From 453b5f259a2ea93a04081f3e83c34ae72e64c651 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 17 Jan 2023 23:00:33 +0100 Subject: [PATCH] CLI Simplification (#449) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 14 ++--- README.zh-CN.md | 14 ++--- docs/cli.md | 50 ++++++++--------- docs/config.md | 18 +++---- docs/{home.md => index.md} | 0 docs/quickstart.md | 4 +- docs/tasks/classification.md | 14 ++--- docs/tasks/detection.md | 14 ++--- docs/tasks/segmentation.md | 14 ++--- mkdocs.yml | 13 +++-- setup.py | 6 ++- tests/test_cli.py | 26 ++++----- ultralytics/yolo/cli.py | 81 ++++++++++++++++++++++------ ultralytics/yolo/engine/predictor.py | 2 +- ultralytics/yolo/utils/__init__.py | 78 ++++++++++++++++++++------- ultralytics/yolo/v8/__init__.py | 2 +- 16 files changed, 218 insertions(+), 132 deletions(-) rename docs/{home.md => index.md} (100%) diff --git a/README.md b/README.md index a669a87..8ffef70 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ pip install ultralytics YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command: ```bash -yolo task=detect mode=predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" +yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" ``` `yolo` can be used for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list @@ -158,10 +158,10 @@ See [Detection Docs](https://docs.ultralytics.com/tasks/detection/) for usage ex | [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt) | 640 | 53.9 | 479.1 | 3.53 | 68.2 | 257.8 | - **mAPval** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset. -
Reproduce by `yolo mode=val task=detect data=coco.yaml device=0` +
Reproduce by `yolo val detect data=coco.yaml device=0` - **Speed** averaged over COCO val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. -
Reproduce by `yolo mode=val task=detect data=coco128.yaml batch=1 device=0/cpu` +
Reproduce by `yolo val detect data=coco128.yaml batch=1 device=0/cpu` @@ -178,10 +178,10 @@ See [Segmentation Docs](https://docs.ultralytics.com/tasks/segmentation/) for us | [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt) | 640 | 53.4 | 43.4 | 712.1 | 4.02 | 71.8 | 344.1 | - **mAPval** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset. -
Reproduce by `yolo mode=val task=segment data=coco.yaml device=0` +
Reproduce by `yolo val segment data=coco.yaml device=0` - **Speed** averaged over COCO val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. -
Reproduce by `yolo mode=val task=segment data=coco128-seg.yaml batch=1 device=0/cpu` +
Reproduce by `yolo val segment data=coco128-seg.yaml batch=1 device=0/cpu` @@ -198,10 +198,10 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classification/) fo | [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 | - **acc** values are model accuracies on the [ImageNet](https://www.image-net.org/) dataset validation set. -
Reproduce by `yolo mode=val task=classify data=path/to/ImageNet device=0` +
Reproduce by `yolo val classify data=path/to/ImageNet device=0` - **Speed** averaged over ImageNet val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. -
Reproduce by `yolo mode=val task=classify data=path/to/ImageNet batch=1 device=0/cpu` +
Reproduce by `yolo val classify data=path/to/ImageNet batch=1 device=0/cpu` diff --git a/README.zh-CN.md b/README.zh-CN.md index f6e34e1..78eaa8d 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -67,7 +67,7 @@ pip install ultralytics YOLOv8 可以直接在命令行界面(CLI)中使用 `yolo` 命令运行: ```bash -yolo task=detect mode=predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" +yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" ``` `yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/config/)的完整列表。 @@ -124,9 +124,9 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式 | [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt) | 640 | 53.9 | 479.1 | 3.53 | 68.2 | 257.8 | - **mAPval** 结果都在 [COCO val2017](http://cocodataset.org) 数据集上,使用单模型单尺度测试得到。 -
复现命令 `yolo mode=val task=detect data=coco.yaml device=0` +
复现命令 `yolo val detect data=coco.yaml device=0` - **推理速度**使用 COCO 验证集图片推理时间进行平均得到,测试环境使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例。 -
复现命令 `yolo mode=val task=detect data=coco128.yaml batch=1 device=0/cpu` +
复现命令 `yolo val detect data=coco128.yaml batch=1 device=0/cpu` @@ -141,9 +141,9 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式 | [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt) | 640 | 53.4 | 43.4 | 712.1 | 4.02 | 71.8 | 344.1 | - **mAPval** 结果都在 [COCO val2017](http://cocodataset.org) 数据集上,使用单模型单尺度测试得到。 -
复现命令 `yolo mode=val task=segment data=coco.yaml device=0` +
复现命令 `yolo val segment data=coco.yaml device=0` - **推理速度**使用 COCO 验证集图片推理时间进行平均得到,测试环境使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例。 -
复现命令 `yolo mode=val task=segment data=coco128-seg.yaml batch=1 device=0/cpu` +
复现命令 `yolo val segment data=coco128-seg.yaml batch=1 device=0/cpu` @@ -158,9 +158,9 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式 | [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 | - **acc** 都在 [ImageNet](https://www.image-net.org/) 数据集上,使用单模型单尺度测试得到。 -
复现命令 `yolo mode=val task=classify data=path/to/ImageNet device=0` +
复现命令 `yolo val classify data=path/to/ImageNet device=0` - **推理速度**使用 ImageNet 验证集图片推理时间进行平均得到,测试环境使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例。 -
复现命令 `yolo mode=val task=classify data=path/to/ImageNet batch=1 device=0/cpu` +
复现命令 `yolo val classify data=path/to/ImageNet batch=1 device=0/cpu` diff --git a/docs/cli.md b/docs/cli.md index 63e53f4..ff0a72a 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -4,37 +4,37 @@ YOLO command line interface is the easiest way to get started. !!! tip "Syntax" ```bash - yolo task=detect mode=train model=yolov8n.yaml epochs=1 ... - ... ... ... - segment predict yolov8n-seg.pt - classify val yolov8n-cls.pt + yolo task=detect mode=train model=yolov8n.yaml args... + classify predict yolov8n-cls.yaml args... + segment val yolov8n-seg.yaml args... + export yolov8n.pt format=onnx args... ``` -The experiment arguments can be overridden directly by pass `arg=val` covered in the next section. You can run any -supported task by setting `task` and `mode` in cli. +The default arguments can be overridden directly by passing custom `arg=val` covered in the next section. You can run +any supported task by setting `task` and `mode` in CLI. === "Training" | | `task` | snippet | |------------------|------------|------------------------------------------------------------| - | Detection | `detect` |
yolo task=detect mode=train       
| - | Instance Segment | `segment` |
yolo task=segment mode=train      
| - | Classification | `classify` |
yolo task=classify mode=train    
| + | Detection | `detect` |
yolo detect train       
| + | Instance Segment | `segment` |
yolo segment train      
| + | Classification | `classify` |
yolo classify train    
| === "Prediction" | | `task` | snippet | |------------------|------------|--------------------------------------------------------------| - | Detection | `detect` |
yolo task=detect mode=predict       
| - | Instance Segment | `segment` |
yolo task=segment mode=predict     
| - | Classification | `classify` |
yolo task=classify mode=predict    
| + | Detection | `detect` |
yolo detect predict       
| + | Instance Segment | `segment` |
yolo segment predict     
| + | Classification | `classify` |
yolo classify predict    
| === "Validation" | | `task` | snippet | |------------------|------------|-----------------------------------------------------------| - | Detection | `detect` |
yolo task=detect mode=val        
| - | Instance Segment | `segment` |
yolo task=segment mode=val       
| - | Classification | `classify` |
yolo task=classify mode=val      
| + | Detection | `detect` |
yolo detect val        
| + | Instance Segment | `segment` |
yolo segment val       
| + | Classification | `classify` |
yolo classify val      
| !!! note "" @@ -44,19 +44,19 @@ supported task by setting `task` and `mode` in cli. ## Overriding default config arguments -All global default arguments can be overriden by simply passing them as arguments in the cli. +Default arguments can be overriden by simply passing them as arguments in the CLI. !!! tip "" === "Syntax" ```bash - yolo task= ... mode= ... {++ arg=val ++} + yolo task mode arg=val... ``` === "Example" Perform detection training for `10 epochs` with `learning_rate` of `0.01` ```bash - yolo task=detect mode=train {++ epochs=10 lr0=0.01 ++} + yolo detect train epochs=10 lr0=0.01 ``` --- @@ -67,23 +67,19 @@ You can override config file entirely by passing a new file. You can create a co current working dir as follows: ```bash -yolo task=init +yolo copy-config ``` -You can then use `cfg=name.yaml` command to pass the new config file +You can then use `cfg=default_copy.yaml` command to pass the new config file along with any addition args: ```bash -yolo cfg=default.yaml +yolo cfg=default_copy.yaml args... ``` ??? example === "Command" ```bash - yolo task=init - yolo cfg=default.yaml + yolo copy-config + yolo cfg=default_copy.yaml args... ``` - === "Results" - TODO: add terminal output - - diff --git a/docs/config.md b/docs/config.md index 57078eb..09226d6 100644 --- a/docs/config.md +++ b/docs/config.md @@ -38,13 +38,13 @@ include train, val, and predict. - Predict: The predict mode is used to make predictions with the model on new data. This mode is typically used in production or when deploying the model to users. -| Key | Value | Description | -|--------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| task | `detect` | Set the task via CLI. See Tasks for all supported tasks like - `detect`, `segment`, `classify`.
- `init` is a special case that creates a copy of default.yaml configs to the current working dir | -| mode | `train` | Set the mode via CLI. It can be `train`, `val`, `predict` | -| resume | `False` | Resume last given task when set to `True`.
Resume from a given checkpoint is `model.pt` is passed | -| model | null | Set the model. Format can differ for task type. Supports `model_name`, `model.yaml` & `model.pt` | -| data | null | Set the data. Format can differ for task type. Supports `data.yaml`, `data_folder`, `dataset_name` | +| Key | Value | Description | +|--------|----------|--------------------------------------------------------------------------------------------------------| +| task | `detect` | Set the task via CLI. See Tasks for all supported tasks like - `detect`, `segment`, `classify` | +| mode | `train` | Set the mode via CLI. It can be `train`, `val`, `predict`, `export` | +| resume | `False` | Resume last given task when set to `True`.
Resume from a given checkpoint is `model.pt` is passed | +| model | null | Set the model. Format can differ for task type. Supports `model_name`, `model.yaml` & `model.pt` | +| data | null | Set the data. Format can differ for task type. Supports `data.yaml`, `data_folder`, `dataset_name` | ### Training @@ -197,6 +197,6 @@ it easier to debug and optimize the training process. |-----------|---------|---------------------------------------------------------------------------------------------| | project: | 'runs' | The project name | | name: | 'exp' | The run name. `exp` gets automatically incremented if not specified, i.e, `exp`, `exp2` ... | -| exist_ok: | `False` | ??? | +| exist_ok: | `False` | Will replace current directory contents if set to True and output directory exists. | | plots | `False` | **Validation**: Save plots while validation | -| nosave | `False` | Don't save any plots, models or files | \ No newline at end of file +| save | `False` | Save any plots, models or files | \ No newline at end of file diff --git a/docs/home.md b/docs/index.md similarity index 100% rename from docs/home.md rename to docs/index.md diff --git a/docs/quickstart.md b/docs/quickstart.md index 4c4684b..ac5e791 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -36,11 +36,11 @@ CLI requires no customization or code. You can simply run all tasks from the ter === "Example training" ```bash - yolo task=detect mode=train model=yolov8n.pt data=coco128.yaml device=0 + yolo detect train model=yolov8n.pt data=coco128.yaml device=0 ``` === "Example Multi-GPU training" ```bash - yolo task=detect mode=train model=yolov8n.pt data=coco128.yaml device=\'0,1,2,3\' + yolo detect train model=yolov8n.pt data=coco128.yaml device=\'0,1,2,3\' ``` [CLI Guide](cli.md){ .md-button .md-button--primary} diff --git a/docs/tasks/classification.md b/docs/tasks/classification.md index 153d7cd..a6e7a99 100644 --- a/docs/tasks/classification.md +++ b/docs/tasks/classification.md @@ -35,7 +35,7 @@ see the [Configuration](../config.md) page. === "CLI" ```bash - yolo task=classify mode=train data=mnist160 model=yolov8n-cls.pt epochs=100 imgsz=64 + yolo classify train data=mnist160 model=yolov8n-cls.pt epochs=100 imgsz=64 ``` ## Val @@ -60,8 +60,8 @@ it's training `data` and arguments as model attributes. === "CLI" ```bash - yolo task=classify mode=val model=yolov8n-cls.pt # val official model - yolo task=classify mode=val model=path/to/best.pt # val custom model + yolo classify val model=yolov8n-cls.pt # val official model + yolo classify val model=path/to/best.pt # val custom model ``` ## Predict @@ -85,8 +85,8 @@ Use a trained YOLOv8n-cls model to run predictions on images. === "CLI" ```bash - yolo task=classify mode=predict model=yolov8n-cls.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model - yolo task=classify mode=predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model + yolo classify predict model=yolov8n-cls.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model + yolo classify predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model ``` ## Export @@ -110,8 +110,8 @@ Export a YOLOv8n-cls model to a different format like ONNX, CoreML, etc. === "CLI" ```bash - yolo mode=export model=yolov8n-cls.pt format=onnx # export official model - yolo mode=export model=path/to/best.pt format=onnx # export custom trained model + yolo export model=yolov8n-cls.pt format=onnx # export official model + yolo export model=path/to/best.pt format=onnx # export custom trained model ``` Available YOLOv8-cls export formats include: diff --git a/docs/tasks/detection.md b/docs/tasks/detection.md index e208239..4a7df4b 100644 --- a/docs/tasks/detection.md +++ b/docs/tasks/detection.md @@ -35,7 +35,7 @@ the [Configuration](../config.md) page. === "CLI" ```bash - yolo task=detect mode=train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640 + yolo detect train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640 ``` ## Val @@ -60,8 +60,8 @@ training `data` and arguments as model attributes. === "CLI" ```bash - yolo task=detect mode=val model=yolov8n.pt # val official model - yolo task=detect mode=val model=path/to/best.pt # val custom model + yolo detect val model=yolov8n.pt # val official model + yolo detect val model=path/to/best.pt # val custom model ``` ## Predict @@ -85,8 +85,8 @@ Use a trained YOLOv8n model to run predictions on images. === "CLI" ```bash - yolo task=detect mode=predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model - yolo task=detect mode=predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model + yolo detect predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model + yolo detect predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model ``` ## Export @@ -110,8 +110,8 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc. === "CLI" ```bash - yolo mode=export model=yolov8n.pt format=onnx # export official model - yolo mode=export model=path/to/best.pt format=onnx # export custom trained model + yolo export model=yolov8n.pt format=onnx # export official model + yolo export model=path/to/best.pt format=onnx # export custom trained model ``` Available YOLOv8 export formats include: diff --git a/docs/tasks/segmentation.md b/docs/tasks/segmentation.md index 5b91dae..977819e 100644 --- a/docs/tasks/segmentation.md +++ b/docs/tasks/segmentation.md @@ -35,7 +35,7 @@ arguments see the [Configuration](../config.md) page. === "CLI" ```bash - yolo task=segment mode=train data=coco128-seg.yaml model=yolov8n-seg.pt epochs=100 imgsz=640 + yolo segment train data=coco128-seg.yaml model=yolov8n-seg.pt epochs=100 imgsz=640 ``` ## Val @@ -60,8 +60,8 @@ retains it's training `data` and arguments as model attributes. === "CLI" ```bash - yolo task=segment mode=val model=yolov8n-seg.pt # val official model - yolo task=segment mode=val model=path/to/best.pt # val custom model + yolo segment val model=yolov8n-seg.pt # val official model + yolo segment val model=path/to/best.pt # val custom model ``` ## Predict @@ -85,8 +85,8 @@ Use a trained YOLOv8n-seg model to run predictions on images. === "CLI" ```bash - yolo task=segment mode=predict model=yolov8n-seg.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model - yolo task=segment mode=predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model + yolo segment predict model=yolov8n-seg.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model + yolo segment predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model ``` ## Export @@ -110,8 +110,8 @@ Export a YOLOv8n-seg model to a different format like ONNX, CoreML, etc. === "CLI" ```bash - yolo mode=export model=yolov8n-seg.pt format=onnx # export official model - yolo mode=export model=path/to/best.pt format=onnx # export custom trained model + yolo export model=yolov8n-seg.pt format=onnx # export official model + yolo export model=path/to/best.pt format=onnx # export custom trained model ``` Available YOLOv8-seg export formats include: diff --git a/mkdocs.yml b/mkdocs.yml index 1eb2a17..249701c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,18 +75,18 @@ plugins: # Primary navigation nav: - - Home: home.md + - Home: index.md - Quickstart: quickstart.md - Tasks: - Detection: tasks/detection.md - Segmentation: tasks/segmentation.md - Classification: tasks/classification.md - Usage: - - CLI: cli.md - - Python: python.md - - Predict: predict.md - - Configuration: config.md - - Customization Guide: engine.md + - CLI: cli.md + - Python: python.md + - Predict: predict.md + - Configuration: config.md + - Customization Guide: engine.md - Ultralytics HUB: hub.md - iOS and Android App: app.md - Reference: @@ -99,4 +99,3 @@ nav: - Results: reference/results.md - ultralytics.nn: reference/nn.md - Operations: reference/ops.md - - Docs: README.md diff --git a/setup.py b/setup.py index 590159c..109ce5e 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ FILE = Path(__file__).resolve() PARENT = FILE.parent # root directory README = (PARENT / "README.md").read_text(encoding="utf-8") REQUIREMENTS = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements((PARENT / 'requirements.txt').read_text())] +PKG_REQUIREMENTS = ['sentry_sdk'] # pip-only requirements def get_version(): @@ -35,7 +36,7 @@ setup( author_email='hello@ultralytics.com', packages=find_packages(), # required include_package_data=True, - install_requires=REQUIREMENTS, + install_requires=REQUIREMENTS + PKG_REQUIREMENTS, extras_require={ 'dev': ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material']}, @@ -49,4 +50,5 @@ setup( "Topic :: Scientific/Engineering :: Image Recognition", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", "Operating System :: Microsoft :: Windows"], keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics", - entry_points={'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli']}) + entry_points={ + 'console_scripts': ['yolo = ultralytics.yolo.cli:entrypoint', 'ultralytics = ultralytics.yolo.cli:entrypoint']}) diff --git a/tests/test_cli.py b/tests/test_cli.py index 95daa04..65dc367 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,30 +14,32 @@ def run(cmd): subprocess.run(cmd.split(), check=True) -def test_checks(): - run('yolo mode=checks') +def test_special_modes(): + run('yolo checks') + run('yolo settings') + run('yolo help') # Train checks --------------------------------------------------------------------------------------------------------- def test_train_det(): - run(f'yolo mode=train task=detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1') + run(f'yolo train detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1') def test_train_seg(): - run(f'yolo mode=train task=segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1') + run(f'yolo train segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1') def test_train_cls(): - run(f'yolo mode=train task=classify model={CFG}-cls.yaml data=mnist160 imgsz=32 epochs=1') + run(f'yolo train classify model={CFG}-cls.yaml data=mnist160 imgsz=32 epochs=1') # Val checks ----------------------------------------------------------------------------------------------------------- def test_val_detect(): - run(f'yolo mode=val task=detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1') + run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1') def test_val_segment(): - run(f'yolo mode=val task=segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1') + run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1') def test_val_classify(): @@ -46,11 +48,11 @@ def test_val_classify(): # Predict checks ------------------------------------------------------------------------------------------------------- def test_predict_detect(): - run(f"yolo mode=predict task=detect model={MODEL}.pt source={ROOT / 'assets'}") + run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'}") def test_predict_segment(): - run(f"yolo mode=predict task=segment model={MODEL}-seg.pt source={ROOT / 'assets'}") + run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'}") def test_predict_classify(): @@ -59,12 +61,12 @@ def test_predict_classify(): # Export checks -------------------------------------------------------------------------------------------------------- def test_export_detect_torchscript(): - run(f'yolo mode=export model={MODEL}.pt format=torchscript') + run(f'yolo export model={MODEL}.pt format=torchscript') def test_export_segment_torchscript(): - run(f'yolo mode=export model={MODEL}-seg.pt format=torchscript') + run(f'yolo export model={MODEL}-seg.pt format=torchscript') def test_export_classify_torchscript(): - run(f'yolo mode=export model={MODEL}-cls.pt format=torchscript') + run(f'yolo export model={MODEL}-cls.pt format=torchscript') diff --git a/ultralytics/yolo/cli.py b/ultralytics/yolo/cli.py index 5c05d11..395fb44 100644 --- a/ultralytics/yolo/cli.py +++ b/ultralytics/yolo/cli.py @@ -1,18 +1,17 @@ # Ultralytics YOLO 🚀, GPL-3.0 license +import argparse import shutil from pathlib import Path -import hydra +from hydra import compose, initialize from ultralytics import hub, yolo -from ultralytics.yolo.configs import get_config -from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr +from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, PREFIX, print_settings, yaml_load DIR = Path(__file__).parent -@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), config_name=DEFAULT_CONFIG.name) def cli(cfg): """ Run a specified task and mode with the given configuration. @@ -21,21 +20,13 @@ def cli(cfg): cfg (DictConfig): Configuration for the task and mode. """ # LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") + from ultralytics.yolo.configs import get_config + if cfg.cfg: - LOGGER.info(f"Overriding default config with {cfg.cfg}") + LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}") cfg = get_config(cfg.cfg) task, mode = cfg.task.lower(), cfg.mode.lower() - # Special case for initializing the configuration - if task == "init": - shutil.copy2(DEFAULT_CONFIG, Path.cwd()) - LOGGER.info(f""" - {colorstr("YOLO:")} configuration saved to {Path.cwd() / DEFAULT_CONFIG.name}. - To run experiments using custom configuration: - yolo cfg=config_file.yaml - """) - return - # Mapping from task to module task_module_map = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify} module = task_module_map.get(task) @@ -47,10 +38,66 @@ def cli(cfg): "train": module.train, "val": module.val, "predict": module.predict, - "export": yolo.engine.exporter.export, - "checks": hub.checks} + "export": yolo.engine.exporter.export} func = mode_func_map.get(mode) if not func: raise SyntaxError(f"mode not recognized. Choices are {', '.join(mode_func_map.keys())}") func(cfg) + + +def entrypoint(): + """ + This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed + to the package. It's a combination of argparse and hydra. + + This function allows for: + - passing mandatory YOLO args as a list of strings + - specifying the task to be performed, either 'detect', 'segment' or 'classify' + - specifying the mode, either 'train', 'val', 'test', or 'predict' + - running special modes like 'checks' + - passing overrides to the package's configuration + + It uses the package's default config and initializes it using the passed overrides. + Then it calls the CLI function with the composed config + """ + parser = argparse.ArgumentParser(description='YOLO parser') + parser.add_argument('args', type=str, nargs='+', help='YOLO args') + args = parser.parse_args().args + + tasks = 'detect', 'segment', 'classify' + modes = 'train', 'val', 'predict', 'export' + special_modes = { + 'checks': hub.checks, + 'help': lambda: LOGGER.info(HELP_MSG), + 'settings': print_settings, + 'copy-config': copy_default_config} + + overrides = [] # basic overrides, i.e. imgsz=320 + defaults = yaml_load(DEFAULT_CONFIG) + for a in args: + if '=' in a: + overrides.append(a) + elif a in tasks: + overrides.append(f'task={a}') + elif a in modes: + overrides.append(f'mode={a}') + elif a in special_modes: + special_modes[a]() + return + elif a in defaults and defaults[a] is False: + overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show + else: + raise (SyntaxError(f"'{a}' is not a valid yolo argument\n{HELP_MSG}")) + + with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"): + cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides) + cli(cfg) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_config(): + new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml') + shutil.copy2(DEFAULT_CONFIG, new_file) + LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n" + f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...") diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 7096794..af951f8 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -160,7 +160,7 @@ class BasePredictor: return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one def predict_cli(self): - # Method used for cli prediction. It uses always generator as outputs as not required by cli mode + # Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode gen = self.stream_inference(verbose=True) for _ in gen: # running CLI inference without accumulating any outputs (do not modify) pass diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index c808332..b0d6308 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -10,6 +10,7 @@ import tempfile import threading import uuid from pathlib import Path +from typing import Union import cv2 import git @@ -41,12 +42,15 @@ HELP_MSG = \ from ultralytics import YOLO - model = YOLO('yolov8n.yaml') # build a new model from scratch - model = YOLO('yolov8n.pt') # load a pretrained model (recommended for best training results) - results = model.train(data='coco128.yaml') # train the model - results = model.val() # evaluate model performance on the validation set - results = model.predict(source='bus.jpg') # predict on an image - success = model.export(format='onnx') # export the model to ONNX format + # Load a model + model = YOLO("yolov8n.yaml") # build a new model from scratch + model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) + + # Use the model + results = model.train(data="coco128.yaml", epochs=3) # train the model + results = model.val() # evaluate model performance on the validation set + results = model("https://ultralytics.com/images/bus.jpg") # predict on an image + success = model.export(format="onnx") # export the model to ONNX format 3. Use the command line interface (CLI): @@ -161,12 +165,12 @@ def is_pip_package(filepath: str = __name__) -> bool: return spec is not None and spec.origin is not None -def is_dir_writeable(dir_path: str) -> bool: +def is_dir_writeable(dir_path: Union[str, Path]) -> bool: """ Check if a directory is writeable. Args: - dir_path (str): The path to the directory. + dir_path (str) or (Path): The path to the directory. Returns: bool: True if the directory is writeable, False otherwise. @@ -179,6 +183,18 @@ def is_dir_writeable(dir_path: str) -> bool: return False +def is_pytest_running(): + """ + Returns a boolean indicating if pytest is currently running or not + :return: True if pytest is running, False otherwise + """ + try: + import sys + return "pytest" in sys.modules + except ImportError: + return False + + def get_git_root_dir(): """ Determines whether the current file is part of a git repository and if so, returns the repository root directory. @@ -348,6 +364,17 @@ def yaml_load(file='data.yaml', append_filename=False): return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f) +def set_sentry(dsn=None): + """ + Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running. + """ + if dsn and not is_pytest_running(): + import sentry_sdk # noqa + + import ultralytics + sentry_sdk.init(dsn=dsn, traces_sample_rate=1.0, release=ultralytics.__version__, debug=False) + + def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): """ Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. @@ -364,8 +391,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): is_git = is_git_directory() # True if ultralytics installed via git root = get_git_root_dir() if is_git else Path() + datasets_root = (root.parent if (is_git and is_dir_writeable(root.parent)) else root).resolve() defaults = { - 'datasets_dir': str((root.parent if is_git else root) / 'datasets'), # default datasets directory. + 'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory. 'weights_dir': str(root / 'weights'), # default weights directory. 'runs_dir': str(root / 'runs'), # default runs directory. 'sync': True, # sync analytics to help with YOLO development @@ -393,6 +421,26 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): return settings +def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): + """ + Function that runs on a first-time ultralytics package installation to set up global settings and create necessary + directories. + """ + SETTINGS.update(kwargs) + yaml_save(file, SETTINGS) + + +def print_settings(): + """ + Function that prints Ultralytics settings + """ + import json + s = f'\n{PREFIX}Settings:\n' + s += json.dumps(SETTINGS, indent=2) + s += f"\n\nUpdate settings at {USER_CONFIG_DIR / 'settings.yaml'}" + LOGGER.info(s) + + # Run below code on utils init ----------------------------------------------------------------------------------------- # Set logger @@ -403,15 +451,7 @@ if platform.system() == 'Windows': setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging # Check first-install steps +PREFIX = colorstr("Ultralytics: ") SETTINGS = get_settings() DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory - - -def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): - """ - Function that runs on a first-time ultralytics package installation to set up global settings and create necessary - directories. - """ - SETTINGS.update(kwargs) - - yaml_save(file, SETTINGS) +set_sentry() diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py index c6ef3d0..05b65c7 100644 --- a/ultralytics/yolo/v8/__init__.py +++ b/ultralytics/yolo/v8/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra cli) +from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI) from ultralytics.yolo.v8 import classify, detect, segment __all__ = ["classify", "segment", "detect"]