diff --git a/README.md b/README.md
index a669a87..8ffef70 100644
--- a/README.md
+++ b/README.md
@@ -91,7 +91,7 @@ pip install ultralytics
YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command:
```bash
-yolo task=detect mode=predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
+yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
```
`yolo` can be used for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list
@@ -158,10 +158,10 @@ See [Detection Docs](https://docs.ultralytics.com/tasks/detection/) for usage ex
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt) | 640 | 53.9 | 479.1 | 3.53 | 68.2 | 257.8 |
- **mAPval** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset.
-
Reproduce by `yolo mode=val task=detect data=coco.yaml device=0`
+
Reproduce by `yolo val detect data=coco.yaml device=0`
- **Speed** averaged over COCO val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/)
instance.
-
Reproduce by `yolo mode=val task=detect data=coco128.yaml batch=1 device=0/cpu`
+
Reproduce by `yolo val detect data=coco128.yaml batch=1 device=0/cpu`
@@ -178,10 +178,10 @@ See [Segmentation Docs](https://docs.ultralytics.com/tasks/segmentation/) for us
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt) | 640 | 53.4 | 43.4 | 712.1 | 4.02 | 71.8 | 344.1 |
- **mAPval** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset.
-
Reproduce by `yolo mode=val task=segment data=coco.yaml device=0`
+
Reproduce by `yolo val segment data=coco.yaml device=0`
- **Speed** averaged over COCO val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/)
instance.
-
Reproduce by `yolo mode=val task=segment data=coco128-seg.yaml batch=1 device=0/cpu`
+
Reproduce by `yolo val segment data=coco128-seg.yaml batch=1 device=0/cpu`
@@ -198,10 +198,10 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classification/) fo
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** values are model accuracies on the [ImageNet](https://www.image-net.org/) dataset validation set.
-
Reproduce by `yolo mode=val task=classify data=path/to/ImageNet device=0`
+
Reproduce by `yolo val classify data=path/to/ImageNet device=0`
- **Speed** averaged over ImageNet val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/)
instance.
-
Reproduce by `yolo mode=val task=classify data=path/to/ImageNet batch=1 device=0/cpu`
+
Reproduce by `yolo val classify data=path/to/ImageNet batch=1 device=0/cpu`
diff --git a/README.zh-CN.md b/README.zh-CN.md
index f6e34e1..78eaa8d 100644
--- a/README.zh-CN.md
+++ b/README.zh-CN.md
@@ -67,7 +67,7 @@ pip install ultralytics
YOLOv8 可以直接在命令行界面(CLI)中使用 `yolo` 命令运行:
```bash
-yolo task=detect mode=predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
+yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
```
`yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/config/)的完整列表。
@@ -124,9 +124,9 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt) | 640 | 53.9 | 479.1 | 3.53 | 68.2 | 257.8 |
- **mAPval** 结果都在 [COCO val2017](http://cocodataset.org) 数据集上,使用单模型单尺度测试得到。
-
复现命令 `yolo mode=val task=detect data=coco.yaml device=0`
+
复现命令 `yolo val detect data=coco.yaml device=0`
- **推理速度**使用 COCO 验证集图片推理时间进行平均得到,测试环境使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例。
-
复现命令 `yolo mode=val task=detect data=coco128.yaml batch=1 device=0/cpu`
+
复现命令 `yolo val detect data=coco128.yaml batch=1 device=0/cpu`
@@ -141,9 +141,9 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt) | 640 | 53.4 | 43.4 | 712.1 | 4.02 | 71.8 | 344.1 |
- **mAPval** 结果都在 [COCO val2017](http://cocodataset.org) 数据集上,使用单模型单尺度测试得到。
-
复现命令 `yolo mode=val task=segment data=coco.yaml device=0`
+
复现命令 `yolo val segment data=coco.yaml device=0`
- **推理速度**使用 COCO 验证集图片推理时间进行平均得到,测试环境使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例。
-
复现命令 `yolo mode=val task=segment data=coco128-seg.yaml batch=1 device=0/cpu`
+
复现命令 `yolo val segment data=coco128-seg.yaml batch=1 device=0/cpu`
@@ -158,9 +158,9 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** 都在 [ImageNet](https://www.image-net.org/) 数据集上,使用单模型单尺度测试得到。
-
复现命令 `yolo mode=val task=classify data=path/to/ImageNet device=0`
+
复现命令 `yolo val classify data=path/to/ImageNet device=0`
- **推理速度**使用 ImageNet 验证集图片推理时间进行平均得到,测试环境使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例。
-
复现命令 `yolo mode=val task=classify data=path/to/ImageNet batch=1 device=0/cpu`
+
复现命令 `yolo val classify data=path/to/ImageNet batch=1 device=0/cpu`
diff --git a/docs/cli.md b/docs/cli.md
index 63e53f4..ff0a72a 100644
--- a/docs/cli.md
+++ b/docs/cli.md
@@ -4,37 +4,37 @@ YOLO command line interface is the easiest way to get started.
!!! tip "Syntax"
```bash
- yolo task=detect mode=train model=yolov8n.yaml epochs=1 ...
- ... ... ...
- segment predict yolov8n-seg.pt
- classify val yolov8n-cls.pt
+ yolo task=detect mode=train model=yolov8n.yaml args...
+ classify predict yolov8n-cls.yaml args...
+ segment val yolov8n-seg.yaml args...
+ export yolov8n.pt format=onnx args...
```
-The experiment arguments can be overridden directly by pass `arg=val` covered in the next section. You can run any
-supported task by setting `task` and `mode` in cli.
+The default arguments can be overridden directly by passing custom `arg=val` covered in the next section. You can run
+any supported task by setting `task` and `mode` in CLI.
=== "Training"
| | `task` | snippet |
|------------------|------------|------------------------------------------------------------|
- | Detection | `detect` |
yolo task=detect mode=train
|
- | Instance Segment | `segment` | yolo task=segment mode=train
|
- | Classification | `classify` | yolo task=classify mode=train
|
+ | Detection | `detect` | yolo detect train
|
+ | Instance Segment | `segment` | yolo segment train
|
+ | Classification | `classify` | yolo classify train
|
=== "Prediction"
| | `task` | snippet |
|------------------|------------|--------------------------------------------------------------|
- | Detection | `detect` | yolo task=detect mode=predict
|
- | Instance Segment | `segment` | yolo task=segment mode=predict
|
- | Classification | `classify` | yolo task=classify mode=predict
|
+ | Detection | `detect` | yolo detect predict
|
+ | Instance Segment | `segment` | yolo segment predict
|
+ | Classification | `classify` | yolo classify predict
|
=== "Validation"
| | `task` | snippet |
|------------------|------------|-----------------------------------------------------------|
- | Detection | `detect` | yolo task=detect mode=val
|
- | Instance Segment | `segment` | yolo task=segment mode=val
|
- | Classification | `classify` | yolo task=classify mode=val
|
+ | Detection | `detect` | yolo detect val
|
+ | Instance Segment | `segment` | yolo segment val
|
+ | Classification | `classify` | yolo classify val
|
!!! note ""
@@ -44,19 +44,19 @@ supported task by setting `task` and `mode` in cli.
## Overriding default config arguments
-All global default arguments can be overriden by simply passing them as arguments in the cli.
+Default arguments can be overriden by simply passing them as arguments in the CLI.
!!! tip ""
=== "Syntax"
```bash
- yolo task= ... mode= ... {++ arg=val ++}
+ yolo task mode arg=val...
```
=== "Example"
Perform detection training for `10 epochs` with `learning_rate` of `0.01`
```bash
- yolo task=detect mode=train {++ epochs=10 lr0=0.01 ++}
+ yolo detect train epochs=10 lr0=0.01
```
---
@@ -67,23 +67,19 @@ You can override config file entirely by passing a new file. You can create a co
current working dir as follows:
```bash
-yolo task=init
+yolo copy-config
```
-You can then use `cfg=name.yaml` command to pass the new config file
+You can then use `cfg=default_copy.yaml` command to pass the new config file along with any addition args:
```bash
-yolo cfg=default.yaml
+yolo cfg=default_copy.yaml args...
```
??? example
=== "Command"
```bash
- yolo task=init
- yolo cfg=default.yaml
+ yolo copy-config
+ yolo cfg=default_copy.yaml args...
```
- === "Results"
- TODO: add terminal output
-
-
diff --git a/docs/config.md b/docs/config.md
index 57078eb..09226d6 100644
--- a/docs/config.md
+++ b/docs/config.md
@@ -38,13 +38,13 @@ include train, val, and predict.
- Predict: The predict mode is used to make predictions with the model on new data. This mode is typically used in
production or when deploying the model to users.
-| Key | Value | Description |
-|--------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| task | `detect` | Set the task via CLI. See Tasks for all supported tasks like - `detect`, `segment`, `classify`.
- `init` is a special case that creates a copy of default.yaml configs to the current working dir |
-| mode | `train` | Set the mode via CLI. It can be `train`, `val`, `predict` |
-| resume | `False` | Resume last given task when set to `True`.
Resume from a given checkpoint is `model.pt` is passed |
-| model | null | Set the model. Format can differ for task type. Supports `model_name`, `model.yaml` & `model.pt` |
-| data | null | Set the data. Format can differ for task type. Supports `data.yaml`, `data_folder`, `dataset_name` |
+| Key | Value | Description |
+|--------|----------|--------------------------------------------------------------------------------------------------------|
+| task | `detect` | Set the task via CLI. See Tasks for all supported tasks like - `detect`, `segment`, `classify` |
+| mode | `train` | Set the mode via CLI. It can be `train`, `val`, `predict`, `export` |
+| resume | `False` | Resume last given task when set to `True`.
Resume from a given checkpoint is `model.pt` is passed |
+| model | null | Set the model. Format can differ for task type. Supports `model_name`, `model.yaml` & `model.pt` |
+| data | null | Set the data. Format can differ for task type. Supports `data.yaml`, `data_folder`, `dataset_name` |
### Training
@@ -197,6 +197,6 @@ it easier to debug and optimize the training process.
|-----------|---------|---------------------------------------------------------------------------------------------|
| project: | 'runs' | The project name |
| name: | 'exp' | The run name. `exp` gets automatically incremented if not specified, i.e, `exp`, `exp2` ... |
-| exist_ok: | `False` | ??? |
+| exist_ok: | `False` | Will replace current directory contents if set to True and output directory exists. |
| plots | `False` | **Validation**: Save plots while validation |
-| nosave | `False` | Don't save any plots, models or files |
\ No newline at end of file
+| save | `False` | Save any plots, models or files |
\ No newline at end of file
diff --git a/docs/home.md b/docs/index.md
similarity index 100%
rename from docs/home.md
rename to docs/index.md
diff --git a/docs/quickstart.md b/docs/quickstart.md
index 4c4684b..ac5e791 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -36,11 +36,11 @@ CLI requires no customization or code. You can simply run all tasks from the ter
=== "Example training"
```bash
- yolo task=detect mode=train model=yolov8n.pt data=coco128.yaml device=0
+ yolo detect train model=yolov8n.pt data=coco128.yaml device=0
```
=== "Example Multi-GPU training"
```bash
- yolo task=detect mode=train model=yolov8n.pt data=coco128.yaml device=\'0,1,2,3\'
+ yolo detect train model=yolov8n.pt data=coco128.yaml device=\'0,1,2,3\'
```
[CLI Guide](cli.md){ .md-button .md-button--primary}
diff --git a/docs/tasks/classification.md b/docs/tasks/classification.md
index 153d7cd..a6e7a99 100644
--- a/docs/tasks/classification.md
+++ b/docs/tasks/classification.md
@@ -35,7 +35,7 @@ see the [Configuration](../config.md) page.
=== "CLI"
```bash
- yolo task=classify mode=train data=mnist160 model=yolov8n-cls.pt epochs=100 imgsz=64
+ yolo classify train data=mnist160 model=yolov8n-cls.pt epochs=100 imgsz=64
```
## Val
@@ -60,8 +60,8 @@ it's training `data` and arguments as model attributes.
=== "CLI"
```bash
- yolo task=classify mode=val model=yolov8n-cls.pt # val official model
- yolo task=classify mode=val model=path/to/best.pt # val custom model
+ yolo classify val model=yolov8n-cls.pt # val official model
+ yolo classify val model=path/to/best.pt # val custom model
```
## Predict
@@ -85,8 +85,8 @@ Use a trained YOLOv8n-cls model to run predictions on images.
=== "CLI"
```bash
- yolo task=classify mode=predict model=yolov8n-cls.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model
- yolo task=classify mode=predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model
+ yolo classify predict model=yolov8n-cls.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model
+ yolo classify predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model
```
## Export
@@ -110,8 +110,8 @@ Export a YOLOv8n-cls model to a different format like ONNX, CoreML, etc.
=== "CLI"
```bash
- yolo mode=export model=yolov8n-cls.pt format=onnx # export official model
- yolo mode=export model=path/to/best.pt format=onnx # export custom trained model
+ yolo export model=yolov8n-cls.pt format=onnx # export official model
+ yolo export model=path/to/best.pt format=onnx # export custom trained model
```
Available YOLOv8-cls export formats include:
diff --git a/docs/tasks/detection.md b/docs/tasks/detection.md
index e208239..4a7df4b 100644
--- a/docs/tasks/detection.md
+++ b/docs/tasks/detection.md
@@ -35,7 +35,7 @@ the [Configuration](../config.md) page.
=== "CLI"
```bash
- yolo task=detect mode=train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640
+ yolo detect train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640
```
## Val
@@ -60,8 +60,8 @@ training `data` and arguments as model attributes.
=== "CLI"
```bash
- yolo task=detect mode=val model=yolov8n.pt # val official model
- yolo task=detect mode=val model=path/to/best.pt # val custom model
+ yolo detect val model=yolov8n.pt # val official model
+ yolo detect val model=path/to/best.pt # val custom model
```
## Predict
@@ -85,8 +85,8 @@ Use a trained YOLOv8n model to run predictions on images.
=== "CLI"
```bash
- yolo task=detect mode=predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model
- yolo task=detect mode=predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model
+ yolo detect predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model
+ yolo detect predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model
```
## Export
@@ -110,8 +110,8 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc.
=== "CLI"
```bash
- yolo mode=export model=yolov8n.pt format=onnx # export official model
- yolo mode=export model=path/to/best.pt format=onnx # export custom trained model
+ yolo export model=yolov8n.pt format=onnx # export official model
+ yolo export model=path/to/best.pt format=onnx # export custom trained model
```
Available YOLOv8 export formats include:
diff --git a/docs/tasks/segmentation.md b/docs/tasks/segmentation.md
index 5b91dae..977819e 100644
--- a/docs/tasks/segmentation.md
+++ b/docs/tasks/segmentation.md
@@ -35,7 +35,7 @@ arguments see the [Configuration](../config.md) page.
=== "CLI"
```bash
- yolo task=segment mode=train data=coco128-seg.yaml model=yolov8n-seg.pt epochs=100 imgsz=640
+ yolo segment train data=coco128-seg.yaml model=yolov8n-seg.pt epochs=100 imgsz=640
```
## Val
@@ -60,8 +60,8 @@ retains it's training `data` and arguments as model attributes.
=== "CLI"
```bash
- yolo task=segment mode=val model=yolov8n-seg.pt # val official model
- yolo task=segment mode=val model=path/to/best.pt # val custom model
+ yolo segment val model=yolov8n-seg.pt # val official model
+ yolo segment val model=path/to/best.pt # val custom model
```
## Predict
@@ -85,8 +85,8 @@ Use a trained YOLOv8n-seg model to run predictions on images.
=== "CLI"
```bash
- yolo task=segment mode=predict model=yolov8n-seg.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model
- yolo task=segment mode=predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model
+ yolo segment predict model=yolov8n-seg.pt source="https://ultralytics.com/images/bus.jpg" # predict with official model
+ yolo segment predict model=path/to/best.pt source="https://ultralytics.com/images/bus.jpg" # predict with custom model
```
## Export
@@ -110,8 +110,8 @@ Export a YOLOv8n-seg model to a different format like ONNX, CoreML, etc.
=== "CLI"
```bash
- yolo mode=export model=yolov8n-seg.pt format=onnx # export official model
- yolo mode=export model=path/to/best.pt format=onnx # export custom trained model
+ yolo export model=yolov8n-seg.pt format=onnx # export official model
+ yolo export model=path/to/best.pt format=onnx # export custom trained model
```
Available YOLOv8-seg export formats include:
diff --git a/mkdocs.yml b/mkdocs.yml
index 1eb2a17..249701c 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -75,18 +75,18 @@ plugins:
# Primary navigation
nav:
- - Home: home.md
+ - Home: index.md
- Quickstart: quickstart.md
- Tasks:
- Detection: tasks/detection.md
- Segmentation: tasks/segmentation.md
- Classification: tasks/classification.md
- Usage:
- - CLI: cli.md
- - Python: python.md
- - Predict: predict.md
- - Configuration: config.md
- - Customization Guide: engine.md
+ - CLI: cli.md
+ - Python: python.md
+ - Predict: predict.md
+ - Configuration: config.md
+ - Customization Guide: engine.md
- Ultralytics HUB: hub.md
- iOS and Android App: app.md
- Reference:
@@ -99,4 +99,3 @@ nav:
- Results: reference/results.md
- ultralytics.nn: reference/nn.md
- Operations: reference/ops.md
- - Docs: README.md
diff --git a/setup.py b/setup.py
index 590159c..109ce5e 100644
--- a/setup.py
+++ b/setup.py
@@ -11,6 +11,7 @@ FILE = Path(__file__).resolve()
PARENT = FILE.parent # root directory
README = (PARENT / "README.md").read_text(encoding="utf-8")
REQUIREMENTS = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements((PARENT / 'requirements.txt').read_text())]
+PKG_REQUIREMENTS = ['sentry_sdk'] # pip-only requirements
def get_version():
@@ -35,7 +36,7 @@ setup(
author_email='hello@ultralytics.com',
packages=find_packages(), # required
include_package_data=True,
- install_requires=REQUIREMENTS,
+ install_requires=REQUIREMENTS + PKG_REQUIREMENTS,
extras_require={
'dev':
['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material']},
@@ -49,4 +50,5 @@ setup(
"Topic :: Scientific/Engineering :: Image Recognition", "Operating System :: POSIX :: Linux",
"Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
- entry_points={'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli']})
+ entry_points={
+ 'console_scripts': ['yolo = ultralytics.yolo.cli:entrypoint', 'ultralytics = ultralytics.yolo.cli:entrypoint']})
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 95daa04..65dc367 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -14,30 +14,32 @@ def run(cmd):
subprocess.run(cmd.split(), check=True)
-def test_checks():
- run('yolo mode=checks')
+def test_special_modes():
+ run('yolo checks')
+ run('yolo settings')
+ run('yolo help')
# Train checks ---------------------------------------------------------------------------------------------------------
def test_train_det():
- run(f'yolo mode=train task=detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1')
+ run(f'yolo train detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1')
def test_train_seg():
- run(f'yolo mode=train task=segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1')
+ run(f'yolo train segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1')
def test_train_cls():
- run(f'yolo mode=train task=classify model={CFG}-cls.yaml data=mnist160 imgsz=32 epochs=1')
+ run(f'yolo train classify model={CFG}-cls.yaml data=mnist160 imgsz=32 epochs=1')
# Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect():
- run(f'yolo mode=val task=detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1')
+ run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1')
def test_val_segment():
- run(f'yolo mode=val task=segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1')
+ run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1')
def test_val_classify():
@@ -46,11 +48,11 @@ def test_val_classify():
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
- run(f"yolo mode=predict task=detect model={MODEL}.pt source={ROOT / 'assets'}")
+ run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'}")
def test_predict_segment():
- run(f"yolo mode=predict task=segment model={MODEL}-seg.pt source={ROOT / 'assets'}")
+ run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'}")
def test_predict_classify():
@@ -59,12 +61,12 @@ def test_predict_classify():
# Export checks --------------------------------------------------------------------------------------------------------
def test_export_detect_torchscript():
- run(f'yolo mode=export model={MODEL}.pt format=torchscript')
+ run(f'yolo export model={MODEL}.pt format=torchscript')
def test_export_segment_torchscript():
- run(f'yolo mode=export model={MODEL}-seg.pt format=torchscript')
+ run(f'yolo export model={MODEL}-seg.pt format=torchscript')
def test_export_classify_torchscript():
- run(f'yolo mode=export model={MODEL}-cls.pt format=torchscript')
+ run(f'yolo export model={MODEL}-cls.pt format=torchscript')
diff --git a/ultralytics/yolo/cli.py b/ultralytics/yolo/cli.py
index 5c05d11..395fb44 100644
--- a/ultralytics/yolo/cli.py
+++ b/ultralytics/yolo/cli.py
@@ -1,18 +1,17 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
+import argparse
import shutil
from pathlib import Path
-import hydra
+from hydra import compose, initialize
from ultralytics import hub, yolo
-from ultralytics.yolo.configs import get_config
-from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr
+from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, PREFIX, print_settings, yaml_load
DIR = Path(__file__).parent
-@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), config_name=DEFAULT_CONFIG.name)
def cli(cfg):
"""
Run a specified task and mode with the given configuration.
@@ -21,21 +20,13 @@ def cli(cfg):
cfg (DictConfig): Configuration for the task and mode.
"""
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
+ from ultralytics.yolo.configs import get_config
+
if cfg.cfg:
- LOGGER.info(f"Overriding default config with {cfg.cfg}")
+ LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
cfg = get_config(cfg.cfg)
task, mode = cfg.task.lower(), cfg.mode.lower()
- # Special case for initializing the configuration
- if task == "init":
- shutil.copy2(DEFAULT_CONFIG, Path.cwd())
- LOGGER.info(f"""
- {colorstr("YOLO:")} configuration saved to {Path.cwd() / DEFAULT_CONFIG.name}.
- To run experiments using custom configuration:
- yolo cfg=config_file.yaml
- """)
- return
-
# Mapping from task to module
task_module_map = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
module = task_module_map.get(task)
@@ -47,10 +38,66 @@ def cli(cfg):
"train": module.train,
"val": module.val,
"predict": module.predict,
- "export": yolo.engine.exporter.export,
- "checks": hub.checks}
+ "export": yolo.engine.exporter.export}
func = mode_func_map.get(mode)
if not func:
raise SyntaxError(f"mode not recognized. Choices are {', '.join(mode_func_map.keys())}")
func(cfg)
+
+
+def entrypoint():
+ """
+ This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
+ to the package. It's a combination of argparse and hydra.
+
+ This function allows for:
+ - passing mandatory YOLO args as a list of strings
+ - specifying the task to be performed, either 'detect', 'segment' or 'classify'
+ - specifying the mode, either 'train', 'val', 'test', or 'predict'
+ - running special modes like 'checks'
+ - passing overrides to the package's configuration
+
+ It uses the package's default config and initializes it using the passed overrides.
+ Then it calls the CLI function with the composed config
+ """
+ parser = argparse.ArgumentParser(description='YOLO parser')
+ parser.add_argument('args', type=str, nargs='+', help='YOLO args')
+ args = parser.parse_args().args
+
+ tasks = 'detect', 'segment', 'classify'
+ modes = 'train', 'val', 'predict', 'export'
+ special_modes = {
+ 'checks': hub.checks,
+ 'help': lambda: LOGGER.info(HELP_MSG),
+ 'settings': print_settings,
+ 'copy-config': copy_default_config}
+
+ overrides = [] # basic overrides, i.e. imgsz=320
+ defaults = yaml_load(DEFAULT_CONFIG)
+ for a in args:
+ if '=' in a:
+ overrides.append(a)
+ elif a in tasks:
+ overrides.append(f'task={a}')
+ elif a in modes:
+ overrides.append(f'mode={a}')
+ elif a in special_modes:
+ special_modes[a]()
+ return
+ elif a in defaults and defaults[a] is False:
+ overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
+ else:
+ raise (SyntaxError(f"'{a}' is not a valid yolo argument\n{HELP_MSG}"))
+
+ with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
+ cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
+ cli(cfg)
+
+
+# Special modes --------------------------------------------------------------------------------------------------------
+def copy_default_config():
+ new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
+ shutil.copy2(DEFAULT_CONFIG, new_file)
+ LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
+ f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py
index 7096794..af951f8 100644
--- a/ultralytics/yolo/engine/predictor.py
+++ b/ultralytics/yolo/engine/predictor.py
@@ -160,7 +160,7 @@ class BasePredictor:
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
def predict_cli(self):
- # Method used for cli prediction. It uses always generator as outputs as not required by cli mode
+ # Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference(verbose=True)
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass
diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py
index c808332..b0d6308 100644
--- a/ultralytics/yolo/utils/__init__.py
+++ b/ultralytics/yolo/utils/__init__.py
@@ -10,6 +10,7 @@ import tempfile
import threading
import uuid
from pathlib import Path
+from typing import Union
import cv2
import git
@@ -41,12 +42,15 @@ HELP_MSG = \
from ultralytics import YOLO
- model = YOLO('yolov8n.yaml') # build a new model from scratch
- model = YOLO('yolov8n.pt') # load a pretrained model (recommended for best training results)
- results = model.train(data='coco128.yaml') # train the model
- results = model.val() # evaluate model performance on the validation set
- results = model.predict(source='bus.jpg') # predict on an image
- success = model.export(format='onnx') # export the model to ONNX format
+ # Load a model
+ model = YOLO("yolov8n.yaml") # build a new model from scratch
+ model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
+
+ # Use the model
+ results = model.train(data="coco128.yaml", epochs=3) # train the model
+ results = model.val() # evaluate model performance on the validation set
+ results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
+ success = model.export(format="onnx") # export the model to ONNX format
3. Use the command line interface (CLI):
@@ -161,12 +165,12 @@ def is_pip_package(filepath: str = __name__) -> bool:
return spec is not None and spec.origin is not None
-def is_dir_writeable(dir_path: str) -> bool:
+def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
"""
Check if a directory is writeable.
Args:
- dir_path (str): The path to the directory.
+ dir_path (str) or (Path): The path to the directory.
Returns:
bool: True if the directory is writeable, False otherwise.
@@ -179,6 +183,18 @@ def is_dir_writeable(dir_path: str) -> bool:
return False
+def is_pytest_running():
+ """
+ Returns a boolean indicating if pytest is currently running or not
+ :return: True if pytest is running, False otherwise
+ """
+ try:
+ import sys
+ return "pytest" in sys.modules
+ except ImportError:
+ return False
+
+
def get_git_root_dir():
"""
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
@@ -348,6 +364,17 @@ def yaml_load(file='data.yaml', append_filename=False):
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
+def set_sentry(dsn=None):
+ """
+ Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
+ """
+ if dsn and not is_pytest_running():
+ import sentry_sdk # noqa
+
+ import ultralytics
+ sentry_sdk.init(dsn=dsn, traces_sample_rate=1.0, release=ultralytics.__version__, debug=False)
+
+
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
"""
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
@@ -364,8 +391,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
is_git = is_git_directory() # True if ultralytics installed via git
root = get_git_root_dir() if is_git else Path()
+ datasets_root = (root.parent if (is_git and is_dir_writeable(root.parent)) else root).resolve()
defaults = {
- 'datasets_dir': str((root.parent if is_git else root) / 'datasets'), # default datasets directory.
+ 'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
'weights_dir': str(root / 'weights'), # default weights directory.
'runs_dir': str(root / 'runs'), # default runs directory.
'sync': True, # sync analytics to help with YOLO development
@@ -393,6 +421,26 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
return settings
+def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
+ """
+ Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
+ directories.
+ """
+ SETTINGS.update(kwargs)
+ yaml_save(file, SETTINGS)
+
+
+def print_settings():
+ """
+ Function that prints Ultralytics settings
+ """
+ import json
+ s = f'\n{PREFIX}Settings:\n'
+ s += json.dumps(SETTINGS, indent=2)
+ s += f"\n\nUpdate settings at {USER_CONFIG_DIR / 'settings.yaml'}"
+ LOGGER.info(s)
+
+
# Run below code on utils init -----------------------------------------------------------------------------------------
# Set logger
@@ -403,15 +451,7 @@ if platform.system() == 'Windows':
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
# Check first-install steps
+PREFIX = colorstr("Ultralytics: ")
SETTINGS = get_settings()
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
-
-
-def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
- """
- Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
- directories.
- """
- SETTINGS.update(kwargs)
-
- yaml_save(file, SETTINGS)
+set_sentry()
diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py
index c6ef3d0..05b65c7 100644
--- a/ultralytics/yolo/v8/__init__.py
+++ b/ultralytics/yolo/v8/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
-from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra cli)
+from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
from ultralytics.yolo.v8 import classify, detect, segment
__all__ = ["classify", "segment", "detect"]