|
|
@ -48,19 +48,21 @@ CLI_HELP_MSG = \
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# Define keys for arg type checks
|
|
|
|
# Define keys for arg type checks
|
|
|
|
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'}
|
|
|
|
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'
|
|
|
|
CFG_FRACTION_KEYS = {
|
|
|
|
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
|
|
|
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing',
|
|
|
|
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
|
|
|
'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste',
|
|
|
|
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou') # fractional floats limited to 0.0 - 1.0
|
|
|
|
'conf', 'iou'} # fractional floats limited to 0.0 - 1.0
|
|
|
|
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
|
|
|
CFG_INT_KEYS = {
|
|
|
|
'line_thickness', 'workspace', 'nbs', 'save_period')
|
|
|
|
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
|
|
|
CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect',
|
|
|
|
'line_thickness', 'workspace', 'nbs', 'save_period'}
|
|
|
|
'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show',
|
|
|
|
CFG_BOOL_KEYS = {
|
|
|
|
'save_txt', 'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment',
|
|
|
|
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
|
|
|
'agnostic_nms', 'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms',
|
|
|
|
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
|
|
|
|
'v5loader')
|
|
|
|
'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
|
|
|
|
|
|
|
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
|
|
|
|
# Define valid tasks and modes
|
|
|
|
|
|
|
|
TASKS = 'detect', 'segment', 'classify'
|
|
|
|
|
|
|
|
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cfg2dict(cfg):
|
|
|
|
def cfg2dict(cfg):
|
|
|
@ -196,9 +198,6 @@ def entrypoint(debug=''):
|
|
|
|
LOGGER.info(CLI_HELP_MSG)
|
|
|
|
LOGGER.info(CLI_HELP_MSG)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# Define tasks and modes
|
|
|
|
|
|
|
|
tasks = 'detect', 'segment', 'classify'
|
|
|
|
|
|
|
|
modes = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
|
|
|
|
|
|
|
special = {
|
|
|
|
special = {
|
|
|
|
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
|
|
|
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
|
|
|
'checks': checks.check_yolo,
|
|
|
|
'checks': checks.check_yolo,
|
|
|
@ -206,7 +205,7 @@ def entrypoint(debug=''):
|
|
|
|
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
|
|
|
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
|
|
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
|
|
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
|
|
|
'copy-cfg': copy_default_cfg}
|
|
|
|
'copy-cfg': copy_default_cfg}
|
|
|
|
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in tasks}, **{k: None for k in modes}, **special}
|
|
|
|
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
|
|
|
|
|
|
|
|
|
|
|
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
|
|
|
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
|
|
|
special.update({k[0]: v for k, v in special.items()}) # singular
|
|
|
|
special.update({k[0]: v for k, v in special.items()}) # singular
|
|
|
@ -240,9 +239,9 @@ def entrypoint(debug=''):
|
|
|
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
|
|
|
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
|
|
|
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
|
|
|
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
|
|
|
|
|
|
|
|
|
|
|
elif a in tasks:
|
|
|
|
elif a in TASKS:
|
|
|
|
overrides['task'] = a
|
|
|
|
overrides['task'] = a
|
|
|
|
elif a in modes:
|
|
|
|
elif a in MODES:
|
|
|
|
overrides['mode'] = a
|
|
|
|
overrides['mode'] = a
|
|
|
|
elif a in special:
|
|
|
|
elif a in special:
|
|
|
|
special[a]()
|
|
|
|
special[a]()
|
|
|
@ -262,10 +261,10 @@ def entrypoint(debug=''):
|
|
|
|
mode = overrides.get('mode', None)
|
|
|
|
mode = overrides.get('mode', None)
|
|
|
|
if mode is None:
|
|
|
|
if mode is None:
|
|
|
|
mode = DEFAULT_CFG.mode or 'predict'
|
|
|
|
mode = DEFAULT_CFG.mode or 'predict'
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
|
|
|
elif mode not in modes:
|
|
|
|
elif mode not in MODES:
|
|
|
|
if mode not in ('checks', checks):
|
|
|
|
if mode not in ('checks', checks):
|
|
|
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
|
|
|
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
|
|
|
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
|
|
|
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
|
|
|
checks.check_yolo()
|
|
|
|
checks.check_yolo()
|
|
|
|
return
|
|
|
|
return
|
|
|
@ -280,11 +279,11 @@ def entrypoint(debug=''):
|
|
|
|
model = YOLO(model)
|
|
|
|
model = YOLO(model)
|
|
|
|
|
|
|
|
|
|
|
|
# Task
|
|
|
|
# Task
|
|
|
|
# if task and task != model.task:
|
|
|
|
task = overrides.get('task', None)
|
|
|
|
# LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
|
|
|
|
if task is not None and task not in TASKS:
|
|
|
|
# f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
|
|
|
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
|
|
|
overrides['task'] = overrides.get('task', model.task)
|
|
|
|
else:
|
|
|
|
model.task = overrides['task']
|
|
|
|
model.task = task
|
|
|
|
|
|
|
|
|
|
|
|
# Mode
|
|
|
|
# Mode
|
|
|
|
if mode in {'predict', 'track'} and 'source' not in overrides:
|
|
|
|
if mode in {'predict', 'track'} and 'source' not in overrides:
|
|
|
|