CLI Simplification (#449)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-17 23:00:33 +01:00
committed by GitHub
parent 454191bd4b
commit 453b5f259a
16 changed files with 218 additions and 132 deletions

View File

@ -1,18 +1,17 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import argparse
import shutil
from pathlib import Path
import hydra
from hydra import compose, initialize
from ultralytics import hub, yolo
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, PREFIX, print_settings, yaml_load
DIR = Path(__file__).parent
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), config_name=DEFAULT_CONFIG.name)
def cli(cfg):
"""
Run a specified task and mode with the given configuration.
@ -21,21 +20,13 @@ def cli(cfg):
cfg (DictConfig): Configuration for the task and mode.
"""
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
from ultralytics.yolo.configs import get_config
if cfg.cfg:
LOGGER.info(f"Overriding default config with {cfg.cfg}")
LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
cfg = get_config(cfg.cfg)
task, mode = cfg.task.lower(), cfg.mode.lower()
# Special case for initializing the configuration
if task == "init":
shutil.copy2(DEFAULT_CONFIG, Path.cwd())
LOGGER.info(f"""
{colorstr("YOLO:")} configuration saved to {Path.cwd() / DEFAULT_CONFIG.name}.
To run experiments using custom configuration:
yolo cfg=config_file.yaml
""")
return
# Mapping from task to module
task_module_map = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
module = task_module_map.get(task)
@ -47,10 +38,66 @@ def cli(cfg):
"train": module.train,
"val": module.val,
"predict": module.predict,
"export": yolo.engine.exporter.export,
"checks": hub.checks}
"export": yolo.engine.exporter.export}
func = mode_func_map.get(mode)
if not func:
raise SyntaxError(f"mode not recognized. Choices are {', '.join(mode_func_map.keys())}")
func(cfg)
def entrypoint():
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package. It's a combination of argparse and hydra.
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
- specifying the mode, either 'train', 'val', 'test', or 'predict'
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config
"""
parser = argparse.ArgumentParser(description='YOLO parser')
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
args = parser.parse_args().args
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export'
special_modes = {
'checks': hub.checks,
'help': lambda: LOGGER.info(HELP_MSG),
'settings': print_settings,
'copy-config': copy_default_config}
overrides = [] # basic overrides, i.e. imgsz=320
defaults = yaml_load(DEFAULT_CONFIG)
for a in args:
if '=' in a:
overrides.append(a)
elif a in tasks:
overrides.append(f'task={a}')
elif a in modes:
overrides.append(f'mode={a}')
elif a in special_modes:
special_modes[a]()
return
elif a in defaults and defaults[a] is False:
overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
else:
raise (SyntaxError(f"'{a}' is not a valid yolo argument\n{HELP_MSG}"))
with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
cli(cfg)
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config():
new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CONFIG, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")

View File

@ -160,7 +160,7 @@ class BasePredictor:
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
def predict_cli(self):
# Method used for cli prediction. It uses always generator as outputs as not required by cli mode
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference(verbose=True)
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass

View File

@ -10,6 +10,7 @@ import tempfile
import threading
import uuid
from pathlib import Path
from typing import Union
import cv2
import git
@ -41,12 +42,15 @@ HELP_MSG = \
from ultralytics import YOLO
model = YOLO('yolov8n.yaml') # build a new model from scratch
model = YOLO('yolov8n.pt') # load a pretrained model (recommended for best training results)
results = model.train(data='coco128.yaml') # train the model
results = model.val() # evaluate model performance on the validation set
results = model.predict(source='bus.jpg') # predict on an image
success = model.export(format='onnx') # export the model to ONNX format
# Load a model
model = YOLO("yolov8n.yaml") # build a new model from scratch
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
# Use the model
results = model.train(data="coco128.yaml", epochs=3) # train the model
results = model.val() # evaluate model performance on the validation set
results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
success = model.export(format="onnx") # export the model to ONNX format
3. Use the command line interface (CLI):
@ -161,12 +165,12 @@ def is_pip_package(filepath: str = __name__) -> bool:
return spec is not None and spec.origin is not None
def is_dir_writeable(dir_path: str) -> bool:
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
"""
Check if a directory is writeable.
Args:
dir_path (str): The path to the directory.
dir_path (str) or (Path): The path to the directory.
Returns:
bool: True if the directory is writeable, False otherwise.
@ -179,6 +183,18 @@ def is_dir_writeable(dir_path: str) -> bool:
return False
def is_pytest_running():
"""
Returns a boolean indicating if pytest is currently running or not
:return: True if pytest is running, False otherwise
"""
try:
import sys
return "pytest" in sys.modules
except ImportError:
return False
def get_git_root_dir():
"""
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
@ -348,6 +364,17 @@ def yaml_load(file='data.yaml', append_filename=False):
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def set_sentry(dsn=None):
"""
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
"""
if dsn and not is_pytest_running():
import sentry_sdk # noqa
import ultralytics
sentry_sdk.init(dsn=dsn, traces_sample_rate=1.0, release=ultralytics.__version__, debug=False)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
"""
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
@ -364,8 +391,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
is_git = is_git_directory() # True if ultralytics installed via git
root = get_git_root_dir() if is_git else Path()
datasets_root = (root.parent if (is_git and is_dir_writeable(root.parent)) else root).resolve()
defaults = {
'datasets_dir': str((root.parent if is_git else root) / 'datasets'), # default datasets directory.
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
'weights_dir': str(root / 'weights'), # default weights directory.
'runs_dir': str(root / 'runs'), # default runs directory.
'sync': True, # sync analytics to help with YOLO development
@ -393,6 +421,26 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
return settings
def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
"""
Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
directories.
"""
SETTINGS.update(kwargs)
yaml_save(file, SETTINGS)
def print_settings():
"""
Function that prints Ultralytics settings
"""
import json
s = f'\n{PREFIX}Settings:\n'
s += json.dumps(SETTINGS, indent=2)
s += f"\n\nUpdate settings at {USER_CONFIG_DIR / 'settings.yaml'}"
LOGGER.info(s)
# Run below code on utils init -----------------------------------------------------------------------------------------
# Set logger
@ -403,15 +451,7 @@ if platform.system() == 'Windows':
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
# Check first-install steps
PREFIX = colorstr("Ultralytics: ")
SETTINGS = get_settings()
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
"""
Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
directories.
"""
SETTINGS.update(kwargs)
yaml_save(file, SETTINGS)
set_sentry()

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra cli)
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
from ultralytics.yolo.v8 import classify, detect, segment
__all__ = ["classify", "segment", "detect"]