Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher
2023-02-17 22:26:40 +01:00
committed by GitHub
parent 9047d737f4
commit edd3ff1669
76 changed files with 928 additions and 935 deletions

View File

@ -10,10 +10,10 @@ from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
# Define all export formats
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
def start(key=""):
def start(key=''):
"""
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
"""
@ -34,7 +34,7 @@ def start(key=""):
session.register_callbacks(trainer)
trainer.train(**session.train_args)
except Exception as e:
LOGGER.warning(f"{PREFIX}{e}")
LOGGER.warning(f'{PREFIX}{e}')
def request_api_key(auth, max_attempts=3):
@ -43,56 +43,56 @@ def request_api_key(auth, max_attempts=3):
"""
import getpass
for attempts in range(max_attempts):
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n')
auth.api_key, model_id = split_key(input_key)
if auth.authenticate():
LOGGER.info(f"{PREFIX}Authenticated ✅")
LOGGER.info(f'{PREFIX}Authenticated ✅')
return model_id
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n')
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
def reset_model(key=""):
def reset_model(key=''):
# Reset a trained model to an untrained state
api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/model-reset", json={"apiKey": api_key, "modelId": model_id})
r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id})
if r.status_code == 200:
LOGGER.info(f"{PREFIX}model reset successfully")
LOGGER.info(f'{PREFIX}model reset successfully')
return
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
LOGGER.warning(f'{PREFIX}model reset failure {r.status_code} {r.reason}')
def export_model(key="", format="torchscript"):
def export_model(key='', format='torchscript'):
# Export a model to all formats
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/export",
r = requests.post('https://api.ultralytics.com/export',
json={
"apiKey": api_key,
"modelId": model_id,
"format": format})
assert (r.status_code == 200), f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
LOGGER.info(f"{PREFIX}{format} export started ✅")
'apiKey': api_key,
'modelId': model_id,
'format': format})
assert (r.status_code == 200), f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
LOGGER.info(f'{PREFIX}{format} export started ✅')
def get_export(key="", format="torchscript"):
def get_export(key='', format='torchscript'):
# Get an exported model dictionary with download URL
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/get-export",
r = requests.post('https://api.ultralytics.com/get-export',
json={
"apiKey": api_key,
"modelId": model_id,
"format": format})
assert (r.status_code == 200), f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
'apiKey': api_key,
'modelId': model_id,
'format': format})
assert (r.status_code == 200), f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
return r.json()
# temp. For checking
if __name__ == "__main__":
if __name__ == '__main__':
start()

View File

@ -5,7 +5,7 @@ import requests
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
from ultralytics.yolo.utils import is_colab
API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys"
API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys'
class Auth:
@ -18,7 +18,7 @@ class Auth:
@staticmethod
def _clean_api_key(key: str) -> str:
"""Strip model from key if present"""
separator = "_"
separator = '_'
return key.split(separator)[0] if separator in key else key
def authenticate(self) -> bool:
@ -26,11 +26,11 @@ class Auth:
try:
header = self.get_auth_header()
if header:
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
if not r.json().get('success', False):
raise ConnectionError("Unable to authenticate.")
raise ConnectionError('Unable to authenticate.')
return True
raise ConnectionError("User has not authenticated locally.")
raise ConnectionError('User has not authenticated locally.')
except ConnectionError:
self.id_token = self.api_key = False # reset invalid
return False
@ -43,21 +43,21 @@ class Auth:
if not is_colab():
return False # Currently only works with Colab
try:
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
if authn.get("success", False):
self.id_token = authn.get("data", {}).get("idToken", None)
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
if authn.get('success', False):
self.id_token = authn.get('data', {}).get('idToken', None)
self.authenticate()
return True
raise ConnectionError("Unable to fetch browser authentication details.")
raise ConnectionError('Unable to fetch browser authentication details.')
except ConnectionError:
self.id_token = False # reset invalid
return False
def get_auth_header(self):
if self.id_token:
return {"authorization": f"Bearer {self.id_token}"}
return {'authorization': f'Bearer {self.id_token}'}
elif self.api_key:
return {"x-api-key": self.api_key}
return {'x-api-key': self.api_key}
else:
return None

View File

@ -11,7 +11,7 @@ from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
session = None
@ -20,9 +20,9 @@ class HubTrainingSession:
def __init__(self, model_id, auth):
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.api_url = f"{HUB_API_ROOT}/v1/models/{model_id}"
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self._rate_limits = {"metrics": 3.0, "ckpt": 900.0, "heartbeat": 300.0} # rate limits (seconds)
self._rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self._timers = {} # rate limit timers (seconds)
self._metrics_queue = {} # metrics queue
self.model = self._get_model()
@ -40,7 +40,7 @@ class HubTrainingSession:
passed by signal.
"""
if self.alive is True:
LOGGER.info(f"{PREFIX}Kill signal received! ❌")
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
self._stop_heartbeat()
sys.exit(signum)
@ -49,23 +49,23 @@ class HubTrainingSession:
self.alive = False
def upload_metrics(self):
payload = {"metrics": self._metrics_queue.copy(), "type": "metrics"}
smart_request(f"{self.api_url}", json=payload, headers=self.auth_header, code=2)
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
file = None
if Path(weights).is_file():
with open(weights, "rb") as f:
with open(weights, 'rb') as f:
file = f.read()
if final:
smart_request(
f"{self.api_url}/upload",
f'{self.api_url}/upload',
data={
"epoch": epoch,
"type": "final",
"map": map},
files={"best.pt": file},
'epoch': epoch,
'type': 'final',
'map': map},
files={'best.pt': file},
headers=self.auth_header,
retry=10,
timeout=3600,
@ -73,66 +73,66 @@ class HubTrainingSession:
)
else:
smart_request(
f"{self.api_url}/upload",
f'{self.api_url}/upload',
data={
"epoch": epoch,
"type": "epoch",
"isBest": bool(is_best)},
'epoch': epoch,
'type': 'epoch',
'isBest': bool(is_best)},
headers=self.auth_header,
files={"last.pt": file},
files={'last.pt': file},
code=3,
)
def _get_model(self):
# Returns model from database by id
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}"
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
headers = self.auth_header
try:
response = smart_request(api_url, method="get", headers=headers, thread=False, code=0)
data = response.json().get("data", None)
response = smart_request(api_url, method='get', headers=headers, thread=False, code=0)
data = response.json().get('data', None)
if data.get("status", None) == "trained":
if data.get('status', None) == 'trained':
raise ValueError(
emojis(f"Model is already trained and uploaded to "
f"https://hub.ultralytics.com/models/{self.model_id} 🚀"))
emojis(f'Model is already trained and uploaded to '
f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
if not data.get("data", None):
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
self.model_id = data["id"]
if not data.get('data', None):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_id = data['id']
# TODO: restore when server keys when dataset URL and GPU train is working
self.train_args = {
"batch": data["batch_size"],
"epochs": data["epochs"],
"imgsz": data["imgsz"],
"patience": data["patience"],
"device": data["device"],
"cache": data["cache"],
"data": data["data"]}
'batch': data['batch_size'],
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
'device': data['device'],
'cache': data['cache'],
'data': data['data']}
self.input_file = data.get("cfg", data["weights"])
self.input_file = data.get('cfg', data['weights'])
# hack for yolov5 cfg adds u
if "cfg" in data and "yolov5" in data["cfg"]:
self.input_file = data["cfg"].replace(".yaml", "u.yaml")
if 'cfg' in data and 'yolov5' in data['cfg']:
self.input_file = data['cfg'].replace('.yaml', 'u.yaml')
return data
except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError("ERROR: The HUB server is not online. Please try again later.") from e
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
except Exception:
raise
def check_disk_space(self):
if not check_dataset_disk_space(self.model["data"]):
raise MemoryError("Not enough disk space")
if not check_dataset_disk_space(self.model['data']):
raise MemoryError('Not enough disk space')
def register_callbacks(self, trainer):
trainer.add_callback("on_pretrain_routine_end", self.on_pretrain_routine_end)
trainer.add_callback("on_fit_epoch_end", self.on_fit_epoch_end)
trainer.add_callback("on_model_save", self.on_model_save)
trainer.add_callback("on_train_end", self.on_train_end)
trainer.add_callback('on_pretrain_routine_end', self.on_pretrain_routine_end)
trainer.add_callback('on_fit_epoch_end', self.on_fit_epoch_end)
trainer.add_callback('on_model_save', self.on_model_save)
trainer.add_callback('on_train_end', self.on_train_end)
def on_pretrain_routine_end(self, trainer):
"""
@ -140,57 +140,57 @@ class HubTrainingSession:
This method does not use trainer. It is passed to all callbacks by default.
"""
# Start timer for upload rate limit
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
self._timers = {"metrics": time(), "ckpt": time()} # start timer on self.rate_limit
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
self._timers = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
def on_fit_epoch_end(self, trainer):
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
if trainer.epoch == 0:
model_info = {
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
"model/speed(ms)": round(trainer.validator.speed[1], 3)}
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
all_plots = {**all_plots, **model_info}
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - self._timers["metrics"] > self._rate_limits["metrics"]:
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
self.upload_metrics()
self._timers["metrics"] = time() # reset timer
self._timers['metrics'] = time() # reset timer
self._metrics_queue = {} # reset queue
def on_model_save(self, trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - self._timers["ckpt"] > self._rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {self.model_id}")
if time() - self._timers['ckpt'] > self._rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {self.model_id}')
self._upload_model(trainer.epoch, trainer.last, is_best)
self._timers["ckpt"] = time() # reset timer
self._timers['ckpt'] = time() # reset timer
def on_train_end(self, trainer):
# Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Training completed successfully ✅")
LOGGER.info(f"{PREFIX}Uploading final {self.model_id}")
LOGGER.info(f'{PREFIX}Training completed successfully ✅')
LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
# hack for fetching mAP
mAP = trainer.metrics.get("metrics/mAP50-95(B)", 0)
mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
self.alive = False # stop heartbeats
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
file = None
if Path(weights).is_file():
with open(weights, "rb") as f:
with open(weights, 'rb') as f:
file = f.read()
file_param = {"best.pt" if final else "last.pt": file}
endpoint = f"{self.api_url}/upload"
data = {"epoch": epoch}
file_param = {'best.pt' if final else 'last.pt': file}
endpoint = f'{self.api_url}/upload'
data = {'epoch': epoch}
if final:
data.update({"type": "final", "map": map})
data.update({'type': 'final', 'map': map})
else:
data.update({"type": "epoch", "isBest": bool(is_best)})
data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request(
endpoint,
@ -207,14 +207,14 @@ class HubTrainingSession:
self.alive = True
while self.alive:
r = smart_request(
f"{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}",
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={
"agent": AGENT_NAME,
"agentId": self.agent_id},
'agent': AGENT_NAME,
'agentId': self.agent_id},
headers=self.auth_header,
retry=0,
code=5,
thread=False,
)
self.agent_id = r.json().get("data", {}).get("agentId", None)
sleep(self._rate_limits["heartbeat"])
self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self._rate_limits['heartbeat'])

