ultralytics 8.0.50
AMP check and YOLOv5u YAMLs (#1263)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Troy <wudashuo@vip.qq.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com>
This commit is contained in:
@ -61,8 +61,10 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', '
|
||||
'v5loader')
|
||||
|
||||
# Define valid tasks and modes
|
||||
TASKS = 'detect', 'segment', 'classify'
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
TASKS = 'detect', 'segment', 'classify'
|
||||
TASK2DATA = {'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100'}
|
||||
TASK2MODEL = {'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', 'classify': 'yolov8n-cls.pt'}
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
@ -274,8 +276,11 @@ def entrypoint(debug=''):
|
||||
|
||||
# Task
|
||||
task = overrides.pop('task', None)
|
||||
if task and task not in TASKS:
|
||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||
if task:
|
||||
if task not in TASKS:
|
||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||
if 'model' not in overrides:
|
||||
overrides['model'] = TASK2MODEL[task]
|
||||
|
||||
# Model
|
||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
||||
@ -287,9 +292,10 @@ def entrypoint(debug=''):
|
||||
model = YOLO(model, task=task)
|
||||
|
||||
# Task Update
|
||||
if task and task != model.task:
|
||||
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
||||
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
||||
if task != model.task:
|
||||
if task:
|
||||
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
||||
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
||||
task = model.task
|
||||
|
||||
# Mode
|
||||
@ -299,8 +305,7 @@ def entrypoint(debug=''):
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ('train', 'val'):
|
||||
if 'data' not in overrides:
|
||||
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
|
||||
overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||
elif mode == 'export':
|
||||
if 'format' not in overrides:
|
||||
@ -322,4 +327,4 @@ def copy_default_cfg():
|
||||
|
||||
if __name__ == '__main__':
|
||||
# entrypoint(debug='yolo predict model=yolov8n.pt')
|
||||
entrypoint(debug='')
|
||||
entrypoint(debug='yolo train model=yolov8n-seg.pt')
|
||||
|
@ -6,7 +6,7 @@ mode: train # YOLO mode, i.e. train, val, predict, export
|
||||
|
||||
# Train settings -------------------------------------------------------------------------------------------------------
|
||||
model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
|
||||
data: # path to data file, i.e. i.e. coco128.yaml
|
||||
data: # path to data file, i.e. coco128.yaml
|
||||
epochs: 100 # number of epochs to train for
|
||||
patience: 50 # epochs to wait for no observable improvement for early stopping of training
|
||||
batch: 16 # number of images per batch (-1 for AutoBatch)
|
||||
|
Reference in New Issue
Block a user