# Ultralytics YOLO 🚀, GPL-3.0 license import signal import sys from pathlib import Path from time import sleep import requests from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' class HUBTrainingSession: def __init__(self, model_id, auth): self.agent_id = None # identifies which instance is communicating with server self.model_id = model_id self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' self.auth_header = auth.get_auth_header() self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) self.timers = {} # rate limit timers (seconds) self.metrics_queue = {} # metrics queue self.model = self._get_model() self.alive = True self._start_heartbeat() # start heartbeats self._register_signal_handlers() def _register_signal_handlers(self): signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGINT, self._handle_signal) def _handle_signal(self, signum, frame): """ Prevent heartbeats from being sent on Colab after kill. This method does not use frame, it is included as it is passed by signal. """ if self.alive is True: LOGGER.info(f'{PREFIX}Kill signal received! ❌') self._stop_heartbeat() sys.exit(signum) def _stop_heartbeat(self): """End the heartbeat loop""" self.alive = False def upload_metrics(self): payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) def _get_model(self): # Returns model from database by id api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' try: response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0) data = response.json().get('data', None) if data.get('status', None) == 'trained': raise ValueError( emojis(f'Model is already trained and uploaded to ' f'https://hub.ultralytics.com/models/{self.model_id} 🚀')) if not data.get('data', None): raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix self.model_id = data['id'] # TODO: restore when server keys when dataset URL and GPU train is working self.train_args = { 'batch': data['batch_size'], 'epochs': data['epochs'], 'imgsz': data['imgsz'], 'patience': data['patience'], 'device': data['device'], 'cache': data['cache'], 'data': data['data']} self.model_file = data.get('cfg', data['weights']) self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u return data except requests.exceptions.ConnectionError as e: raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e except Exception: raise def check_disk_space(self): if not check_dataset_disk_space(self.model['data']): raise MemoryError('Not enough disk space') def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): # Upload a model to HUB if Path(weights).is_file(): with open(weights, 'rb') as f: file = f.read() else: LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') file = None url = f'{self.api_url}/upload' # url = 'http://httpbin.org/post' # for debug data = {'epoch': epoch} if final: data.update({'type': 'final', 'map': map}) smart_request('post', url, data=data, files={'best.pt': file}, headers=self.auth_header, retry=10, timeout=3600, thread=False, progress=True, code=4) else: data.update({'type': 'epoch', 'isBest': bool(is_best)}) smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3) @threaded def _start_heartbeat(self): while self.alive: r = smart_request('post', f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', json={ 'agent': AGENT_NAME, 'agentId': self.agent_id}, headers=self.auth_header, retry=0, code=5, thread=False) # already in a thread self.agent_id = r.json().get('data', {}).get('agentId', None) sleep(self.rate_limits['heartbeat'])