View File

@ -18,14 +18,14 @@ from ultralytics.yolo.utils.checks import check_online
PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
gib = 1 << 30 # bytes per GiB
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
total, used, free = (x / gib for x in shutil.disk_usage("/")) # bytes
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
if data * sf < free:
return True # sufficient space
@ -57,7 +57,7 @@ def request_with_credentials(url: str) -> any:
});
});
""" % url))
return output.eval_js("_hub_tmp")
return output.eval_js('_hub_tmp')
# Deprecated TODO: eliminate this function?
@ -84,7 +84,7 @@ def split_key(key=''):
return api_key, model_id
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post", verbose=True, **kwargs):
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post', verbose=True, **kwargs):
"""
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
@ -128,7 +128,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
f"Please retry after {h['Retry-After']}s."
if verbose:
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
if r.status_code not in retry_codes:
return r
time.sleep(2 ** i) # exponential standoff
@ -149,17 +149,17 @@ class Traces:
self.rate_limit = 3.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
"sys_argv_name": Path(sys.argv[0]).name,
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
"python": platform.python_version(),
"release": __version__,
"environment": ENVIRONMENT}
'sys_argv_name': Path(sys.argv[0]).name,
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
'python': platform.python_version(),
'release': __version__,
'environment': ENVIRONMENT}
self.enabled = SETTINGS['sync'] and \
RANK in {-1, 0} and \
check_online() and \
not is_pytest_running() and \
not is_github_actions_ci() and \
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
"""