import os import shutil import psutil import requests from IPython import display # to display images and clear console output from ultralytics.hub.auth import Auth from ultralytics.hub.session import HubTrainingSession from ultralytics.hub.utils import PREFIX, split_key from ultralytics.yolo.utils import LOGGER, emojis, is_colab from ultralytics.yolo.utils.torch_utils import select_device from ultralytics.yolo.v8.detect import DetectionTrainer def checks(verbose=True): if is_colab(): shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory if verbose: # System info gib = 1 << 30 # bytes per GiB ram = psutil.virtual_memory().total total, used, free = shutil.disk_usage("/") display.clear_output() s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)' else: s = '' select_device(newline=False) LOGGER.info(f'Setup complete ✅ {s}') def start(key=''): # Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY') def request_api_key(attempts=0): """Prompt the user to input their API key""" import getpass max_attempts = 3 tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else "" LOGGER.info(f"{PREFIX}Login. {tries}") input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n") auth.api_key, model_id = split_key(input_key) if not auth.authenticate(): attempts += 1 LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n") if attempts < max_attempts: return request_api_key(attempts) raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) else: return model_id try: api_key, model_id = split_key(key) auth = Auth(api_key) # attempts cookie login if no api key is present attempts = 1 if len(key) else 0 if not auth.get_state(): if len(key): LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n") model_id = request_api_key(attempts) LOGGER.info(f"{PREFIX}Authenticated ✅") if not model_id: raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌')) session = HubTrainingSession(model_id=model_id, auth=auth) session.check_disk_space() # TODO: refactor, hardcoded for v8 args = session.model.copy() args.pop("id") args.pop("status") args.pop("weights") args["data"] = "coco128.yaml" args["model"] = "yolov8n.yaml" args["batch_size"] = 16 args["imgsz"] = 64 trainer = DetectionTrainer(overrides=args) session.register_callbacks(trainer) setattr(trainer, 'hub_session', session) trainer.train() except Exception as e: LOGGER.warning(f"{PREFIX}{e}") def reset_model(key=''): # Reset a trained model to an untrained state api_key, model_id = split_key(key) r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id}) if r.status_code == 200: LOGGER.info(f"{PREFIX}model reset successfully") return LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}") def export_model(key='', format='torchscript'): # Export a model to all formats api_key, model_id = split_key(key) formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs', 'ultralytics_tflite', 'ultralytics_coreml') assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}" r = requests.post('https://api.ultralytics.com/export', json={ "apiKey": api_key, "modelId": model_id, "format": format}) assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" LOGGER.info(f"{PREFIX}{format} export started ✅") def get_export(key='', format='torchscript'): # Get an exported model dictionary with download URL api_key, model_id = split_key(key) formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs', 'ultralytics_tflite', 'ultralytics_coreml') assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}" r = requests.post('https://api.ultralytics.com/get-export', json={ "apiKey": api_key, "modelId": model_id, "format": format}) assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" return r.json() # temp. For checking if __name__ == "__main__": start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz")