Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
@ -9,46 +8,39 @@ from types import SimpleNamespace
|
||||
from typing import Dict, Union
|
||||
|
||||
from ultralytics import __version__, yolo
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, PREFIX, checks, colorstr, print_settings, yaml_load
|
||||
|
||||
DIR = Path(__file__).parent
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, USER_CONFIG_DIR,
|
||||
IterableSimpleNamespace, checks, colorstr, yaml_load, yaml_print)
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
"""
|
||||
YOLOv8 CLI Usage examples:
|
||||
YOLOv8 'yolo' CLI commands use the following syntax:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
pip install ultralytics
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
|
||||
|
||||
2. Train, Val, Predict and Export using 'yolo' commands:
|
||||
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
For a full list of available ARGS see https://docs.ultralytics.com/cfg.
|
||||
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
Validate a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
3. Run special commands:
|
||||
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
5. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com/cli
|
||||
Community: https://community.ultralytics.com
|
||||
@ -56,15 +48,6 @@ CLI_HELP_MSG = \
|
||||
"""
|
||||
|
||||
|
||||
class UltralyticsCFG(SimpleNamespace):
|
||||
"""
|
||||
UltralyticsCFG iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
return iter(vars(self).items())
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary.
|
||||
@ -104,7 +87,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Return instance
|
||||
return UltralyticsCFG(**cfg)
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def check_cfg_mismatch(base: Dict, custom: Dict):
|
||||
@ -118,12 +101,19 @@ def check_cfg_mismatch(base: Dict, custom: Dict):
|
||||
"""
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
for option in mismatched:
|
||||
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, base, 3, 0.6)}")
|
||||
if mismatched:
|
||||
for x in mismatched:
|
||||
matches = get_close_matches(x, base, 3, 0.6)
|
||||
match_str = f"Similar arguments are {matches}." if matches else 'There are no similar arguments.'
|
||||
LOGGER.warning(f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}")
|
||||
LOGGER.warning(CLI_HELP_MSG)
|
||||
sys.exit()
|
||||
|
||||
|
||||
def argument_error(arg):
|
||||
return SyntaxError(f"'{arg}' is not a valid YOLO argument.\n{CLI_HELP_MSG}")
|
||||
|
||||
|
||||
def entrypoint(debug=False):
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
@ -139,67 +129,61 @@ def entrypoint(debug=False):
|
||||
It uses the package's default cfg and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed cfg
|
||||
"""
|
||||
if debug:
|
||||
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
|
||||
else:
|
||||
if len(sys.argv) == 1: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||
args = parser.parse_args().args
|
||||
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
||||
args = ['train', 'predict', 'model=yolov8n.pt'] if debug else sys.argv[1:]
|
||||
if not args: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
special_modes = {
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': print_settings,
|
||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'copy-cfg': copy_default_config}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
defaults = yaml_load(DEFAULT_CFG_PATH)
|
||||
for a in args:
|
||||
if '=' in a:
|
||||
if a.startswith('cfg='): # custom.yaml passed
|
||||
custom_config = Path(a.split('=')[-1])
|
||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {custom_config}")
|
||||
overrides = {k: v for k, v in yaml_load(custom_config).items() if k not in {'cfg'}}
|
||||
else:
|
||||
try:
|
||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||
k, v = a.split('=')
|
||||
try:
|
||||
if k == 'device': # special DDP handling, i.e. device='0,1,2,3'
|
||||
v = v.replace('[', '').replace(']', '') # handle device=[0,1,2,3]
|
||||
v = v.replace(" ", "") # handle device=[0, 1, 2, 3]
|
||||
v = v.replace('\\', '') # handle device=\'0,1,2,3\'
|
||||
overrides[k] = v
|
||||
else:
|
||||
overrides[k] = eval(v) # convert strings to integers, floats, bools, etc.
|
||||
except (NameError, SyntaxError):
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
|
||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
||||
else:
|
||||
if v.isnumeric():
|
||||
v = eval(v)
|
||||
elif v.lower() == 'none':
|
||||
v = None
|
||||
elif v.lower() == 'true':
|
||||
v = True
|
||||
elif v.lower() == 'false':
|
||||
v = False
|
||||
elif ',' in v:
|
||||
v = eval(v)
|
||||
overrides[k] = v
|
||||
except (NameError, SyntaxError, ValueError) as e:
|
||||
raise argument_error(a) from e
|
||||
|
||||
elif a in tasks:
|
||||
overrides['task'] = a
|
||||
elif a in modes:
|
||||
overrides['mode'] = a
|
||||
elif a in special_modes:
|
||||
special_modes[a]()
|
||||
elif a in special:
|
||||
special[a]()
|
||||
return
|
||||
elif a in defaults and defaults[a] is False:
|
||||
elif a in DEFAULT_CFG_DICT and DEFAULT_CFG_DICT[a] is False:
|
||||
overrides[a] = True # auto-True for default False args, i.e. 'yolo show' sets show=True
|
||||
elif a in defaults:
|
||||
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
|
||||
f"i.e. try '{a}={defaults[a]}'"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
elif a in DEFAULT_CFG_DICT:
|
||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
raise SyntaxError(
|
||||
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
|
||||
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
raise argument_error(a)
|
||||
|
||||
cfg = get_cfg(defaults, overrides) # create CFG instance
|
||||
cfg = get_cfg(DEFAULT_CFG_DICT, overrides) # create CFG instance
|
||||
|
||||
# Mapping from task to module
|
||||
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
|
||||
@ -223,8 +207,8 @@ def copy_default_config():
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||
f"Usage for running YOLO with this new custom cfg:\nyolo cfg={new_file} args...")
|
||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
entrypoint()
|
||||
entrypoint(debug=True)
|
||||
|
Reference in New Issue
Block a user