From f0fff8c13e07ddd2219bbb2195d3fce54ec1405a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 22 Dec 2022 15:22:16 +0100 Subject: [PATCH] Simplify cli.py and fix Detect train Usage (#83) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/cli.py | 28 ++++++++++++---------------- ultralytics/yolo/v8/detect/train.py | 4 ++-- ultralytics/yolo/v8/segment/train.py | 2 +- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/ultralytics/yolo/cli.py b/ultralytics/yolo/cli.py index 2ff00f0..c0ffeec 100644 --- a/ultralytics/yolo/cli.py +++ b/ultralytics/yolo/cli.py @@ -13,9 +13,9 @@ from .utils import LOGGER, colorstr @hydra.main(version_base=None, config_path="utils/configs", config_name="default") def cli(cfg): LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") + task, mode = cfg.task.lower(), cfg.mode.lower() - module_file = None - if cfg.task.lower() == "init": # special case + if task == "init": # special case shutil.copy2(DEFAULT_CONFIG, os.getcwd()) LOGGER.info(f""" {colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}. @@ -23,25 +23,21 @@ def cli(cfg): yolo task='task' mode='mode' --config-name config_file.yaml """) return - elif cfg.task.lower() == "detect": + elif task == "detect": module_file = yolo.detect - elif cfg.task.lower() == "segment": + elif task == "segment": module_file = yolo.segment - elif cfg.task.lower() == "classify": + elif task == "classify": module_file = yolo.classify + else: + raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`") - if not module_file: - raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`") - - module_function = None - - if cfg.mode.lower() == "train": + if mode == "train": module_function = module_file.train - elif cfg.mode.lower() == "val": + elif mode == "val": module_function = module_file.val - elif cfg.mode.lower() == "predict": + elif mode == "predict": module_function = module_file.predict - - if not module_function: - raise Exception("mode not recognized. Choices are `'train', 'val', 'predict'`") + else: + raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict'`") module_function(cfg) diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index e967ebf..9495374 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -223,9 +223,9 @@ def train(cfg): if __name__ == "__main__": """ CLI usage: - python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 + python ultralytics/yolo/v8/detect/train.py model=yolov5n.yaml data=coco128 epochs=100 img_size=640 TODO: - Direct cli support, i.e, yolov8 classify_train args.epochs 10 + yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=100 """ train() diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 2ec1df1..6a408fb 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -243,7 +243,7 @@ def train(cfg): if __name__ == "__main__": """ CLI usage: - python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 + python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 TODO: Direct cli support, i.e, yolov8 classify_train args.epochs 10