From c6985da9de9d17d825a76a6e3a8e9ece5d094824 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 17 Jan 2023 19:02:34 +0530 Subject: [PATCH] New YOLOv8 `Results()` class for prediction outputs (#314) Signed-off-by: dependabot[bot] Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Viet Nhat Thai <60825385+vietnhatthai@users.noreply.github.com> Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com> --- .github/workflows/ci.yaml | 2 +- .github/workflows/cla.yml | 2 +- MANIFEST.in | 1 + README.md | 10 +- README.zh-CN.md | 2 +- docs/README.md | 2 +- docs/cli.md | 2 +- docs/{index.md => home.md} | 0 docs/predict.md | 72 +++++ docs/python.md | 72 +++-- docs/reference/results.md | 11 + examples/tutorial.ipynb | 10 +- mkdocs.yml | 9 +- setup.py | 13 +- tests/test_cli.py | 29 +- tests/test_engine.py | 12 +- tests/test_python.py | 23 ++ ultralytics/__init__.py | 2 +- ultralytics/nn/autobackend.py | 7 +- ultralytics/nn/{results.py => autoshape.py} | 0 ultralytics/nn/tasks.py | 8 +- .../yolo/data/dataloaders/stream_loaders.py | 62 +++- ultralytics/yolo/engine/model.py | 23 +- ultralytics/yolo/engine/predictor.py | 122 +++++--- ultralytics/yolo/engine/results.py | 284 ++++++++++++++++++ ultralytics/yolo/engine/trainer.py | 6 +- ultralytics/yolo/utils/__init__.py | 13 +- ultralytics/yolo/utils/ops.py | 150 +++++---- ultralytics/yolo/v8/classify/predict.py | 25 +- ultralytics/yolo/v8/classify/train.py | 2 + ultralytics/yolo/v8/detect/predict.py | 38 +-- ultralytics/yolo/v8/segment/predict.py | 64 ++-- 32 files changed, 816 insertions(+), 262 deletions(-) rename docs/{index.md => home.md} (100%) create mode 100644 docs/predict.md create mode 100644 docs/reference/results.md rename ultralytics/nn/{results.py => autoshape.py} (100%) create mode 100644 ultralytics/yolo/engine/results.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c1b726a..788ffa7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -104,4 +104,4 @@ jobs: yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript - name: Pytest tests shell: bash # for Windows compatibility - run: pytest tests \ No newline at end of file + run: pytest tests diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index 89183e8..f699260 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -18,7 +18,7 @@ jobs: steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I sign the CLA') || github.event_name == 'pull_request_target' - uses: contributor-assistant/github-action@v2.2.0 + uses: contributor-assistant/github-action@v2.2.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # must be repository secret token diff --git a/MANIFEST.in b/MANIFEST.in index 1635ec1..f37cd80 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,3 +3,4 @@ include requirements.txt include LICENSE include setup.py recursive-include ultralytics *.yaml +recursive-exclude __pycache__ * diff --git a/README.md b/README.md index 136aa80..a669a87 100644 --- a/README.md +++ b/README.md @@ -56,14 +56,14 @@ To request an Enterprise License please complete the form at [Ultralytics Licens
-[Ultralytics Live Session 3](https://youtu.be/IPcpYO5ITa8) ✨ is here! Join us on January 18th at 18 CET as we dive into the latest advancements in YOLOv8, and demonstrate how to use this cutting-edge, SOTA model to improve your object detection, instance segmentation, and image classification projects. See firsthand how YOLOv8's speed, accuracy, and ease of use make it a top choice for professionals and researchers alike. +[Ultralytics Live Session 3](https://youtu.be/IPcpYO5ITa8) ✨ is here! Join us on January 24th at 18 CET as we dive into the latest advancements in YOLOv8, and demonstrate how to use this cutting-edge, SOTA model to improve your object detection, instance segmentation, and image classification projects. See firsthand how YOLOv8's speed, accuracy, and ease of use make it a top choice for professionals and researchers alike. -In addition to learning about the exciting new features and improvements of Ultralytics YOLOv8, you will also have the opportunity to ask questions and interact with our team during the live Q&A session. We encourage all of you to come prepared with any questions you may have. +In addition to learning about the exciting new features and improvements of Ultralytics YOLOv8, you will also have the opportunity to ask questions and interact with our team during the live Q&A session. We encourage you to come prepared with any questions you may have. -Don't miss out on this opportunity! To join the webinar, visit our YouTube [Channel](https://www.youtube.com/@Ultralytics/streams) and turn on your notifications! +To join the webinar, visit our YouTube [Channel](https://www.youtube.com/@Ultralytics/streams) and turn on your notifications! - +
##
Documentation
@@ -76,7 +76,7 @@ documentation on training, validation, prediction and deployment. Pip install the ultralytics package including all [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) in a -[**Python>=3.7.0**](https://www.python.org/) environment, including +[**3.10>=Python>=3.7**](https://www.python.org/) environment, including [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/). ```bash diff --git a/README.zh-CN.md b/README.zh-CN.md index aae97b9..f6e34e1 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -53,7 +53,7 @@
安装 -Pip 安装包含所有 [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) 的 ultralytics 包,环境要求 [**Python>=3.7.0**](https://www.python.org/),且 [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/)。 +Pip 安装包含所有 [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) 的 ultralytics 包,环境要求 [**3.10>=Python>=3.7**](https://www.python.org/),且 [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/)。 ```bash pip install ultralytics diff --git a/docs/README.md b/docs/README.md index 8c8981b..3b3e306 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ # Ultralytics Docs -Deployed to https://docs.ultralytics.com +Ultralytics Docs are deployed to [https://docs.ultralytics.com](https://docs.ultralytics.com). ### Install Ultralytics package diff --git a/docs/cli.md b/docs/cli.md index 520efa2..63e53f4 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -83,7 +83,7 @@ yolo cfg=default.yaml yolo task=init yolo cfg=default.yaml ``` - === "Result" + === "Results" TODO: add terminal output diff --git a/docs/index.md b/docs/home.md similarity index 100% rename from docs/index.md rename to docs/home.md diff --git a/docs/predict.md b/docs/predict.md new file mode 100644 index 0000000..ba47f76 --- /dev/null +++ b/docs/predict.md @@ -0,0 +1,72 @@ +Inference or prediction of a task returns a list of `Results` objects. Alternatively, in the streaming mode, it returns a generator of `Results` objects which is memory efficient. Streaming mode can be enabled by passing `stream=True` in predictor's call method. + +!!! example "Predict" + === "Getting a List" + ```python + inputs = [img, img] # list of np arrays + results = model(inputs) # List of Results objects + for result in results: + boxes = results.boxes # Boxes object for bbox outputs + masks = results.masks # Masks object for segmenation masks outputs + probs = results.probs # Class probabilities for classification outputs + ... + ``` + === "Getting a Generator" + ```python + inputs = [img, img] # list of np arrays + results = model(inputs, stream="True") # Generator of Results objects + for result in results: + boxes = results.boxes # Boxes object for bbox outputs + masks = results.masks # Masks object for segmenation masks outputs + probs = results.probs # Class probabilities for classification outputs + ... + ``` + +## Working with Results + +Results object consists of these component objects: + +- `results.boxes` : It is an object of class `Boxes`. It has properties and methods for manipulating bboxes +- `results.masks` : It is an object of class `Masks`. It can be used to index masks or to get segment coordinates. +- `results.prob` : It is a `Tensor` object. It contains the class probabilities/logits. + +Each result is composed of torch.Tensor by default, in which you can easily use following functionality: +```python +results = results.cuda() +results = results.cpu() +results = results.to("cpu") +results = results.numpy() +``` +### Boxes +`Boxes` object can be used index, manipulate and convert bboxes to different formats. The box format conversion operations are cached, which means they're only calculated once per object and those values are reused for future calls. + +- Indexing a `Boxes` objects returns a `Boxes` object +```python +boxes = results.boxes +box = boxes[0] # returns one box +box.xyxy +``` +- Properties and conversions +``` +results.boxes.xyxy # box with xyxy format, (N, 4) +results.boxes.xywh # box with xywh format, (N, 4) +results.boxes.xyxyn # box with xyxy format but normalized, (N, 4) +results.boxes.xywhn # box with xywh format but normalized, (N, 4) +results.boxes.conf # confidence score, (N, 1) +results.boxes.cls # cls, (N, 1) +``` +### Masks +`Masks` object can be used index, manipulate and convert masks to segments. The segment conversion operation is cached. + +```python +results.masks.masks # masks, (N, H, W) +results.masks.segments # bounding coordinates of masks, List[segment] * N +``` + +### probs +`probs` attribute of `Results` class is a `Tensor` containing class probabilities of a classification operation. +```python +results.probs # cls prob, (num_class, ) +``` + +Class reference documentation for `Results` module and its components can be found [here](reference/results.md) \ No newline at end of file diff --git a/docs/python.md b/docs/python.md index 9722125..22a88d7 100644 --- a/docs/python.md +++ b/docs/python.md @@ -1,5 +1,4 @@ -This is the simplest way of simply using YOLOv8 models in a Python environment. It can be imported from -the `ultralytics` module. +The simplest way of simply using YOLOv8 directly in a Python environment. !!! example "Train" @@ -51,35 +50,60 @@ the `ultralytics` module. === "From source" ```python from ultralytics import YOLO + from PIL import Image + import cv2 model = YOLO("model.pt") - model.predict(source="0") # accepts all formats - img/folder/vid.*(mp4/format). 0 for webcam - model.predict(source="folder", show=True) # Display preds. Accepts all yolo predict arguments + # accepts all formats - image/dir/Path/URL/video/PIL/ndarray. 0 for webcam + results = model.predict(source="0") + results = model.predict(source="folder", show=True) # Display preds. Accepts all YOLO predict arguments - ``` + # from PIL + im1 = Image.open("bus.jpg") + results = model.predict(source=im1, save=True) # save plotted images - === "From image/ndarray/tensor" - ```python - # TODO, still working on it. - ``` + # from ndarray + im2 = cv2.imread("bus.jpg") + results = model.predict(source=im2, save=True, save_txt=True) # save predictions as labels + # from list of PIL/ndarray + results = model.predict(source=[im1, im2]) + ``` - === "Return outputs" + === "Results usage" ```python - from ultralytics import YOLO - - model = YOLO("model.pt") - outputs = model.predict(source="0", return_outputs=True) # treat predict as a Python generator - for output in outputs: - # each output here is a dict. - # for detection - print(output["det"]) # np.ndarray, (N, 6), xyxy, score, cls - # for segmentation - print(output["det"]) # np.ndarray, (N, 6), xyxy, score, cls - print(output["segment"]) # List[np.ndarray] * N, bounding coordinates of masks - # for classify - print(output["prob"]) # np.ndarray, (num_class, ), cls prob - + # results would be a list of Results object including all the predictions by default + # but be careful as it could occupy a lot memory when there're many images, + # especially the task is segmentation. + # 1. return as a list + results = model.predict(source="folder") + + # results would be a generator which is more friendly to memory by setting stream=True + # 2. return as a generator + results = model.predict(source=0, stream=True) + + for result in results: + # detection + result.boxes.xyxy # box with xyxy format, (N, 4) + result.boxes.xywh # box with xywh format, (N, 4) + result.boxes.xyxyn # box with xyxy format but normalized, (N, 4) + result.boxes.xywhn # box with xywh format but normalized, (N, 4) + result.boxes.conf # confidence score, (N, 1) + result.boxes.cls # cls, (N, 1) + + # segmentation + result.masks.masks # masks, (N, H, W) + result.masks.segments # bounding coordinates of masks, List[segment] * N + + # classification + result.probs # cls prob, (num_class, ) + + # Each result is composed of torch.Tensor by default, + # in which you can easily use following functionality: + result = result.cuda() + result = result.cpu() + result = result.to("cpu") + result = result.numpy() ``` !!! note "Export and Deployment" diff --git a/docs/reference/results.md b/docs/reference/results.md new file mode 100644 index 0000000..e222ec4 --- /dev/null +++ b/docs/reference/results.md @@ -0,0 +1,11 @@ +### Results API Reference + +:::ultralytics.yolo.engine.results.Results + +### Boxes API Reference + +:::ultralytics.yolo.engine.results.Boxes + +### Masks API Reference + +:::ultralytics.yolo.engine.results.Masks diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index 13dcf05..fd3e4ba 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -68,7 +68,7 @@ "import ultralytics\n", "ultralytics.checks()" ], - "execution_count": 1, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -145,7 +145,7 @@ }, "source": [ "        \n", - "" + "" ] }, { @@ -185,7 +185,7 @@ "# Validate YOLOv8n on COCO128 val\n", "!yolo task=detect mode=val model=yolov8n.pt data=coco128.yaml" ], - "execution_count": 2, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -310,7 +310,7 @@ "# Train YOLOv8n on COCO128 for 3 epochs\n", "!yolo task=detect mode=train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640" ], - "execution_count": 3, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -501,7 +501,7 @@ "id": "CYIjW4igCjqD", "outputId": "3bb45917-f90e-4951-959d-7bcd26680f2e" }, - "execution_count": 4, + "execution_count": null, "outputs": [ { "output_type": "stream", diff --git a/mkdocs.yml b/mkdocs.yml index 6add0f1..1eb2a17 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,7 +75,7 @@ plugins: # Primary navigation nav: - - Home: index.md + - Home: home.md - Quickstart: quickstart.md - Tasks: - Detection: tasks/detection.md @@ -84,6 +84,7 @@ nav: - Usage: - CLI: cli.md - Python: python.md + - Predict: predict.md - Configuration: config.md - Customization Guide: engine.md - Ultralytics HUB: hub.md @@ -95,5 +96,7 @@ nav: - Validator: reference/base_val.md - Predictor: reference/base_pred.md - Exporter: reference/exporter.md - - nn Module: reference/nn.md - - operations: reference/ops.md + - Results: reference/results.md + - ultralytics.nn: reference/nn.md + - Operations: reference/ops.md + - Docs: README.md diff --git a/setup.py b/setup.py index 05f25ad..590159c 100644 --- a/setup.py +++ b/setup.py @@ -15,22 +15,22 @@ REQUIREMENTS = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements((PARENT def get_version(): file = PARENT / 'ultralytics/__init__.py' - return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(), re.M)[1] + return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding="utf-8"), re.M)[1] setup( name="ultralytics", # name of pypi package version=get_version(), # version of pypi package - python_requires=">=3.7.0", + python_requires=">=3.7,<=3.11", license='GPL-3.0', - description='Ultralytics YOLOv8 and HUB', + description='Ultralytics YOLOv8', long_description=README, long_description_content_type="text/markdown", url="https://github.com/ultralytics/ultralytics", project_urls={ 'Bug Reports': 'https://github.com/ultralytics/ultralytics/issues', 'Funding': 'https://ultralytics.com', - 'Source': 'https://github.com/ultralytics/ultralytics',}, + 'Source': 'https://github.com/ultralytics/ultralytics'}, author="Ultralytics", author_email='hello@ultralytics.com', packages=find_packages(), # required @@ -38,7 +38,7 @@ setup( install_requires=REQUIREMENTS, extras_require={ 'dev': - ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material'],}, + ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material']}, classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Programming Language :: Python :: 3", @@ -49,5 +49,4 @@ setup( "Topic :: Scientific/Engineering :: Image Recognition", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", "Operating System :: Microsoft :: Windows"], keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics", - entry_points={ - 'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli'],}) + entry_points={'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli']}) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8259540..95daa04 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -import os +import subprocess from pathlib import Path from ultralytics.yolo.utils import ROOT, SETTINGS @@ -9,30 +9,35 @@ MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' CFG = 'yolov8n' +def run(cmd): + # Run a subprocess command with check=True + subprocess.run(cmd.split(), check=True) + + def test_checks(): - os.system('yolo mode=checks') + run('yolo mode=checks') # Train checks --------------------------------------------------------------------------------------------------------- def test_train_det(): - os.system(f'yolo mode=train task=detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1') + run(f'yolo mode=train task=detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1') def test_train_seg(): - os.system(f'yolo mode=train task=segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1') + run(f'yolo mode=train task=segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1') def test_train_cls(): - os.system(f'yolo mode=train task=classify model={CFG}-cls.yaml data=mnist160 imgsz=32 epochs=1') + run(f'yolo mode=train task=classify model={CFG}-cls.yaml data=mnist160 imgsz=32 epochs=1') # Val checks ----------------------------------------------------------------------------------------------------------- def test_val_detect(): - os.system(f'yolo mode=val task=detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1') + run(f'yolo mode=val task=detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1') def test_val_segment(): - os.system(f'yolo mode=val task=segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1') + run(f'yolo mode=val task=segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1') def test_val_classify(): @@ -41,11 +46,11 @@ def test_val_classify(): # Predict checks ------------------------------------------------------------------------------------------------------- def test_predict_detect(): - os.system(f"yolo mode=predict model={MODEL}.pt source={ROOT / 'assets'}") + run(f"yolo mode=predict task=detect model={MODEL}.pt source={ROOT / 'assets'}") def test_predict_segment(): - os.system(f"yolo mode=predict model={MODEL}-seg.pt source={ROOT / 'assets'}") + run(f"yolo mode=predict task=segment model={MODEL}-seg.pt source={ROOT / 'assets'}") def test_predict_classify(): @@ -54,12 +59,12 @@ def test_predict_classify(): # Export checks -------------------------------------------------------------------------------------------------------- def test_export_detect_torchscript(): - os.system(f'yolo mode=export model={MODEL}.pt format=torchscript') + run(f'yolo mode=export model={MODEL}.pt format=torchscript') def test_export_segment_torchscript(): - os.system(f'yolo mode=export model={MODEL}-seg.pt format=torchscript') + run(f'yolo mode=export model={MODEL}-seg.pt format=torchscript') def test_export_classify_torchscript(): - pass + run(f'yolo mode=export model={MODEL}-cls.pt format=torchscript') diff --git a/tests/test_engine.py b/tests/test_engine.py index e33811f..2755655 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -28,8 +28,8 @@ def test_detect(): # Predictor pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) - result = pred(source=SOURCE, model=f"{MODEL}.pt", return_outputs=True) - assert len(list(result)), "predictor test failed" + result = pred(source=SOURCE, model=f"{MODEL}.pt") + assert len(result), "predictor test failed" overrides["resume"] = trainer.last trainer = detect.DetectionTrainer(overrides=overrides) @@ -58,8 +58,8 @@ def test_segment(): # Predictor pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) - result = pred(source=SOURCE, model=f"{MODEL}-seg.pt", return_outputs=True) - assert len(list(result)) == 2, "predictor test failed" + result = pred(source=SOURCE, model=f"{MODEL}-seg.pt") + assert len(result) == 2, "predictor test failed" # Test resume overrides["resume"] = trainer.last @@ -90,5 +90,5 @@ def test_classify(): # Predictor pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]}) - result = pred(source=SOURCE, model=trainer.best, return_outputs=True) - assert len(list(result)) == 2, "predictor test failed" + result = pred(source=SOURCE, model=trainer.best) + assert len(result) == 2, "predictor test failed" diff --git a/tests/test_python.py b/tests/test_python.py index 7faa3e2..d553068 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -2,6 +2,10 @@ from pathlib import Path +import cv2 +import torch +from PIL import Image + from ultralytics import YOLO from ultralytics.yolo.utils import ROOT, SETTINGS @@ -35,6 +39,21 @@ def test_predict_dir(): model.predict(source=ROOT / "assets") +def test_predict_img(): + model = YOLO(MODEL) + img = Image.open(str(SOURCE)) + output = model(source=img, save=True, verbose=True) # PIL + assert len(output) == 1, "predict test failed" + img = cv2.imread(str(SOURCE)) + output = model(source=img, save=True, save_txt=True) # ndarray + assert len(output) == 1, "predict test failed" + output = model(source=[img, img], save=True, save_txt=True) # batch + assert len(output) == 2, "predict test failed" + tens = torch.zeros(320, 640, 3) + output = model(tens.numpy()) + assert len(output) == 1, "predict test failed" + + def test_val(): model = YOLO(MODEL) model.val(data="coco8.yaml", imgsz=32) @@ -106,3 +125,7 @@ def test_workflow(): model.val() model.predict(SOURCE) model.export(format="onnx", opset=12) # export a model to ONNX format + + +if __name__ == "__main__": + test_predict_img() diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index df2e7f5..6a1aeec 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = "8.0.6" +__version__ = "8.0.7" from ultralytics.hub import checks from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 6f68da2..f2f1b10 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -65,6 +65,7 @@ class AutoBackend(nn.Module): model = weights.to(device) model = model.fuse() if fuse else model names = model.module.names if hasattr(model, 'module') else model.names # get class names + stride = max(int(model.stride.max()), 32) # model stride model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() pt = True @@ -236,7 +237,7 @@ class AutoBackend(nn.Module): Runs inference on the YOLOv8 MultiBackend model. Args: - im (torch.tensor): The image tensor to perform inference on. + im (torch.Tensor): The image tensor to perform inference on. augment (bool): whether to perform data augmentation during inference, defaults to False visualize (bool): whether to visualize the output predictions, defaults to False @@ -328,10 +329,10 @@ class AutoBackend(nn.Module): Convert a numpy array to a tensor. Args: - x (numpy.ndarray): The array to be converted. + x (np.ndarray): The array to be converted. Returns: - (torch.tensor): The converted tensor + (torch.Tensor): The converted tensor """ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x diff --git a/ultralytics/nn/results.py b/ultralytics/nn/autoshape.py similarity index 100% rename from ultralytics/nn/results.py rename to ultralytics/nn/autoshape.py diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index c3ec574..6d16351 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -27,12 +27,12 @@ class BaseModel(nn.Module): Wrapper for `_forward_once` method. Args: - x (torch.tensor): The input image tensor + x (torch.Tensor): The input image tensor profile (bool): Whether to profile the model, defaults to False visualize (bool): Whether to return the intermediate feature maps, defaults to False Returns: - (torch.tensor): The output of the network. + (torch.Tensor): The output of the network. """ return self._forward_once(x, profile, visualize) @@ -41,12 +41,12 @@ class BaseModel(nn.Module): Perform a forward pass through the network. Args: - x (torch.tensor): The input tensor to the model + x (torch.Tensor): The input tensor to the model profile (bool): Print the computation time of each layer if True, defaults to False. visualize (bool): Save the feature maps of the model if True, defaults to False Returns: - (torch.tensor): The last output of the model. + (torch.Tensor): The last output of the model. """ y, dt = [], [] # outputs for m in self.model: diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py index 5d2373c..1637322 100644 --- a/ultralytics/yolo/data/dataloaders/stream_loaders.py +++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py @@ -11,10 +11,11 @@ from urllib.parse import urlparse import cv2 import numpy as np import torch +from PIL import Image from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS -from ultralytics.yolo.utils import LOGGER, is_colab, is_kaggle, ops +from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops from ultralytics.yolo.utils.checks import check_requirements @@ -36,7 +37,7 @@ class LoadStreams: if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc' check_requirements(('pafy', 'youtube_dl==2020.12.2')) - import pafy + import pafy # noqa s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam if s == 0: @@ -109,7 +110,7 @@ class LoadScreenshots: def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None): # source = [screen_number left top width height] (pixels) check_requirements('mss') - import mss + import mss # noqa source, *params = source.split() self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0 @@ -254,3 +255,58 @@ class LoadImages: def __len__(self): return self.nf # number of files + + +class LoadPilAndNumpy: + + def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None): + if not isinstance(im0, list): + im0 = [im0] + self.im0 = [self._single_check(im) for im in im0] + self.imgsz = imgsz + self.stride = stride + self.auto = auto + self.transforms = transforms + self.mode = 'image' + # generate fake paths + self.paths = [f"image{i}.jpg" for i in range(len(self.im0))] + + @staticmethod + def _single_check(im): + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): + im = np.asarray(im)[:, :, ::-1] + im = np.ascontiguousarray(im) # contiguous + return im + + def _single_preprocess(self, im, auto): + if self.transforms: + im = self.transforms(im) # transforms + else: + im = LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=im) + im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + im = np.ascontiguousarray(im) # contiguous + return im + + def __len__(self): + return len(self.im0) + + def __next__(self): + if self.count == 1: # loop only once as it's batch inference + raise StopIteration + auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto + im = [self._single_preprocess(im, auto) for im in self.im0] + im = np.stack(im, 0) if len(im) > 1 else im[0][None] + self.count += 1 + return self.paths, im, self.im0, None, '' + + def __iter__(self): + self.count = 0 + return self + + +if __name__ == "__main__": + img = cv2.imread(str(ROOT / "assets/bus.jpg")) + dataset = LoadPilAndNumpy(im0=img) + for d in dataset: + print(d[0]) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 18a5313..c7eff56 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -54,8 +54,8 @@ class YOLO: # Load or create new YOLO model {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) - def __call__(self, source, **kwargs): - return self.predict(source, **kwargs) + def __call__(self, source=None, stream=False, verbose=False, **kwargs): + return self.predict(source, stream, verbose, **kwargs) def _new(self, cfg: str, verbose=True): """ @@ -111,13 +111,20 @@ class YOLO: self.model.fuse() @smart_inference_mode() - def predict(self, source, return_outputs=False, **kwargs): + def predict(self, source=None, stream=False, verbose=False, **kwargs): """ - Visualize prediction. + Perform prediction using the YOLO model. Args: - source (str): Accepts all source types accepted by yolo - **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs + source (str | int | PIL | np.ndarray): The source of the image to make predictions on. + Accepts all source types accepted by the YOLO model. + stream (bool): Whether to stream the predictions or not. Defaults to False. + verbose (bool): Whether to print verbose information or not. Defaults to False. + **kwargs : Additional keyword arguments passed to the predictor. + Check the 'configuration' section in the documentation for all available options. + + Returns: + (dict): The prediction results. """ overrides = self.overrides.copy() overrides["conf"] = 0.25 @@ -127,8 +134,8 @@ class YOLO: predictor = self.PredictorClass(overrides=overrides) predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size - predictor.setup(model=self.model, source=source, return_outputs=return_outputs) - return predictor() if return_outputs else predictor.predict_cli() + predictor.setup(model=self.model, source=source) + return predictor(stream=stream, verbose=verbose) @smart_inference_mode() def val(self, data=None, **kwargs): diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 1ad0021..546091b 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -27,13 +27,14 @@ Usage - formats: """ import platform from collections import defaultdict +from itertools import chain from pathlib import Path import cv2 from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.configs import get_config -from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams +from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow @@ -89,7 +90,6 @@ class BasePredictor: self.vid_path, self.vid_writer = None, None self.annotator = None self.data_path = None - self.output = {} self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks callbacks.add_integration_callbacks(self) @@ -99,29 +99,18 @@ class BasePredictor: def get_annotator(self, img): raise NotImplementedError("get_annotator function needs to be implemented") - def write_results(self, pred, batch, print_string): + def write_results(self, results, batch, print_string): raise NotImplementedError("print_results function needs to be implemented") def postprocess(self, preds, img, orig_img): return preds - def setup(self, source=None, model=None, return_outputs=False): + def setup(self, source=None, model=None): # source - source = str(source if source is not None else self.args.source) - is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) - is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) - webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) - screenshot = source.lower().startswith('screen') - if is_url and is_file: - source = check_file(source) # download - + source, webcam, screenshot, from_img = self.check_source(source) # model - device = select_device(self.args.device) - model = model or self.args.model - self.args.half &= device.type != 'cpu' # half precision only supported on CUDA - model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) - stride, pt = model.stride, model.pt - imgsz = check_imgsz(self.args.imgsz, stride=stride) # check image size + stride, pt = self.setup_model(model) + imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size # Dataloader bs = 1 # batch_size @@ -131,7 +120,7 @@ class BasePredictor: imgsz=imgsz, stride=stride, auto=pt, - transforms=getattr(model.model, 'transforms', None), + transforms=getattr(self.model.model, 'transforms', None), vid_stride=self.args.vid_stride) bs = len(self.dataset) elif screenshot: @@ -139,32 +128,47 @@ class BasePredictor: imgsz=imgsz, stride=stride, auto=pt, - transforms=getattr(model.model, 'transforms', None)) + transforms=getattr(self.model.model, 'transforms', None)) + elif from_img: + self.dataset = LoadPilAndNumpy(source, + imgsz=imgsz, + stride=stride, + auto=pt, + transforms=getattr(self.model.model, 'transforms', None)) else: self.dataset = LoadImages(source, imgsz=imgsz, stride=stride, auto=pt, - transforms=getattr(model.model, 'transforms', None), + transforms=getattr(self.model.model, 'transforms', None), vid_stride=self.args.vid_stride) self.vid_path, self.vid_writer = [None] * bs, [None] * bs - model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup + self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz)) # warmup - self.model = model self.webcam = webcam self.screenshot = screenshot + self.from_img = from_img self.imgsz = imgsz self.done_setup = True - self.device = device - self.return_outputs = return_outputs - return model @smart_inference_mode() - def __call__(self, source=None, model=None, return_outputs=False): + def __call__(self, source=None, model=None, verbose=False, stream=False): + if stream: + return self.stream_inference(source, model, verbose) + else: + return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one + + def predict_cli(self): + # Method used for cli prediction. It uses always generator as outputs as not required by cli mode + gen = self.stream_inference(verbose=True) + for _ in gen: # running CLI inference without accumulating any outputs (do not modify) + pass + + def stream_inference(self, source=None, model=None, verbose=False): self.run_callbacks("on_predict_start") - model = self.model if self.done_setup else self.setup(source, model, return_outputs) - model.eval() + if not self.done_setup: + self.setup(source, model) self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) for batch in self.dataset: self.run_callbacks("on_predict_batch_start") @@ -177,17 +181,17 @@ class BasePredictor: # Inference with self.dt[1]: - preds = model(im, augment=self.args.augment, visualize=visualize) + preds = self.model(im, augment=self.args.augment, visualize=visualize) # postprocess with self.dt[2]: - preds = self.postprocess(preds, im, im0s) - + results = self.postprocess(preds, im, im0s) for i in range(len(im)): - if self.webcam: - path, im0s = path[i], im0s[i] - p = Path(path) - s += self.write_results(i, preds, (p, im, im0s)) + p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) + p = Path(p) + + if verbose or self.args.save or self.args.save_txt: + s += self.write_results(i, results, (p, im, im0)) if self.args.show: self.show(p) @@ -195,30 +199,50 @@ class BasePredictor: if self.args.save: self.save_preds(vid_cap, i, str(self.save_dir / p.name)) - if self.return_outputs: - yield self.output - self.output.clear() + yield results # Print time (inference-only) - LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") + if verbose: + LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") self.run_callbacks("on_predict_batch_end") # Print results - t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image - LOGGER.info( - f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}' - % t) + if verbose: + t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image + LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape ' + f'{(1, 3, *self.imgsz)}' % t) if self.args.save_txt or self.args.save: - s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' + s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" \ + if self.args.save_txt else '' LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") self.run_callbacks("on_predict_end") - def predict_cli(self, source=None, model=None, return_outputs=False): - # as __call__ is a generator now so have to treat it like a generator - for _ in (self.__call__(source, model, return_outputs)): - pass + def setup_model(self, model): + device = select_device(self.args.device) + model = model or self.args.model + self.args.half &= device.type != 'cpu' # half precision only supported on CUDA + model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) + self.model = model + self.device = device + self.model.eval() + return model.stride, model.pt + + def check_source(self, source): + source = source if source is not None else self.args.source + webcam, screenshot, from_img = False, False, False + if isinstance(source, (str, int, Path)): # int for local usb carame + source = str(source) + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') + if is_url and is_file: + source = check_file(source) # download + else: + from_img = True + return source, webcam, screenshot, from_img def show(self, p): im0 = self.annotator.result() diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py new file mode 100644 index 0000000..cf20e8d --- /dev/null +++ b/ultralytics/yolo/engine/results.py @@ -0,0 +1,284 @@ +from functools import lru_cache + +import numpy as np +import torch + +from ultralytics.yolo.utils import LOGGER, ops + + +class Results: + """ + A class for storing and manipulating inference results. + + Args: + boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. + masks (Masks, optional): A Masks object containing the detection masks. + probs (torch.Tensor, optional): A tensor containing the detection class probabilities. + orig_shape (tuple, optional): Original image size. + + Attributes: + boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. + masks (Masks, optional): A Masks object containing the detection masks. + probs (torch.Tensor, optional): A tensor containing the detection class probabilities. + orig_shape (tuple, optional): Original image size. + """ + + def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None: + self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes + self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks + self.probs = probs.softmax(0) if probs is not None else None + self.orig_shape = orig_shape + self.comp = ["boxes", "masks", "probs"] + + def pandas(self): + pass + # TODO masks.pandas + boxes.pandas + cls.pandas + + def __getitem__(self, idx): + r = Results(orig_shape=self.orig_shape) + for item in self.comp: + if getattr(self, item) is None: + continue + setattr(r, item, getattr(self, item)[idx]) + return r + + def cpu(self): + r = Results(orig_shape=self.orig_shape) + for item in self.comp: + if getattr(self, item) is None: + continue + setattr(r, item, getattr(self, item).cpu()) + return r + + def numpy(self): + r = Results(orig_shape=self.orig_shape) + for item in self.comp: + if getattr(self, item) is None: + continue + setattr(r, item, getattr(self, item).numpy()) + return r + + def cuda(self): + r = Results(orig_shape=self.orig_shape) + for item in self.comp: + if getattr(self, item) is None: + continue + setattr(r, item, getattr(self, item).cuda()) + return r + + def to(self, *args, **kwargs): + r = Results(orig_shape=self.orig_shape) + for item in self.comp: + if getattr(self, item) is None: + continue + setattr(r, item, getattr(self, item).to(*args, **kwargs)) + return r + + def __len__(self): + for item in self.comp: + if getattr(self, item) is None: + continue + return len(getattr(self, item)) + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = f'Ultralytics YOLO {self.__class__} instance\n' # string + if self.boxes: + s = s + self.boxes.__repr__() + '\n' + if self.masks: + s = s + self.masks.__repr__() + '\n' + if self.probs: + s = s + self.probs.__repr__() + s += f'original size: {self.orig_shape}\n' + + return s + + +class Boxes: + """ + A class for storing and manipulating detection boxes. + + Args: + boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes, + with shape (num_boxes, 6). The last two columns should contain confidence and class values. + orig_shape (tuple): Original image size, in the format (height, width). + + Attributes: + boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes, + with shape (num_boxes, 6). + orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width). + + Properties: + xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format. + conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes. + cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes. + xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format. + xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size. + xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size. + """ + + def __init__(self, boxes, orig_shape) -> None: + if boxes.ndim == 1: + boxes = boxes[None, :] + assert boxes.shape[-1] == 6 # xyxy, conf, cls + self.boxes = boxes + self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \ + else np.asarray(orig_shape) + + @property + def xyxy(self): + return self.boxes[:, :4] + + @property + def conf(self): + return self.boxes[:, -2] + + @property + def cls(self): + return self.boxes[:, -1] + + @property + @lru_cache(maxsize=2) # maxsize 1 should suffice + def xywh(self): + return ops.xyxy2xywh(self.xyxy) + + @property + @lru_cache(maxsize=2) + def xyxyn(self): + return self.xyxy / self.orig_shape[[1, 0, 1, 0]] + + @property + @lru_cache(maxsize=2) + def xywhn(self): + return self.xywh / self.orig_shape[[1, 0, 1, 0]] + + def cpu(self): + boxes = self.boxes.cpu() + return Boxes(boxes, self.orig_shape) + + def numpy(self): + boxes = self.boxes.numpy() + return Boxes(boxes, self.orig_shape) + + def cuda(self): + boxes = self.boxes.cuda() + return Boxes(boxes, self.orig_shape) + + def to(self, *args, **kwargs): + boxes = self.boxes.to(*args, **kwargs) + return Boxes(boxes, self.orig_shape) + + def pandas(self): + LOGGER.info('results.pandas() method not yet implemented') + ''' + new = copy(self) # return copy + ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns + cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns + for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]): + a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update + setattr(new, k, [pd.DataFrame(x, columns=c) for x in a]) + return new + ''' + + @property + def shape(self): + return self.boxes.shape + + def __len__(self): # override len(results) + return len(self.boxes) + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" + + f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}") + + def __getitem__(self, idx): + boxes = self.boxes[idx] + return Boxes(boxes, self.orig_shape) + + +class Masks: + """ + A class for storing and manipulating detection masks. + + Args: + masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width). + orig_shape (tuple): Original image size, in the format (height, width). + + Attributes: + masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width). + orig_shape (tuple): Original image size, in the format (height, width). + + Properties: + segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks. + """ + + def __init__(self, masks, orig_shape) -> None: + self.masks = masks # N, h, w + self.orig_shape = orig_shape + + @property + @lru_cache(maxsize=1) + def segments(self): + return [ + ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True) + for x in reversed(ops.masks2segments(self.masks))] + + @property + def shape(self): + return self.masks.shape + + def cpu(self): + masks = self.masks.cpu() + return Masks(masks, self.orig_shape) + + def numpy(self): + masks = self.masks.numpy() + return Masks(masks, self.orig_shape) + + def cuda(self): + masks = self.masks.cuda() + return Masks(masks, self.orig_shape) + + def to(self, *args, **kwargs): + masks = self.masks.to(*args, **kwargs) + return Masks(masks, self.orig_shape) + + def __len__(self): # override len(results) + return len(self.masks) + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" + + f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}") + + def __getitem__(self, idx): + masks = self.masks[idx] + return Masks(masks, self.im_shape, self.orig_shape) + + +if __name__ == "__main__": + # test examples + results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640]) + results = results.cuda() + print("--cuda--pass--") + results = results.cpu() + print("--cpu--pass--") + results = results.to("cuda:0") + print("--to-cuda--pass--") + results = results.to("cpu") + print("--to-cpu--pass--") + results = results.numpy() + print("--numpy--pass--") + # box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5]) + # box = box.cuda() + # box = box.cpu() + # box = box.numpy() + # for b in box: + # print(b) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 5c6f262..6ee15fa 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -30,7 +30,7 @@ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, yaml_save) from ultralytics.yolo.utils.autobatch import check_train_batch_size -from ultralytics.yolo.utils.checks import check_file, print_args +from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.files import get_latest_run, increment_path from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer @@ -203,7 +203,9 @@ class BaseTrainer: self.set_model_attributes() if world_size > 1: self.model = DDP(self.model, device_ids=[rank]) - + # Check imgsz + gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride) + self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs * 2) # Batch size if self.batch_size == -1: if RANK == -1: # single-GPU only, estimate best batch size diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 830bbf4..c808332 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -5,7 +5,6 @@ import inspect import logging.config import os import platform -import subprocess import sys import tempfile import threading @@ -13,6 +12,7 @@ import uuid from pathlib import Path import cv2 +import git import numpy as np import pandas as pd import torch @@ -134,10 +134,8 @@ def is_git_directory() -> bool: Returns: bool: True if the current working directory is inside a git repository, False otherwise. """ - import git try: - from git import Repo - Repo(search_parent_directories=True) + git.Repo(search_parent_directories=True) # subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) # CLI alternative return True except git.exc.InvalidGitRepositoryError: # subprocess.CalledProcessError: @@ -187,9 +185,10 @@ def get_git_root_dir(): If the current file is not part of a git repository, returns None. """ try: - output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) - return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # parent/.git - except subprocess.CalledProcessError: + # output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) + # return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # CLI alternative + return Path(git.Repo(search_parent_directories=True).working_tree_dir) + except git.exc.InvalidGitRepositoryError: # (subprocess.CalledProcessError, FileNotFoundError): return None diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index c24e6ab..4e213ec 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -15,20 +15,39 @@ from .metrics import box_iou class Profile(contextlib.ContextDecorator): - # YOLOv8 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager + """ + YOLOv8 Profile class. + Usage: as a decorator with @Profile() or as a context manager with 'with Profile():' + """ + def __init__(self, t=0.0): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. Defaults to 0.0. + """ self.t = t self.cuda = torch.cuda.is_available() def __enter__(self): + """ + Start timing. + """ self.start = self.time() return self def __exit__(self, type, value, traceback): + """ + Stop timing. + """ self.dt = self.time() - self.start # delta-time self.t += self.dt # accumulate dt def time(self): + """ + Get current time. + """ if self.cuda: torch.cuda.synchronize() return time.time() @@ -48,15 +67,15 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) def segment2box(segment, width=640, height=640): """ - Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to - (xyxy) + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) + Args: - segment (torch.tensor): the segment label + segment (torch.Tensor): the segment label width (int): the width of the image. Defaults to 640 height (int): The height of the image. Defaults to 640 Returns: - (np.array): the minimum and maximum x and y values of the segment. + (np.ndarray): the minimum and maximum x and y values of the segment. """ # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) x, y = segment.T # segment xy @@ -67,15 +86,18 @@ def segment2box(segment, width=640, height=640): def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): """ - Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in (img1_shape) to the shape of a different image (img0_shape). + Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in + (img1_shape) to the shape of a different image (img0_shape). + Args: img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). - boxes (torch.tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) img0_shape (tuple): the shape of the target image, in the format of (height, width). - ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be calculated based on the size difference between the two images. + ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. Returns: - boxes (torch.tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) """ if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new @@ -92,7 +114,16 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): def make_divisible(x, divisor): - # Returns nearest x divisible by divisor + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int or torch.Tensor): The divisor. + + Returns: + int: The nearest number divisible by the divisor. + """ if isinstance(divisor, torch.Tensor): divisor = int(divisor.max()) # to int return math.ceil(x / divisor) * divisor @@ -232,7 +263,7 @@ def clip_boxes(boxes, shape): shape Args: - boxes (torch.tensor): the bounding boxes to clip + boxes (torch.Tensor): the bounding boxes to clip shape (tuple): the shape of the image """ if isinstance(boxes, torch.Tensor): # faster individually @@ -246,7 +277,19 @@ def clip_boxes(boxes, shape): def clip_coords(boxes, shape): - # Clip bounding xyxy bounding boxes to image shape (height, width) + """ + Clip bounding xyxy bounding boxes to image shape (height, width). + + Args: + boxes (torch.Tensor or numpy.ndarray): Bounding boxes to be clipped. + shape (tuple): The shape of the image. (height, width) + + Returns: + None + + Note: + The input `boxes` is modified in-place, there is no return value. + """ if isinstance(boxes, torch.Tensor): # faster individually boxes[:, 0].clamp_(0, shape[1]) # x1 boxes[:, 1].clamp_(0, shape[0]) # y1 @@ -263,12 +306,12 @@ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): Args: im1_shape (tuple): model input shape, [h, w] - masks (torch.tensor): [h, w, num] + masks (torch.Tensor): [h, w, num] im0_shape (tuple): the original image shape ratio_pad (tuple): the ratio of the padding to the original image. Returns: - masks (torch.tensor): The masks that are being returned. + masks (torch.Tensor): The masks that are being returned. """ # Rescale coordinates (xyxy) from im1_shape to im0_shape if ratio_pad is None: # calculate from im0_shape @@ -297,9 +340,9 @@ def xyxy2xywh(x): Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format. Args: - x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x1, y1, x2, y2) format. + x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. Returns: - y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format. """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center @@ -311,12 +354,13 @@ def xyxy2xywh(x): def xywh2xyxy(x): """ - Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner. + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Args: - x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x, y, width, height) format. + x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. Returns: - y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x @@ -337,7 +381,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): padw (int): Padding width. Defaults to 0 padh (int): Padding height. Defaults to 0 Returns: - y (numpy.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + y (np.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x @@ -349,16 +394,17 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): """ - Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, width and height are normalized to image dimensions + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. + x, y, width and height are normalized to image dimensions Args: - x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x1, y1, x2, y2) format. + x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. w (int): The width of the image. Defaults to 640 h (int): The height of the image. Defaults to 640 clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False eps (float): The minimum value of the box's width and height. Defaults to 0.0 Returns: - y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format """ if clip: clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip @@ -375,13 +421,13 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0): Convert normalized coordinates to pixel coordinates of shape (n,2) Args: - x (numpy.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates + x (np.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates w (int): The width of the image. Defaults to 640 h (int): The height of the image. Defaults to 640 padw (int): The width of the padding. Defaults to 0 padh (int): The height of the padding. Defaults to 0 Returns: - y (numpy.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box + y (np.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = w * x[..., 0] + padw # top left x @@ -394,9 +440,9 @@ def xywh2ltwh(x): Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. Args: - x (numpy.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + x (np.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format Returns: - y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x @@ -409,9 +455,9 @@ def xyxy2ltwh(x): Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right Args: - x (numpy.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + x (np.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format Returns: - y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format. + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format. """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 2] = x[:, 2] - x[:, 0] # width @@ -424,7 +470,7 @@ def ltwh2xywh(x): Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center Args: - x (torch.tensor): the input tensor + x (torch.Tensor): the input tensor """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x @@ -437,10 +483,10 @@ def ltwh2xyxy(x): It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right Args: - x (numpy.ndarray) or (torch.Tensor): the input image + x (np.ndarray) or (torch.Tensor): the input image Returns: - y (numpy.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes. + y (np.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes. """ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 2] = x[:, 2] + x[:, 0] # width @@ -456,7 +502,7 @@ def segments2boxes(segments): segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates Returns: - (np.array): the xywh coordinates of the bounding boxes. + (np.ndarray): the xywh coordinates of the bounding boxes. """ boxes = [] for s in segments: @@ -467,7 +513,7 @@ def segments2boxes(segments): def resample_segments(segments, n=1000): """ - It takes a list of segments (n,2) and returns a list of segments (n,2) where each segment has been up-sampled to n points + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. Args: segments (list): a list of (n,2) arrays, where n is the number of points in the segment. @@ -489,11 +535,11 @@ def crop_mask(masks, boxes): It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box Args: - masks (torch.tensor): [h, w, n] tensor of masks - boxes (torch.tensor): [n, 4] tensor of bbox coordinates in relative point form + masks (torch.Tensor): [h, w, n] tensor of masks + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form Returns: - (torch.tensor): The masks are being cropped to the bounding box. + (torch.Tensor): The masks are being cropped to the bounding box. """ n, h, w = masks.shape x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n) @@ -509,13 +555,13 @@ def process_mask_upsample(protos, masks_in, bboxes, shape): quality but is slower. Args: - protos (torch.tensor): [mask_dim, mask_h, mask_w] - masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms - bboxes (torch.tensor): [n, 4], n is number of masks after nms + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms + bboxes (torch.Tensor): [n, 4], n is number of masks after nms shape (tuple): the size of the input image (h,w) Returns: - (torch.tensor): The upsampled masks. + (torch.Tensor): The upsampled masks. """ c, mh, mw = protos.shape # CHW masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) @@ -530,13 +576,13 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False): downsampled quality of mask Args: - protos (torch.tensor): [mask_dim, mask_h, mask_w] - masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms - bboxes (torch.tensor): [n, 4], n is number of masks after nms + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms + bboxes (torch.Tensor): [n, 4], n is number of masks after nms shape (tuple): the size of the input image (h,w) Returns: - (torch.tensor): The processed masks. + (torch.Tensor): The processed masks. """ c, mh, mw = protos.shape # CHW @@ -560,13 +606,13 @@ def process_mask_native(protos, masks_in, bboxes, shape): It takes the output of the mask head, and crops it after upsampling to the bounding boxes. Args: - protos (torch.tensor): [mask_dim, mask_h, mask_w] - masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms - bboxes (torch.tensor): [n, 4], n is number of masks after nms + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms + bboxes (torch.Tensor): [n, 4], n is number of masks after nms shape (tuple): the size of the input image (h,w) Returns: - masks (torch.tensor): The returned masks with dimensions [h, w, n] + masks (torch.Tensor): The returned masks with dimensions [h, w, n] """ c, mh, mw = protos.shape # CHW masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) @@ -587,13 +633,13 @@ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=F Args: img1_shape (tuple): The shape of the image that the segments are from. - segments (torch.tensor): the segments to be scaled + segments (torch.Tensor): the segments to be scaled img0_shape (tuple): the shape of the image that the segmentation is being applied to ratio_pad (tuple): the ratio of the image size to the padded image size. normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False Returns: - segments (torch.tensor): the segmented image. + segments (torch.Tensor): the segmented image. """ if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new @@ -617,7 +663,7 @@ def masks2segments(masks, strategy='largest'): It takes a list of masks(n,h,w) and returns a list of segments(n,xy) Args: - masks (torch.tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) + masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) strategy (str): 'concat' or 'largest'. Defaults to largest Returns: diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index f4f8cfe..612adbb 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -4,8 +4,8 @@ import hydra import torch from ultralytics.yolo.engine.predictor import BasePredictor +from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT -from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.plotting import Annotator @@ -15,20 +15,27 @@ class ClassificationPredictor(BasePredictor): return Annotator(img, example=str(self.model.names), pil=True) def preprocess(self, img): - img = torch.Tensor(img).to(self.model.device) + img = (img if isinstance(img, torch.Tensor) else torch.Tensor(img)).to(self.model.device) img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 return img - def write_results(self, idx, preds, batch): + def postprocess(self, preds, img, orig_img): + results = [] + for i, pred in enumerate(preds): + shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape + results.append(Results(probs=pred.softmax(0), orig_shape=shape[:2])) + return results + + def write_results(self, idx, results, batch): p, im, im0 = batch log_string = "" if len(im.shape) == 3: im = im[None] # expand for batch dim self.seen += 1 im0 = im0.copy() - if self.webcam: # batch_size >= 1 + if self.webcam or self.from_img: # batch_size >= 1 log_string += f'{idx}: ' - frame = self.dataset.cound + frame = self.dataset.count else: frame = getattr(self.dataset, 'frame', 0) @@ -38,9 +45,10 @@ class ClassificationPredictor(BasePredictor): log_string += '%gx%g ' % im.shape[2:] # print string self.annotator = self.get_annotator(im0) - prob = preds[idx].softmax(0) - if self.return_outputs: - self.output["prob"] = prob.cpu().numpy() + result = results[idx] + if len(result) == 0: + return log_string + prob = result.probs # Print results top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " @@ -59,7 +67,6 @@ class ClassificationPredictor(BasePredictor): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" - cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" predictor = ClassificationPredictor(cfg) predictor.predict_cli() diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index aca9703..a68aad5 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -56,6 +56,8 @@ class ClassificationTrainer(BaseTrainer): # Load a YOLO model locally, from torchvision, or from Ultralytics assets if model.endswith(".pt"): self.model, _ = attempt_load_one_weight(model, device='cpu') + for p in model.parameters(): + p.requires_grad = True # for training elif model.endswith(".yaml"): self.model = self.get_model(cfg=model) elif model in torchvision.models.__dict__: diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index d2c5c06..94a0d89 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -4,8 +4,8 @@ import hydra import torch from ultralytics.yolo.engine.predictor import BasePredictor +from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops -from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box @@ -27,58 +27,53 @@ class DetectionPredictor(BasePredictor): agnostic=self.args.agnostic_nms, max_det=self.args.max_det) + results = [] for i, pred in enumerate(preds): - shape = orig_img[i].shape if self.webcam else orig_img.shape + shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() + results.append(Results(boxes=pred, orig_shape=shape[:2])) + return results - return preds - - def write_results(self, idx, preds, batch): + def write_results(self, idx, results, batch): p, im, im0 = batch log_string = "" if len(im.shape) == 3: im = im[None] # expand for batch dim self.seen += 1 im0 = im0.copy() - if self.webcam: # batch_size >= 1 + if self.webcam or self.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count else: frame = getattr(self.dataset, 'frame', 0) - self.data_path = p - # save_path = str(self.save_dir / p.name) # im.jpg self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') log_string += '%gx%g ' % im.shape[2:] # print string self.annotator = self.get_annotator(im0) - det = preds[idx] + det = results[idx].boxes # TODO: make boxes inherit from tensors if len(det) == 0: return log_string - for c in det[:, 5].unique(): - n = (det[:, 5] == c).sum() # detections per class + for c in det.cls.unique(): + n = (det.cls == c).sum() # detections per class log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " - if self.return_outputs: - self.output["det"] = det.cpu().numpy() - # write - gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh - for *xyxy, conf, cls in reversed(det): + for d in reversed(det): + cls, conf = d.cls.squeeze(), d.conf.squeeze() if self.args.save_txt: # Write to file - xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh - line = (cls, *xywh, conf) if self.args.save_conf else (cls, *xywh) # label format + line = (cls, *(d.xywhn.view(-1).tolist()), conf) \ + if self.args.save_conf else (cls, *(d.xywhn.view(-1).tolist())) # label format with open(f'{self.txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image c = int(cls) # integer class label = None if self.args.hide_labels else ( self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') - self.annotator.box_label(xyxy, label, color=colors(c, True)) + self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: imc = im0.copy() - save_one_box(xyxy, + save_one_box(d.xyxy, imc, file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg', BGR=True) @@ -89,7 +84,6 @@ class DetectionPredictor(BasePredictor): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): cfg.model = cfg.model or "yolov8n.pt" - cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" predictor = DetectionPredictor(cfg) predictor.predict_cli() diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 47f3021..17af713 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -3,8 +3,8 @@ import hydra import torch +from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops -from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.plotting import colors, save_one_box from ultralytics.yolo.v8.detect.predict import DetectionPredictor @@ -12,7 +12,6 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor class SegmentationPredictor(DetectionPredictor): def postprocess(self, preds, img, orig_img): - masks = [] # TODO: filter by classes p = ops.non_max_suppression(preds[0], self.args.conf, @@ -20,27 +19,29 @@ class SegmentationPredictor(DetectionPredictor): agnostic=self.args.agnostic_nms, max_det=self.args.max_det, nm=32) + results = [] proto = preds[1][-1] for i, pred in enumerate(p): - shape = orig_img[i].shape if self.webcam else orig_img.shape + shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape if not len(pred): + results.append(Results(boxes=pred[:, :6], orig_shape=shape[:2])) # save empty boxes continue if self.args.retina_masks: pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() - masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC + masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2]) # HWC else: - masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC + masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() + results.append(Results(boxes=pred[:, :6], masks=masks, orig_shape=shape[:2])) + return results - return (p, masks) - - def write_results(self, idx, preds, batch): + def write_results(self, idx, results, batch): p, im, im0 = batch log_string = "" if len(im.shape) == 3: im = im[None] # expand for batch dim self.seen += 1 - if self.webcam: # batch_size >= 1 + if self.webcam or self.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count else: @@ -51,54 +52,48 @@ class SegmentationPredictor(DetectionPredictor): log_string += '%gx%g ' % im.shape[2:] # print string self.annotator = self.get_annotator(im0) - preds, masks = preds - det = preds[idx] - if len(det) == 0: + result = results[idx] + if len(result) == 0: return log_string - # Segments - mask = masks[idx] - if self.args.save_txt or self.return_outputs: - shape = im0.shape if self.args.retina_masks else im.shape[2:] - segments = [ - ops.scale_segments(shape, x, im0.shape, normalize=False) for x in reversed(ops.masks2segments(mask))] + det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor # Print results - for c in det[:, 5].unique(): - n = (det[:, 5] == c).sum() # detections per class - log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # add to string + for c in det.cls.unique(): + n = (det.cls == c).sum() # detections per class + log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # Mask plotting self.annotator.masks( - mask, - colors=[colors(x, True) for x in det[:, 5]], + mask.masks, + colors=[colors(x, True) for x in det.cls], im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx]) - det = reversed(det[:, :6]) - if self.return_outputs: - self.output["det"] = det.cpu().numpy() - self.output["segment"] = segments + # Segments + if self.args.save_txt: + segments = mask.segments # Write results - for j, (*xyxy, conf, cls) in enumerate(det): + for j, d in enumerate(reversed(det)): + cls, conf = d.cls.squeeze(), d.conf.squeeze() if self.args.save_txt: # Write to file seg = segments[j].copy() - seg[:, 0] /= shape[1] # width - seg[:, 1] /= shape[0] # height seg = seg.reshape(-1) # (n,2) to (n*2) line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format with open(f'{self.txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save or self.args.save_crop or self.args.show: + if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image c = int(cls) # integer class label = None if self.args.hide_labels else ( self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') - self.annotator.box_label(xyxy, label, color=colors(c, True)) - # annotator.draw.polygon(segments[j], outline=colors(c, True), width=3) + self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: imc = im0.copy() - save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.model.names[c] / f'{p.stem}.jpg', BGR=True) + save_one_box(d.xyxy, + imc, + file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg', + BGR=True) return log_string @@ -106,7 +101,6 @@ class SegmentationPredictor(DetectionPredictor): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): cfg.model = cfg.model or "yolov8n-seg.pt" - cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" predictor = SegmentationPredictor(cfg)