HUB setup (#108)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-02 00:51:14 +05:30
committed by GitHub
parent c6eb6720de
commit 2bc9a5c87e
16 changed files with 631 additions and 122 deletions

View File

@ -0,0 +1,131 @@
import os
import shutil
import psutil
import requests
from IPython import display # to display images and clear console output
from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HubTrainingSession
from ultralytics.hub.utils import PREFIX, split_key
from ultralytics.yolo.utils import LOGGER, emojis, is_colab
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics.yolo.v8.detect import DetectionTrainer
def checks(verbose=True):
if is_colab():
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
display.clear_output()
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
else:
s = ''
select_device(newline=False)
LOGGER.info(f'Setup complete ✅ {s}')
def start(key=''):
# Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
def request_api_key(attempts=0):
"""Prompt the user to input their API key"""
import getpass
max_attempts = 3
tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else ""
LOGGER.info(f"{PREFIX}Login. {tries}")
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
auth.api_key, model_id = split_key(input_key)
if not auth.authenticate():
attempts += 1
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
if attempts < max_attempts:
return request_api_key(attempts)
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
else:
return model_id
try:
api_key, model_id = split_key(key)
auth = Auth(api_key) # attempts cookie login if no api key is present
attempts = 1 if len(key) else 0
if not auth.get_state():
if len(key):
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
model_id = request_api_key(attempts)
LOGGER.info(f"{PREFIX}Authenticated ✅")
if not model_id:
raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
session = HubTrainingSession(model_id=model_id, auth=auth)
session.check_disk_space()
# TODO: refactor, hardcoded for v8
args = session.model.copy()
args.pop("id")
args.pop("status")
args.pop("weights")
args["data"] = "coco128.yaml"
args["model"] = "yolov8n.yaml"
args["batch_size"] = 16
args["imgsz"] = 64
trainer = DetectionTrainer(overrides=args)
session.register_callbacks(trainer)
setattr(trainer, 'hub_session', session)
trainer.train()
except Exception as e:
LOGGER.warning(f"{PREFIX}{e}")
def reset_model(key=''):
# Reset a trained model to an untrained state
api_key, model_id = split_key(key)
r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id})
if r.status_code == 200:
LOGGER.info(f"{PREFIX}model reset successfully")
return
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
def export_model(key='', format='torchscript'):
# Export a model to all formats
api_key, model_id = split_key(key)
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
'ultralytics_tflite', 'ultralytics_coreml')
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
r = requests.post('https://api.ultralytics.com/export',
json={
"apiKey": api_key,
"modelId": model_id,
"format": format})
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
LOGGER.info(f"{PREFIX}{format} export started ✅")
def get_export(key='', format='torchscript'):
# Get an exported model dictionary with download URL
api_key, model_id = split_key(key)
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
'ultralytics_tflite', 'ultralytics_coreml')
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
r = requests.post('https://api.ultralytics.com/get-export',
json={
"apiKey": api_key,
"modelId": model_id,
"format": format})
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
return r.json()
# temp. For checking
if __name__ == "__main__":
start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz")

69
ultralytics/hub/auth.py Normal file
View File

@ -0,0 +1,69 @@
import requests
from ultralytics.hub.config import HUB_API_ROOT
from ultralytics.hub.utils import request_with_credentials
from ultralytics.yolo.utils import is_colab
API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys"
class Auth:
id_token = api_key = model_key = False
def __init__(self, api_key=None):
self.api_key = self._clean_api_key(api_key)
self.authenticate() if self.api_key else self.auth_with_cookies()
@staticmethod
def _clean_api_key(key: str) -> str:
"""Strip model from key if present"""
separator = "_"
return key.split(separator)[0] if separator in key else key
def authenticate(self) -> bool:
"""Attempt to authenticate with server"""
try:
header = self.get_auth_header()
if header:
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
if not r.json().get('success', False):
raise ConnectionError("Unable to authenticate.")
return True
raise ConnectionError("User has not authenticated locally.")
except ConnectionError:
self.id_token = self.api_key = False # reset invalid
return False
def auth_with_cookies(self) -> bool:
"""
Attempt to fetch authentication via cookies and set id_token.
User must be logged in to HUB and running in a supported browser.
"""
if not is_colab():
return False # Currently only works with Colab
try:
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
if authn.get("success", False):
self.id_token = authn.get("data", {}).get("idToken", None)
self.authenticate()
return True
raise ConnectionError("Unable to fetch browser authentication details.")
except ConnectionError:
self.id_token = False # reset invalid
return False
def get_auth_header(self):
if self.id_token:
return {"authorization": f"Bearer {self.id_token}"}
elif self.api_key:
return {"x-api-key": self.api_key}
else:
return None
def get_state(self) -> bool:
"""Get the authentication state"""
return self.id_token or self.api_key
def set_api_key(self, key: str):
"""Get the authentication state"""
self.api_key = key

12
ultralytics/hub/config.py Normal file
View File

@ -0,0 +1,12 @@
import os
# Global variables
REPO_URL = "https://github.com/ultralytics/yolov5.git"
REPO_BRANCH = "ultralytics/HUB" # "master"
ENVIRONMENT = os.environ.get("ULTRALYTICS_ENV", "production")
if ENVIRONMENT == 'production':
HUB_API_ROOT = "https://api.ultralytics.com"
else:
HUB_API_ROOT = "http://127.0.0.1:8000"
print(f'Connected to development server on {HUB_API_ROOT}')

121
ultralytics/hub/session.py Normal file
View File

