|
|
@ -8,9 +8,9 @@ from pathlib import Path
|
|
|
|
from types import SimpleNamespace
|
|
|
|
from types import SimpleNamespace
|
|
|
|
from typing import Dict, List, Union
|
|
|
|
from typing import Dict, List, Union
|
|
|
|
|
|
|
|
|
|
|
|
from ultralytics import __version__
|
|
|
|
|
|
|
|
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT,
|
|
|
|
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT,
|
|
|
|
USER_CONFIG_DIR, IterableSimpleNamespace, colorstr, emojis, yaml_load, yaml_print)
|
|
|
|
USER_CONFIG_DIR, IterableSimpleNamespace, __version__, colorstr, emojis, yaml_load,
|
|
|
|
|
|
|
|
yaml_print)
|
|
|
|
from ultralytics.yolo.utils.checks import check_yolo
|
|
|
|
from ultralytics.yolo.utils.checks import check_yolo
|
|
|
|
|
|
|
|
|
|
|
|
CLI_HELP_MSG = \
|
|
|
|
CLI_HELP_MSG = \
|
|
|
@ -25,13 +25,13 @@ CLI_HELP_MSG = \
|
|
|
|
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
|
|
|
|
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
|
|
|
|
|
|
|
|
|
|
|
|
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
|
|
|
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
|
|
|
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
|
|
|
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
|
|
|
|
|
|
|
|
|
|
|
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
|
|
|
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
|
|
|
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
|
|
|
yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
|
|
|
|
|
|
|
|
|
|
|
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
|
|
|
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
|
|
|
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
|
|
|
yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
|
|
|
|
|
|
|
|
|
|
|
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
|
|
|
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
|
|
|
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
|
|
|
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
|
|
@ -56,7 +56,7 @@ CFG_FRACTION_KEYS = {
|
|
|
|
'mixup', 'copy_paste', 'conf', 'iou'}
|
|
|
|
'mixup', 'copy_paste', 'conf', 'iou'}
|
|
|
|
CFG_INT_KEYS = {
|
|
|
|
CFG_INT_KEYS = {
|
|
|
|
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
|
|
|
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
|
|
|
'line_thickness', 'workspace', 'nbs'}
|
|
|
|
'line_thickness', 'workspace', 'nbs', 'save_period'}
|
|
|
|
CFG_BOOL_KEYS = {
|
|
|
|
CFG_BOOL_KEYS = {
|
|
|
|
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
|
|
|
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
|
|
|
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
|
|
|
|
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
|
|
|
@ -131,7 +131,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
|
|
|
|
return IterableSimpleNamespace(**cfg)
|
|
|
|
return IterableSimpleNamespace(**cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_cfg_mismatch(base: Dict, custom: Dict):
|
|
|
|
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
|
|
|
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
|
|
|
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
|
|
|
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
|
|
@ -143,12 +143,12 @@ def check_cfg_mismatch(base: Dict, custom: Dict):
|
|
|
|
base, custom = (set(x.keys()) for x in (base, custom))
|
|
|
|
base, custom = (set(x.keys()) for x in (base, custom))
|
|
|
|
mismatched = [x for x in custom if x not in base]
|
|
|
|
mismatched = [x for x in custom if x not in base]
|
|
|
|
if mismatched:
|
|
|
|
if mismatched:
|
|
|
|
|
|
|
|
string = ''
|
|
|
|
for x in mismatched:
|
|
|
|
for x in mismatched:
|
|
|
|
matches = get_close_matches(x, base, 3, 0.6)
|
|
|
|
matches = get_close_matches(x, base)
|
|
|
|
match_str = f"Similar arguments are {matches}." if matches else 'There are no similar arguments.'
|
|
|
|
match_str = f"Similar arguments are {matches}." if matches else ''
|
|
|
|
LOGGER.warning(f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}")
|
|
|
|
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
|
|
|
LOGGER.warning(CLI_HELP_MSG)
|
|
|
|
raise SyntaxError(string + CLI_HELP_MSG) from e
|
|
|
|
sys.exit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_equals_args(args: List[str]) -> List[str]:
|
|
|
|
def merge_equals_args(args: List[str]) -> List[str]:
|
|
|
@ -178,10 +178,6 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|
|
|
return new_args
|
|
|
|
return new_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def argument_error(arg):
|
|
|
|
|
|
|
|
return SyntaxError(f"'{arg}' is not a valid YOLO argument.\n{CLI_HELP_MSG}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def entrypoint(debug=''):
|
|
|
|
def entrypoint(debug=''):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
|
|
|
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
|
|
@ -212,6 +208,7 @@ def entrypoint(debug=''):
|
|
|
|
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
|
|
|
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
|
|
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
|
|
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
|
|
|
'copy-cfg': copy_default_cfg}
|
|
|
|
'copy-cfg': copy_default_cfg}
|
|
|
|
|
|
|
|
FULL_ARGS_DICT = {**DEFAULT_CFG_DICT, **{k: None for k in tasks}, **{k: None for k in modes}, **special}
|
|
|
|
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
|
|
|
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
|
|
|
|
|
|
|
|
|
|
|
overrides = {} # basic overrides, i.e. imgsz=320
|
|
|
|
overrides = {} # basic overrides, i.e. imgsz=320
|
|
|
@ -236,7 +233,7 @@ def entrypoint(debug=''):
|
|
|
|
v = eval(v)
|
|
|
|
v = eval(v)
|
|
|
|
overrides[k] = v
|
|
|
|
overrides[k] = v
|
|
|
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
|
|
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
|
|
|
raise argument_error(a) from e
|
|
|
|
check_cfg_mismatch(FULL_ARGS_DICT, {a: ""}, e)
|
|
|
|
|
|
|
|
|
|
|
|
elif a in tasks:
|
|
|
|
elif a in tasks:
|
|
|
|
overrides['task'] = a
|
|
|
|
overrides['task'] = a
|
|
|
@ -251,7 +248,7 @@ def entrypoint(debug=''):
|
|
|
|
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
|
|
|
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
|
|
|
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
|
|
|
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise argument_error(a)
|
|
|
|
check_cfg_mismatch(FULL_ARGS_DICT, {a: ""})
|
|
|
|
|
|
|
|
|
|
|
|
# Defaults
|
|
|
|
# Defaults
|
|
|
|
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
|
|
|
|
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
|
|
|
|