Threadpool fixes and CLI improvements (#550)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-22 17:08:08 +01:00
committed by GitHub
parent d9a0fba251
commit 21b701c4ea
22 changed files with 338 additions and 251 deletions

View File

@ -8,9 +8,9 @@ import platform
import sys
import tempfile
import threading
import types
import uuid
from pathlib import Path
from types import SimpleNamespace
from typing import Union
import cv2
@ -55,10 +55,34 @@ HELP_MSG = \
3. Use the command line interface (CLI):
yolo task=detect mode=train model=yolov8n.yaml args...
classify predict yolov8n-cls.yaml args...
segment val yolov8n-seg.yaml args...
export yolov8n.pt format=onnx args...
YOLOv8 'yolo' CLI commands use the following syntax:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export]
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
- Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
- Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
- Val a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
- Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
- Run special commands:
yolo help
yolo checks
yolo version
yolo settings
yolo copy-cfg
yolo cfg
Docs: https://docs.ultralytics.com
Community: https://community.ultralytics.com
@ -73,11 +97,24 @@ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with Py
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
# Default config dictionary
class IterableSimpleNamespace(SimpleNamespace):
"""
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
"""
def __iter__(self):
return iter(vars(self).items())
def __str__(self):
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
# Default configuration
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
DEFAULT_CFG_DICT = yaml.safe_load(f)
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = types.SimpleNamespace(**DEFAULT_CFG_DICT)
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
def is_colab():
@ -307,14 +344,15 @@ def set_logging(name=LOGGING_NAME, verbose=True):
class TryExcept(contextlib.ContextDecorator):
# YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
def __init__(self, msg=''):
def __init__(self, msg='', verbose=True):
self.msg = msg
self.verbose = verbose
def __enter__(self):
pass
def __exit__(self, exc_type, value, traceback):
if value:
if self.verbose and value:
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True
@ -366,6 +404,21 @@ def yaml_load(file='data.yaml', append_filename=False):
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
Returns:
None
"""
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, default_flow_style=False)
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
def set_sentry(dsn=None):
"""
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
@ -379,7 +432,6 @@ def set_sentry(dsn=None):
debug=False,
traces_sample_rate=1.0,
release=ultralytics.__version__,
send_default_pii=True,
environment='production', # 'dev' or 'production'
ignore_errors=[KeyboardInterrupt])
@ -439,17 +491,6 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
yaml_save(file, SETTINGS)
def print_settings():
"""
Function that prints Ultralytics settings
"""
import json
s = f'\n{PREFIX}Settings:\n'
s += json.dumps(SETTINGS, indent=2)
s += f"\n\nUpdate settings at {USER_CONFIG_DIR / 'settings.yaml'}"
LOGGER.info(s)
# Run below code on utils init -----------------------------------------------------------------------------------------
# Set logger

View File

@ -3,7 +3,7 @@
import json
from time import time
from ultralytics.hub.utils import PREFIX, sync_analytics
from ultralytics.hub.utils import PREFIX, traces
from ultralytics.yolo.utils import LOGGER
@ -43,24 +43,24 @@ def on_train_end(trainer):
LOGGER.info(f"{PREFIX}Training completed successfully ✅\n"
f"{PREFIX}Uploading final {session.model_id}")
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
session.alive = False # stop heartbeats
session.shutdown() # stop heartbeats
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
def on_train_start(trainer):
sync_analytics(trainer.args)
traces(trainer.args, traces_sample_rate=0.0)
def on_val_start(validator):
sync_analytics(validator.args)
traces(validator.args, traces_sample_rate=0.0)
def on_predict_start(predictor):
sync_analytics(predictor.args)
traces(predictor.args, traces_sample_rate=0.0)
def on_export_start(exporter):
sync_analytics(exporter.args)
traces(exporter.args, traces_sample_rate=0.0)
callbacks = {

View File

@ -154,7 +154,7 @@ def check_python(minimum: str = '3.7.0') -> bool:
Returns:
None
"""
check_version(platform.python_version(), minimum, name='Python ', hard=True)
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
@TryExcept()
@ -223,8 +223,10 @@ def check_file(file, suffix=''):
files = []
for d in 'models', 'yolo/data': # search directories
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
assert len(files), f'File not found: {file}' # assert file was found
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
if not files:
raise FileNotFoundError(f"{file} does not exist")
elif len(files) > 1:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] # return file

View File

@ -141,10 +141,14 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.close()
pool.join()
# pool = ThreadPool(threads)
# pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
# pool.close()
# pool.join()
with ThreadPool(threads) as pool:
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.close()
pool.join()
else:
for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir)

View File

@ -62,7 +62,9 @@ def select_device(device='', batch_size=0, newline=False):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
cpu = device == 'cpu'
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
if cpu or mps:
@ -369,3 +371,26 @@ def profile(input, ops, n=10, device=None):
results.append(None)
torch.cuda.empty_cache()
return results
class EarlyStopping:
# early stopper
def __init__(self, patience=30):
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
self.possible_stop = False # possible stop may occur next epoch
def __call__(self, epoch, fitness):
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
self.best_epoch = epoch
self.best_fitness = fitness
delta = epoch - self.best_epoch # epochs without improvement
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
stop = delta >= self.patience # stop training if patience exceeded
if stop:
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
return stop