@ -0,0 +1,121 @@
import signal
import sys
from pathlib import Path
from time import sleep
import requests
from ultralytics import __version__
from ultralytics.hub.config import HUB_API_ROOT
from ultralytics.hub.utils import check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, is_colab, threaded
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
session = None
def signal_handler(signum, frame):
""" Confirm exit """
global hub_logger
LOGGER.info(f'Signal received. {signum} {frame}')
if isinstance(session, HubTrainingSession):
hub_logger.alive = False
del hub_logger
sys.exit(signum)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
class HubTrainingSession:
def __init__(self, model_id, auth):
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self.t = {} # rate limit timers (seconds)
self.metrics_queue = {} # metrics queue
self.alive = True # for heartbeats
self.model = self._get_model()
self._heartbeats() # start heartbeats
def __del__(self):
# Class destructor
self.alive = False
def upload_metrics(self):
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"}
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
file = None
if Path(weights).is_file():
with open(weights, "rb") as f:
file = f.read()
if final:
smart_request(f'{self.api_url}/upload',
data={
"epoch": epoch,
"type": "final",
"map": map},
files={"best.pt": file},
headers=self.auth_header,
retry=10,
timeout=3600,
code=4)
else:
smart_request(f'{self.api_url}/upload',
data={
"epoch": epoch,
"type": "epoch",
"isBest": bool(is_best)},
headers=self.auth_header,
files={"last.pt": file},
code=3)
def _get_model(self):
# Returns model from database by id
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}"
headers = self.auth_header
try:
r = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
data = r.json().get("data", None)
if not data:
return
assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix
self.model_id = data["id"]
return data
except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
def check_disk_space(self):
if not check_dataset_disk_space(self.model['data']):
raise MemoryError("Not enough disk space")
# COMMENT: Should not be needed as HUB is now considered an integration and is in integrations_callbacks
# import ultralytics.yolo.utils.callbacks.hub as hub_callbacks
# @staticmethod
# def register_callbacks(trainer):
# for k, v in hub_callbacks.callbacks.items():
# trainer.add_callback(k, v)
@threaded
def _heartbeats(self):
while self.alive:
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={
"agent": AGENT_NAME,
"agentId": self.agent_id},
headers=self.auth_header,
retry=0,
code=5,
thread=False)
self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self.rate_limits['heartbeat'])

139
ultralytics/hub/utils.py Normal file
View File

@ -0,0 +1,139 @@
import shutil
import threading
import time
import uuid
import requests
from ultralytics.hub.config import HUB_API_ROOT
from ultralytics.yolo.utils import LOGGER, RANK, SETTINGS, colorstr, emojis
PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
gib = 1 << 30 # bytes per GiB
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
total, used, free = (x / gib for x in shutil.disk_usage("/")) # bytes
LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
if data * sf < free:
return True # sufficient space
LOGGER.warning(f'{PREFIX}WARNING: Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
f'training cancelled ❌. Please free {data * sf - free:.1f} GB additional disk space and try again.')
return False # insufficient space
def request_with_credentials(url: str) -> any:
""" Make a ajax request with cookies attached """
from google.colab import output # noqa
from IPython import display # noqa
display.display(
display.Javascript("""
window._hub_tmp = new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
fetch("%s", {
method: 'POST',
credentials: 'include'
})
.then((response) => resolve(response.json()))
.then((json) => {
clearTimeout(timeout);
}).catch((err) => {
clearTimeout(timeout);
reject(err);
});
});
""" % url))
return output.eval_js("_hub_tmp")
# Deprecated TODO: eliminate this function?
def split_key(key: str = '') -> tuple[str, str]:
"""
Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_'
Args:
key (str): The model key to split. If not provided, the user will be prompted to enter it.
Returns:
Tuple[str, str]: A tuple containing the API key and model ID.
"""
import getpass
error_string = emojis(f'{PREFIX}Invalid API key ⚠️\n') # error string
if not key:
key = getpass.getpass('Enter model key: ')
sep = '_' if '_' in key else '.' if '.' in key else None # separator
assert sep, error_string
api_key, model_id = key.split(sep)
assert len(api_key) and len(model_id), error_string
return api_key, model_id
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post", **kwargs):
"""
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
Args:
*args: Positional arguments to be passed to the requests function specified in method.
retry (int, optional): Number of retries to attempt before giving up. Default is 3.
timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
method (str, optional): The HTTP method to use for the request. Choices are 'post' and 'get'. Default is 'post'.
**kwargs: Keyword arguments to be passed to the requests function specified in method.
Returns:
requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
"""
retry_codes = (408, 500) # retry only these codes
methods = {'post': requests.post, 'get': requests.get} # request methods
def fcn(*args, **kwargs):
t0 = time.time()
for i in range(retry + 1):
if (time.time() - t0) > timeout:
break
r = methods[method](*args, **kwargs) # i.e. post(url, data, json, files)
if r.status_code == 200:
break
try:
m = r.json().get('message', 'No JSON message.')
except Exception:
m = 'Unable to read JSON.'
if i == 0:
if r.status_code in retry_codes:
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
elif r.status_code == 429: # rate limit
h = r.headers # response headers
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
f"Please retry after {h['Retry-After']}s."
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
if r.status_code not in retry_codes:
return r
time.sleep(2 ** i) # exponential standoff
return r
if thread:
threading.Thread(target=fcn, args=args, kwargs=kwargs, daemon=True).start()
else:
return fcn(*args, **kwargs)
def sync_analytics(cfg, enabled=False):
"""
Sync analytics data if enabled in the global settings
Args:
cfg (DictConfig): Configuration for the task and mode.
enabled (bool): For debugging.
"""
if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
cfg = dict(cfg) # convert type from DictConfig to dict
cfg['uuid'] = uuid.getnode() # add the device UUID to the configuration data
# Send a request to the HUB API to sync the analytics data
smart_request(f'{HUB_API_ROOT}/analytics', data=cfg, headers=None, code=3, retry=0)