ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -1,421 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import shutil
|
||||
import importlib
|
||||
import sys
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||
get_settings, yaml_load, yaml_print)
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
TASKS = 'detect', 'segment', 'classify', 'pose'
|
||||
TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
|
||||
TASK2MODEL = {
|
||||
'detect': 'yolov8n.pt',
|
||||
'segment': 'yolov8n-seg.pt',
|
||||
'classify': 'yolov8n-cls.pt',
|
||||
'pose': 'yolov8n-pose.pt'}
|
||||
TASK2METRIC = {
|
||||
'detect': 'metrics/mAP50-95(B)',
|
||||
'segment': 'metrics/mAP50-95(M)',
|
||||
'classify': 'metrics/accuracy_top1',
|
||||
'pose': 'metrics/mAP50-95(P)'}
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.cfg'] = importlib.import_module('ultralytics.cfg')
|
||||
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
f"""
|
||||
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of {TASKS}
|
||||
MODE (required) is one of {MODES}
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
|
||||
|
||||
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||
|
||||
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
5. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com
|
||||
Community: https://community.ultralytics.com
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
|
||||
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_width', 'workspace', 'nbs', 'save_period')
|
||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
|
||||
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
|
||||
'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
||||
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
|
||||
Returns:
|
||||
cfg (dict): Configuration object in dictionary format.
|
||||
"""
|
||||
if isinstance(cfg, (str, Path)):
|
||||
cfg = yaml_load(cfg) # load dict
|
||||
elif isinstance(cfg, SimpleNamespace):
|
||||
cfg = vars(cfg) # convert to dict
|
||||
return cfg
|
||||
|
||||
|
||||
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
|
||||
"""
|
||||
Load and merge configuration data from a file or dictionary.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration data.
|
||||
overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
|
||||
|
||||
Returns:
|
||||
(SimpleNamespace): Training arguments namespace.
|
||||
"""
|
||||
cfg = cfg2dict(cfg)
|
||||
|
||||
# Merge overrides
|
||||
if overrides:
|
||||
overrides = cfg2dict(overrides)
|
||||
check_cfg_mismatch(cfg, overrides)
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Special handling for numeric project/name
|
||||
for k in 'project', 'name':
|
||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||
cfg[k] = str(cfg[k])
|
||||
if cfg.get('name') == 'model': # assign model to 'name' arg
|
||||
cfg['name'] = cfg.get('model', '').split('.')[0]
|
||||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||
|
||||
# Type and Value checks
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. "
|
||||
f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be an int (i.e. '{k}=8')")
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def _handle_deprecation(custom):
|
||||
"""
|
||||
Hardcoded function to handle deprecated config keys
|
||||
"""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == 'hide_labels':
|
||||
deprecation_warn(key, 'show_labels')
|
||||
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
||||
if key == 'hide_conf':
|
||||
deprecation_warn(key, 'show_conf')
|
||||
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
||||
if key == 'line_thickness':
|
||||
deprecation_warn(key, 'line_width')
|
||||
custom['line_width'] = custom.pop('line_thickness')
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
"""
|
||||
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||
|
||||
Args:
|
||||
custom (dict): a dictionary of custom configuration options
|
||||
base (dict): a dictionary of base configuration options
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
if mismatched:
|
||||
string = ''
|
||||
for x in mismatched:
|
||||
matches = get_close_matches(x, base) # key list
|
||||
matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
|
||||
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
|
||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||
|
||||
|
||||
def merge_equals_args(args: List[str]) -> List[str]:
|
||||
"""
|
||||
Merges arguments around isolated '=' args in a list of strings.
|
||||
The function considers cases where the first argument ends with '=' or the second starts with '=',
|
||||
as well as when the middle one is an equals sign.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of strings where each element is an argument.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of strings where the arguments around isolated '=' are merged.
|
||||
"""
|
||||
new_args = []
|
||||
for i, arg in enumerate(args):
|
||||
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
||||
new_args[-1] += f'={args[i + 1]}'
|
||||
del args[i + 1]
|
||||
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
||||
new_args.append(f'{arg}{args[i + 1]}')
|
||||
del args[i + 1]
|
||||
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
||||
new_args[-1] += arg
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
|
||||
def handle_yolo_hub(args: List[str]) -> None:
|
||||
"""
|
||||
Handle Ultralytics HUB command-line interface (CLI) commands.
|
||||
|
||||
This function processes Ultralytics HUB CLI commands such as login and logout.
|
||||
It should be called when executing a script with arguments related to HUB authentication.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments
|
||||
|
||||
Example:
|
||||
python my_script.py hub login your_api_key
|
||||
"""
|
||||
from ultralytics import hub
|
||||
|
||||
if args[0] == 'login':
|
||||
key = args[1] if len(args) > 1 else ''
|
||||
# Log in to Ultralytics HUB using the provided API key
|
||||
hub.login(key)
|
||||
elif args[0] == 'logout':
|
||||
# Log out from Ultralytics HUB
|
||||
hub.logout()
|
||||
|
||||
|
||||
def handle_yolo_settings(args: List[str]) -> None:
|
||||
"""
|
||||
Handle YOLO settings command-line interface (CLI) commands.
|
||||
|
||||
This function processes YOLO settings CLI commands such as reset.
|
||||
It should be called when executing a script with arguments related to YOLO settings management.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments for YOLO settings management.
|
||||
|
||||
Example:
|
||||
python my_script.py yolo settings reset
|
||||
"""
|
||||
path = USER_CONFIG_DIR / 'settings.yaml' # get SETTINGS YAML file path
|
||||
if any(args) and args[0] == 'reset':
|
||||
path.unlink() # delete the settings file
|
||||
get_settings() # create new settings
|
||||
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
|
||||
yaml_print(path) # print the current settings
|
||||
|
||||
|
||||
def entrypoint(debug=''):
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
to the package.
|
||||
|
||||
This function allows for:
|
||||
- passing mandatory YOLO args as a list of strings
|
||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||
- running special modes like 'checks'
|
||||
- passing overrides to the package's configuration
|
||||
|
||||
It uses the package's default cfg and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed cfg
|
||||
"""
|
||||
args = (debug.split(' ') if debug else sys.argv)[1:]
|
||||
if not args: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': lambda: handle_yolo_settings(args[1:]),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'hub': lambda: handle_yolo_hub(args[1:]),
|
||||
'login': lambda: handle_yolo_hub(args),
|
||||
'copy-cfg': copy_default_cfg}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||
|
||||
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
||||
special.update({k[0]: v for k, v in special.items()}) # singular
|
||||
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
|
||||
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||
if a.startswith('--'):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
a = a[2:]
|
||||
if a.endswith(','):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
a = a[:-1]
|
||||
if '=' in a:
|
||||
try:
|
||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||
k, v = a.split('=', 1) # split on first '=' sign
|
||||
assert v, f"missing '{k}' value"
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
||||
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
||||
else:
|
||||
if v.lower() == 'none':
|
||||
v = None
|
||||
elif v.lower() == 'true':
|
||||
v = True
|
||||
elif v.lower() == 'false':
|
||||
v = False
|
||||
else:
|
||||
with contextlib.suppress(Exception):
|
||||
v = eval(v)
|
||||
overrides[k] = v
|
||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
||||
|
||||
elif a in TASKS:
|
||||
overrides['task'] = a
|
||||
elif a in MODES:
|
||||
overrides['mode'] = a
|
||||
elif a.lower() in special:
|
||||
special[a.lower()]()
|
||||
return
|
||||
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
||||
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
||||
elif a in DEFAULT_CFG_DICT:
|
||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''})
|
||||
|
||||
# Check keys
|
||||
check_cfg_mismatch(full_args_dict, overrides)
|
||||
|
||||
# Mode
|
||||
mode = overrides.get('mode', None)
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or 'predict'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
elif mode not in MODES:
|
||||
if mode not in ('checks', checks):
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||
checks.check_yolo()
|
||||
return
|
||||
|
||||
# Task
|
||||
task = overrides.pop('task', None)
|
||||
if task:
|
||||
if task not in TASKS:
|
||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||
if 'model' not in overrides:
|
||||
overrides['model'] = TASK2MODEL[task]
|
||||
|
||||
# Model
|
||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
||||
if model is None:
|
||||
model = 'yolov8n.pt'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
overrides['model'] = model
|
||||
if 'rtdetr' in model.lower(): # guess architecture
|
||||
from ultralytics import RTDETR
|
||||
model = RTDETR(model) # no task argument
|
||||
elif 'sam' in model.lower():
|
||||
from ultralytics import SAM
|
||||
model = SAM(model)
|
||||
else:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(model, task=task)
|
||||
if isinstance(overrides.get('pretrained'), str):
|
||||
model.load(overrides['pretrained'])
|
||||
|
||||
# Task Update
|
||||
if task != model.task:
|
||||
if task:
|
||||
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
||||
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
||||
task = model.task
|
||||
|
||||
# Mode
|
||||
if mode in ('predict', 'track') and 'source' not in overrides:
|
||||
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ('train', 'val'):
|
||||
if 'data' not in overrides:
|
||||
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||
elif mode == 'export':
|
||||
if 'format' not in overrides:
|
||||
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||
|
||||
# Run command in python
|
||||
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
|
||||
getattr(model, mode)(**overrides) # default args from model
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_cfg():
|
||||
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
|
||||
entrypoint(debug='')
|
||||
LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.cfg' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
|
||||
"Please use 'ultralytics.cfg' instead.")
|
||||
|
@ -1,114 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Default training settings and hyperparameters for medium-augmentation COCO training
|
||||
|
||||
task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
|
||||
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
|
||||
|
||||
# Train settings -------------------------------------------------------------------------------------------------------
|
||||
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
|
||||
data: # (str, optional) path to data file, i.e. coco128.yaml
|
||||
epochs: 100 # (int) number of epochs to train for
|
||||
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
|
||||
batch: 16 # (int) number of images per batch (-1 for AutoBatch)
|
||||
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
|
||||
save: True # (bool) save train checkpoints and predict results
|
||||
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
|
||||
cache: False # (bool) True/ram, disk or False. Use cache for data loading
|
||||
device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
|
||||
workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
|
||||
project: # (str, optional) project name
|
||||
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
||||
exist_ok: False # (bool) whether to overwrite existing experiment
|
||||
pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
|
||||
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
||||
verbose: True # (bool) whether to print verbose output
|
||||
seed: 0 # (int) random seed for reproducibility
|
||||
deterministic: True # (bool) whether to enable deterministic mode
|
||||
single_cls: False # (bool) train multi-class data as single-class
|
||||
rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
|
||||
cos_lr: False # (bool) use cosine learning rate scheduler
|
||||
close_mosaic: 10 # (int) disable mosaic augmentation for final epochs
|
||||
resume: False # (bool) resume training from last checkpoint
|
||||
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
|
||||
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
|
||||
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
|
||||
# Segmentation
|
||||
overlap_mask: True # (bool) masks should overlap during training (segment train only)
|
||||
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||
# Classification
|
||||
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
val: True # (bool) validate/test during training
|
||||
split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
|
||||
save_json: False # (bool) save results to JSON file
|
||||
save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
|
||||
conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
||||
iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
|
||||
max_det: 300 # (int) maximum number of detections per image
|
||||
half: False # (bool) use half precision (FP16)
|
||||
dnn: False # (bool) use OpenCV DNN for ONNX inference
|
||||
plots: True # (bool) save plots during train/val
|
||||
|
||||
# Prediction settings --------------------------------------------------------------------------------------------------
|
||||
source: # (str, optional) source directory for images or videos
|
||||
show: False # (bool) show results if possible
|
||||
save_txt: False # (bool) save results as .txt file
|
||||
save_conf: False # (bool) save results with confidence scores
|
||||
save_crop: False # (bool) save cropped images with results
|
||||
show_labels: True # (bool) show object labels in plots
|
||||
show_conf: True # (bool) show object confidence scores in plots
|
||||
vid_stride: 1 # (int) video frame-rate stride
|
||||
line_width: # (int, optional) line width of the bounding boxes, auto if missing
|
||||
visualize: False # (bool) visualize model features
|
||||
augment: False # (bool) apply image augmentation to prediction sources
|
||||
agnostic_nms: False # (bool) class-agnostic NMS
|
||||
classes: # (int | list[int], optional) filter results by class, i.e. class=0, or class=[0,2,3]
|
||||
retina_masks: False # (bool) use high-resolution segmentation masks
|
||||
boxes: True # (bool) Show boxes in segmentation predictions
|
||||
|
||||
# Export settings ------------------------------------------------------------------------------------------------------
|
||||
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
|
||||
keras: False # (bool) use Kera=s
|
||||
optimize: False # (bool) TorchScript: optimize for mobile
|
||||
int8: False # (bool) CoreML/TF INT8 quantization
|
||||
dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
|
||||
simplify: False # (bool) ONNX: simplify model
|
||||
opset: # (int, optional) ONNX: opset version
|
||||
workspace: 4 # (int) TensorRT: workspace size (GB)
|
||||
nms: False # (bool) CoreML: add NMS
|
||||
|
||||
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||
lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # (float) final learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # (float) SGD momentum/Adam beta1
|
||||
weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
|
||||
warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # (float) warmup initial momentum
|
||||
warmup_bias_lr: 0.1 # (float) warmup initial bias lr
|
||||
box: 7.5 # (float) box loss gain
|
||||
cls: 0.5 # (float) cls loss gain (scale with pixels)
|
||||
dfl: 1.5 # (float) dfl loss gain
|
||||
pose: 12.0 # (float) pose loss gain
|
||||
kobj: 1.0 # (float) keypoint obj loss gain
|
||||
label_smoothing: 0.0 # (float) label smoothing (fraction)
|
||||
nbs: 64 # (int) nominal batch size
|
||||
hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
|
||||
degrees: 0.0 # (float) image rotation (+/- deg)
|
||||
translate: 0.1 # (float) image translation (+/- fraction)
|
||||
scale: 0.5 # (float) image scale (+/- gain)
|
||||
shear: 0.0 # (float) image shear (+/- deg)
|
||||
perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.0 # (float) image flip up-down (probability)
|
||||
fliplr: 0.5 # (float) image flip left-right (probability)
|
||||
mosaic: 1.0 # (float) image mosaic (probability)
|
||||
mixup: 0.0 # (float) image mixup (probability)
|
||||
copy_paste: 0.0 # (float) segment copy-paste (probability)
|
||||
|
||||
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
||||
cfg: # (str, optional) for overriding defaults.yaml
|
||||
|
||||
# Tracker settings ------------------------------------------------------------------------------------------------------
|
||||
tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
|
Reference in New Issue
Block a user