ultralytics 8.0.44
export and task fixes (#1088)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -15,7 +15,7 @@ EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_c
|
||||
|
||||
def start(key=''):
|
||||
"""
|
||||
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
|
||||
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
|
||||
"""
|
||||
auth = Auth(key)
|
||||
try:
|
||||
@ -30,9 +30,9 @@ def start(key=''):
|
||||
session = HubTrainingSession(model_id=model_id, auth=auth)
|
||||
session.check_disk_space()
|
||||
|
||||
trainer = YOLO(session.input_file)
|
||||
session.register_callbacks(trainer)
|
||||
trainer.train(**session.train_args)
|
||||
model = YOLO(session.input_file)
|
||||
session.register_callbacks(model)
|
||||
model.train(**session.train_args)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{PREFIX}{e}')
|
||||
|
||||
@ -93,6 +93,5 @@ def get_export(key='', format='torchscript'):
|
||||
return r.json()
|
||||
|
||||
|
||||
# temp. For checking
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
|
@ -26,6 +26,7 @@ class HubTrainingSession:
|
||||
self._timers = {} # rate limit timers (seconds)
|
||||
self._metrics_queue = {} # metrics queue
|
||||
self.model = self._get_model()
|
||||
self.alive = True
|
||||
self._start_heartbeat() # start heartbeats
|
||||
self._register_signal_handlers()
|
||||
|
||||
@ -52,37 +53,6 @@ class HubTrainingSession:
|
||||
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
|
||||
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
|
||||
|
||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||
# Upload a model to HUB
|
||||
file = None
|
||||
if Path(weights).is_file():
|
||||
with open(weights, 'rb') as f:
|
||||
file = f.read()
|
||||
if final:
|
||||
smart_request(
|
||||
f'{self.api_url}/upload',
|
||||
data={
|
||||
'epoch': epoch,
|
||||
'type': 'final',
|
||||
'map': map},
|
||||
files={'best.pt': file},
|
||||
headers=self.auth_header,
|
||||
retry=10,
|
||||
timeout=3600,
|
||||
code=4,
|
||||
)
|
||||
else:
|
||||
smart_request(
|
||||
f'{self.api_url}/upload',
|
||||
data={
|
||||
'epoch': epoch,
|
||||
'type': 'epoch',
|
||||
'isBest': bool(is_best)},
|
||||
headers=self.auth_header,
|
||||
files={'last.pt': file},
|
||||
code=3,
|
||||
)
|
||||
|
||||
def _get_model(self):
|
||||
# Returns model from database by id
|
||||
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
||||
@ -151,7 +121,7 @@ class HubTrainingSession:
|
||||
model_info = {
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
||||
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||
all_plots = {**all_plots, **model_info}
|
||||
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
|
||||
@ -169,52 +139,45 @@ class HubTrainingSession:
|
||||
|
||||
def on_train_end(self, trainer):
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f'{PREFIX}Training completed successfully ✅')
|
||||
LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
|
||||
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
|
||||
f'{PREFIX}Uploading final {self.model_id}')
|
||||
|
||||
# hack for fetching mAP
|
||||
mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
|
||||
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
|
||||
self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
||||
self.alive = False # stop heartbeats
|
||||
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
|
||||
|
||||
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||
# Upload a model to HUB
|
||||
file = None
|
||||
if Path(weights).is_file():
|
||||
with open(weights, 'rb') as f:
|
||||
file = f.read()
|
||||
file_param = {'best.pt' if final else 'last.pt': file}
|
||||
endpoint = f'{self.api_url}/upload'
|
||||
else:
|
||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload failed. Missing model {weights}.')
|
||||
file = None
|
||||
data = {'epoch': epoch}
|
||||
if final:
|
||||
data.update({'type': 'final', 'map': map})
|
||||
else:
|
||||
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
||||
|
||||
smart_request(
|
||||
endpoint,
|
||||
data=data,
|
||||
files=file_param,
|
||||
headers=self.auth_header,
|
||||
retry=10 if final else None,
|
||||
timeout=3600 if final else None,
|
||||
code=4 if final else 3,
|
||||
)
|
||||
smart_request(f'{self.api_url}/upload',
|
||||
data=data,
|
||||
files={'best.pt' if final else 'last.pt': file},
|
||||
headers=self.auth_header,
|
||||
retry=10 if final else None,
|
||||
timeout=3600 if final else None,
|
||||
code=4 if final else 3)
|
||||
|
||||
@threaded
|
||||
def _start_heartbeat(self):
|
||||
self.alive = True
|
||||
while self.alive:
|
||||
r = smart_request(
|
||||
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||
json={
|
||||
'agent': AGENT_NAME,
|
||||
'agentId': self.agent_id},
|
||||
headers=self.auth_header,
|
||||
retry=0,
|
||||
code=5,
|
||||
thread=False,
|
||||
)
|
||||
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||
json={
|
||||
'agent': AGENT_NAME,
|
||||
'agentId': self.agent_id},
|
||||
headers=self.auth_header,
|
||||
retry=0,
|
||||
code=5,
|
||||
thread=False)
|
||||
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
||||
sleep(self._rate_limits['heartbeat'])
|
||||
|
Reference in New Issue
Block a user