Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
@ -13,22 +13,6 @@ AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__versio
|
||||
|
||||
session = None
|
||||
|
||||
# Causing problems in tests (non-authenticated)
|
||||
# import signal
|
||||
# import sys
|
||||
# def signal_handler(signum, frame):
|
||||
# """ Confirm exit """
|
||||
# global hub_logger
|
||||
# LOGGER.info(f'Signal received. {signum} {frame}')
|
||||
# if isinstance(session, HubTrainingSession):
|
||||
# hub_logger.alive = False
|
||||
# del hub_logger
|
||||
# sys.exit(signum)
|
||||
#
|
||||
#
|
||||
# signal.signal(signal.SIGTERM, signal_handler)
|
||||
# signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
class HubTrainingSession:
|
||||
|
||||
@ -43,10 +27,11 @@ class HubTrainingSession:
|
||||
self.alive = True # for heartbeats
|
||||
self.model = self._get_model()
|
||||
self._heartbeats() # start heartbeats
|
||||
signal.signal(signal.SIGTERM, self.shutdown) # register the shutdown function to be called on exit
|
||||
signal.signal(signal.SIGINT, self.shutdown)
|
||||
|
||||
def __del__(self):
|
||||
# Class destructor
|
||||
self.alive = False
|
||||
def shutdown(self, *args): # noqa
|
||||
self.alive = False # stop heartbeats
|
||||
|
||||
def upload_metrics(self):
|
||||
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"}
|
||||
@ -100,13 +85,6 @@ class HubTrainingSession:
|
||||
if not check_dataset_disk_space(self.model['data']):
|
||||
raise MemoryError("Not enough disk space")
|
||||
|
||||
# COMMENT: Should not be needed as HUB is now considered an integration and is in integrations_callbacks
|
||||
# import ultralytics.yolo.utils.callbacks.hub as hub_callbacks
|
||||
# @staticmethod
|
||||
# def register_callbacks(trainer):
|
||||
# for k, v in hub_callbacks.callbacks.items():
|
||||
# trainer.add_callback(k, v)
|
||||
|
||||
@threaded
|
||||
def _heartbeats(self):
|
||||
while self.alive:
|
||||
|
@ -4,6 +4,7 @@ import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from random import random
|
||||
|
||||
import requests
|
||||
|
||||
@ -14,7 +15,7 @@ HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/h
|
||||
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
|
||||
|
||||
|
||||
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
|
||||
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
|
||||
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
|
||||
@ -130,18 +131,18 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def sync_analytics(cfg, all_keys=False, enabled=False):
|
||||
@TryExcept(verbose=False)
|
||||
def traces(cfg, all_keys=False, traces_sample_rate=0.0):
|
||||
"""
|
||||
Sync analytics data if enabled in the global settings
|
||||
Sync traces data if enabled in the global settings
|
||||
|
||||
Args:
|
||||
cfg (UltralyticsCFG): Configuration for the task and mode.
|
||||
cfg (IterableSimpleNamespace): Configuration for the task and mode.
|
||||
all_keys (bool): Sync all items, not just non-default values.
|
||||
enabled (bool): For debugging.
|
||||
traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0
|
||||
"""
|
||||
if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
|
||||
cfg = dict(cfg) # convert type from UltralyticsCFG to dict
|
||||
if SETTINGS['sync'] and RANK in {-1, 0} and (random() < traces_sample_rate):
|
||||
cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict
|
||||
if not all_keys:
|
||||
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None)} # retain non-default values
|
||||
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
|
||||
|
Reference in New Issue
Block a user