Simplify cli.py and fix Detect train Usage (#83)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		| @ -13,9 +13,9 @@ from .utils import LOGGER, colorstr | ||||
| @hydra.main(version_base=None, config_path="utils/configs", config_name="default") | ||||
| def cli(cfg): | ||||
|     LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") | ||||
|     task, mode = cfg.task.lower(), cfg.mode.lower() | ||||
|  | ||||
|     module_file = None | ||||
|     if cfg.task.lower() == "init":  # special case | ||||
|     if task == "init":  # special case | ||||
|         shutil.copy2(DEFAULT_CONFIG, os.getcwd()) | ||||
|         LOGGER.info(f""" | ||||
|         {colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}. | ||||
| @ -23,25 +23,21 @@ def cli(cfg): | ||||
|         yolo task='task' mode='mode' --config-name config_file.yaml | ||||
|                     """) | ||||
|         return | ||||
|     elif cfg.task.lower() == "detect": | ||||
|     elif task == "detect": | ||||
|         module_file = yolo.detect | ||||
|     elif cfg.task.lower() == "segment": | ||||
|     elif task == "segment": | ||||
|         module_file = yolo.segment | ||||
|     elif cfg.task.lower() == "classify": | ||||
|     elif task == "classify": | ||||
|         module_file = yolo.classify | ||||
|     else: | ||||
|         raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`") | ||||
|  | ||||
|     if not module_file: | ||||
|         raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`") | ||||
|  | ||||
|     module_function = None | ||||
|  | ||||
|     if cfg.mode.lower() == "train": | ||||
|     if mode == "train": | ||||
|         module_function = module_file.train | ||||
|     elif cfg.mode.lower() == "val": | ||||
|     elif mode == "val": | ||||
|         module_function = module_file.val | ||||
|     elif cfg.mode.lower() == "predict": | ||||
|     elif mode == "predict": | ||||
|         module_function = module_file.predict | ||||
|  | ||||
|     if not module_function: | ||||
|         raise Exception("mode not recognized. Choices are `'train', 'val', 'predict'`") | ||||
|     else: | ||||
|         raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict'`") | ||||
|     module_function(cfg) | ||||
|  | ||||
| @ -223,9 +223,9 @@ def train(cfg): | ||||
| if __name__ == "__main__": | ||||
|     """ | ||||
|     CLI usage: | ||||
|     python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 | ||||
|     python ultralytics/yolo/v8/detect/train.py model=yolov5n.yaml data=coco128 epochs=100 img_size=640 | ||||
|  | ||||
|     TODO: | ||||
|     Direct cli support, i.e, yolov8 classify_train args.epochs 10 | ||||
|     yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=100 | ||||
|     """ | ||||
|     train() | ||||
|  | ||||
| @ -243,7 +243,7 @@ def train(cfg): | ||||
| if __name__ == "__main__": | ||||
|     """ | ||||
|     CLI usage: | ||||
|     python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 | ||||
|     python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 | ||||
|  | ||||
|     TODO: | ||||
|     Direct cli support, i.e, yolov8 classify_train args.epochs 10 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user