CLI Simplification (#449)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,18 +1,17 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
from hydra import compose, initialize
|
||||
|
||||
from ultralytics import hub, yolo
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, PREFIX, print_settings, yaml_load
|
||||
|
||||
DIR = Path(__file__).parent
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), config_name=DEFAULT_CONFIG.name)
|
||||
def cli(cfg):
|
||||
"""
|
||||
Run a specified task and mode with the given configuration.
|
||||
@ -21,21 +20,13 @@ def cli(cfg):
|
||||
cfg (DictConfig): Configuration for the task and mode.
|
||||
"""
|
||||
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
|
||||
from ultralytics.yolo.configs import get_config
|
||||
|
||||
if cfg.cfg:
|
||||
LOGGER.info(f"Overriding default config with {cfg.cfg}")
|
||||
LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
|
||||
cfg = get_config(cfg.cfg)
|
||||
task, mode = cfg.task.lower(), cfg.mode.lower()
|
||||
|
||||
# Special case for initializing the configuration
|
||||
if task == "init":
|
||||
shutil.copy2(DEFAULT_CONFIG, Path.cwd())
|
||||
LOGGER.info(f"""
|
||||
{colorstr("YOLO:")} configuration saved to {Path.cwd() / DEFAULT_CONFIG.name}.
|
||||
To run experiments using custom configuration:
|
||||
yolo cfg=config_file.yaml
|
||||
""")
|
||||
return
|
||||
|
||||
# Mapping from task to module
|
||||
task_module_map = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
|
||||
module = task_module_map.get(task)
|
||||
@ -47,10 +38,66 @@ def cli(cfg):
|
||||
"train": module.train,
|
||||
"val": module.val,
|
||||
"predict": module.predict,
|
||||
"export": yolo.engine.exporter.export,
|
||||
"checks": hub.checks}
|
||||
"export": yolo.engine.exporter.export}
|
||||
func = mode_func_map.get(mode)
|
||||
if not func:
|
||||
raise SyntaxError(f"mode not recognized. Choices are {', '.join(mode_func_map.keys())}")
|
||||
|
||||
func(cfg)
|
||||
|
||||
|
||||
def entrypoint():
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
to the package. It's a combination of argparse and hydra.
|
||||
|
||||
This function allows for:
|
||||
- passing mandatory YOLO args as a list of strings
|
||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||
- running special modes like 'checks'
|
||||
- passing overrides to the package's configuration
|
||||
|
||||
It uses the package's default config and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||
args = parser.parse_args().args
|
||||
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
special_modes = {
|
||||
'checks': hub.checks,
|
||||
'help': lambda: LOGGER.info(HELP_MSG),
|
||||
'settings': print_settings,
|
||||
'copy-config': copy_default_config}
|
||||
|
||||
overrides = [] # basic overrides, i.e. imgsz=320
|
||||
defaults = yaml_load(DEFAULT_CONFIG)
|
||||
for a in args:
|
||||
if '=' in a:
|
||||
overrides.append(a)
|
||||
elif a in tasks:
|
||||
overrides.append(f'task={a}')
|
||||
elif a in modes:
|
||||
overrides.append(f'mode={a}')
|
||||
elif a in special_modes:
|
||||
special_modes[a]()
|
||||
return
|
||||
elif a in defaults and defaults[a] is False:
|
||||
overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
|
||||
else:
|
||||
raise (SyntaxError(f"'{a}' is not a valid yolo argument\n{HELP_MSG}"))
|
||||
|
||||
with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
|
||||
cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
|
||||
cli(cfg)
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_config():
|
||||
new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CONFIG, new_file)
|
||||
LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
|
||||
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
|
||||
|
@ -160,7 +160,7 @@ class BasePredictor:
|
||||
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
|
||||
|
||||
def predict_cli(self):
|
||||
# Method used for cli prediction. It uses always generator as outputs as not required by cli mode
|
||||
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
||||
gen = self.stream_inference(verbose=True)
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
|
@ -10,6 +10,7 @@ import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import git
|
||||
@ -41,12 +42,15 @@ HELP_MSG = \
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO('yolov8n.yaml') # build a new model from scratch
|
||||
model = YOLO('yolov8n.pt') # load a pretrained model (recommended for best training results)
|
||||
results = model.train(data='coco128.yaml') # train the model
|
||||
results = model.val() # evaluate model performance on the validation set
|
||||
results = model.predict(source='bus.jpg') # predict on an image
|
||||
success = model.export(format='onnx') # export the model to ONNX format
|
||||
# Load a model
|
||||
model = YOLO("yolov8n.yaml") # build a new model from scratch
|
||||
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
|
||||
|
||||
# Use the model
|
||||
results = model.train(data="coco128.yaml", epochs=3) # train the model
|
||||
results = model.val() # evaluate model performance on the validation set
|
||||
results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
|
||||
success = model.export(format="onnx") # export the model to ONNX format
|
||||
|
||||
3. Use the command line interface (CLI):
|
||||
|
||||
@ -161,12 +165,12 @@ def is_pip_package(filepath: str = __name__) -> bool:
|
||||
return spec is not None and spec.origin is not None
|
||||
|
||||
|
||||
def is_dir_writeable(dir_path: str) -> bool:
|
||||
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
||||
"""
|
||||
Check if a directory is writeable.
|
||||
|
||||
Args:
|
||||
dir_path (str): The path to the directory.
|
||||
dir_path (str) or (Path): The path to the directory.
|
||||
|
||||
Returns:
|
||||
bool: True if the directory is writeable, False otherwise.
|
||||
@ -179,6 +183,18 @@ def is_dir_writeable(dir_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_pytest_running():
|
||||
"""
|
||||
Returns a boolean indicating if pytest is currently running or not
|
||||
:return: True if pytest is running, False otherwise
|
||||
"""
|
||||
try:
|
||||
import sys
|
||||
return "pytest" in sys.modules
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_git_root_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||
@ -348,6 +364,17 @@ def yaml_load(file='data.yaml', append_filename=False):
|
||||
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
|
||||
|
||||
|
||||
def set_sentry(dsn=None):
|
||||
"""
|
||||
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
|
||||
"""
|
||||
if dsn and not is_pytest_running():
|
||||
import sentry_sdk # noqa
|
||||
|
||||
import ultralytics
|
||||
sentry_sdk.init(dsn=dsn, traces_sample_rate=1.0, release=ultralytics.__version__, debug=False)
|
||||
|
||||
|
||||
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
||||
"""
|
||||
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
|
||||
@ -364,8 +391,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
||||
|
||||
is_git = is_git_directory() # True if ultralytics installed via git
|
||||
root = get_git_root_dir() if is_git else Path()
|
||||
datasets_root = (root.parent if (is_git and is_dir_writeable(root.parent)) else root).resolve()
|
||||
defaults = {
|
||||
'datasets_dir': str((root.parent if is_git else root) / 'datasets'), # default datasets directory.
|
||||
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
|
||||
'weights_dir': str(root / 'weights'), # default weights directory.
|
||||
'runs_dir': str(root / 'runs'), # default runs directory.
|
||||
'sync': True, # sync analytics to help with YOLO development
|
||||
@ -393,6 +421,26 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
||||
return settings
|
||||
|
||||
|
||||
def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||
"""
|
||||
Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
|
||||
directories.
|
||||
"""
|
||||
SETTINGS.update(kwargs)
|
||||
yaml_save(file, SETTINGS)
|
||||
|
||||
|
||||
def print_settings():
|
||||
"""
|
||||
Function that prints Ultralytics settings
|
||||
"""
|
||||
import json
|
||||
s = f'\n{PREFIX}Settings:\n'
|
||||
s += json.dumps(SETTINGS, indent=2)
|
||||
s += f"\n\nUpdate settings at {USER_CONFIG_DIR / 'settings.yaml'}"
|
||||
LOGGER.info(s)
|
||||
|
||||
|
||||
# Run below code on utils init -----------------------------------------------------------------------------------------
|
||||
|
||||
# Set logger
|
||||
@ -403,15 +451,7 @@ if platform.system() == 'Windows':
|
||||
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
||||
|
||||
# Check first-install steps
|
||||
PREFIX = colorstr("Ultralytics: ")
|
||||
SETTINGS = get_settings()
|
||||
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||||
|
||||
|
||||
def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||
"""
|
||||
Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
|
||||
directories.
|
||||
"""
|
||||
SETTINGS.update(kwargs)
|
||||
|
||||
yaml_save(file, SETTINGS)
|
||||
set_sentry()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra cli)
|
||||
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
|
||||
from ultralytics.yolo.v8 import classify, detect, segment
|
||||
|
||||
__all__ = ["classify", "segment", "detect"]
|
||||
|
Reference in New Issue
Block a user