Simplify cli.py and fix Detect train Usage (#83)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 74feef30c4
commit f0fff8c13e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,9 +13,9 @@ from .utils import LOGGER, colorstr
@hydra.main(version_base=None, config_path="utils/configs", config_name="default") @hydra.main(version_base=None, config_path="utils/configs", config_name="default")
def cli(cfg): def cli(cfg):
LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
task, mode = cfg.task.lower(), cfg.mode.lower()
module_file = None if task == "init": # special case
if cfg.task.lower() == "init": # special case
shutil.copy2(DEFAULT_CONFIG, os.getcwd()) shutil.copy2(DEFAULT_CONFIG, os.getcwd())
LOGGER.info(f""" LOGGER.info(f"""
{colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}. {colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}.
@ -23,25 +23,21 @@ def cli(cfg):
yolo task='task' mode='mode' --config-name config_file.yaml yolo task='task' mode='mode' --config-name config_file.yaml
""") """)
return return
elif cfg.task.lower() == "detect": elif task == "detect":
module_file = yolo.detect module_file = yolo.detect
elif cfg.task.lower() == "segment": elif task == "segment":
module_file = yolo.segment module_file = yolo.segment
elif cfg.task.lower() == "classify": elif task == "classify":
module_file = yolo.classify module_file = yolo.classify
else:
raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`")
if not module_file: if mode == "train":
raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`")
module_function = None
if cfg.mode.lower() == "train":
module_function = module_file.train module_function = module_file.train
elif cfg.mode.lower() == "val": elif mode == "val":
module_function = module_file.val module_function = module_file.val
elif cfg.mode.lower() == "predict": elif mode == "predict":
module_function = module_file.predict module_function = module_file.predict
else:
if not module_function: raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict'`")
raise Exception("mode not recognized. Choices are `'train', 'val', 'predict'`")
module_function(cfg) module_function(cfg)

@ -223,9 +223,9 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
""" """
CLI usage: CLI usage:
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 python ultralytics/yolo/v8/detect/train.py model=yolov5n.yaml data=coco128 epochs=100 img_size=640
TODO: TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10 yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=100
""" """
train() train()

@ -243,7 +243,7 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
""" """
CLI usage: CLI usage:
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640
TODO: TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10 Direct cli support, i.e, yolov8 classify_train args.epochs 10

Loading…
Cancel
Save