diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 96a9e31..9411948 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -106,11 +106,7 @@ jobs:
shell: bash # for Windows compatibility
run: |
python -m pip install --upgrade pip wheel
- if [ "${{ matrix.os }}" == "macos-latest" ]; then
- pip install -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu
- else
- pip install -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu
- fi
+ pip install -e ".[export]" coverage --extra-index-url https://download.pytorch.org/whl/cpu
yolo export format=tflite imgsz=32 || true
- name: Check environment
run: |
@@ -125,16 +121,25 @@ jobs:
pip list
- name: Benchmark DetectionModel
shell: bash
- run: yolo benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.26
+ run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.26
- name: Benchmark SegmentationModel
shell: bash
- run: yolo benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.30
+ run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.30
- name: Benchmark ClassificationModel
shell: bash
- run: yolo benchmark model='path with spaces/${{ matrix.model }}-cls.pt' imgsz=160 verbose=0.36
+ run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-cls.pt' imgsz=160 verbose=0.36
- name: Benchmark PoseModel
shell: bash
- run: yolo benchmark model='path with spaces/${{ matrix.model }}-pose.pt' imgsz=160 verbose=0.17
+ run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-pose.pt' imgsz=160 verbose=0.17
+ - name: Merge Coverage Reports
+ run: |
+ coverage xml -o coverage-benchmarks.xml
+ - name: Upload Coverage Reports to CodeCov
+ uses: codecov/codecov-action@v3
+ with:
+ flags: Benchmarks
+ env:
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Benchmark Summary
run: |
cat benchmarks.log
@@ -183,9 +188,11 @@ jobs:
- name: Pytest tests
shell: bash # for Windows compatibility
run: pytest --cov=ultralytics/ --cov-report xml tests/
- - name: Upload coverage reports to Codecov
+ - name: Upload Coverage Reports to CodeCov
if: github.repository == 'ultralytics/ultralytics' && matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
uses: codecov/codecov-action@v3
+ with:
+ flags: Tests
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.gitignore b/.gitignore
index 9ab57be..d197c74 100644
--- a/.gitignore
+++ b/.gitignore
@@ -118,6 +118,9 @@ venv.bak/
.spyderproject
.spyproject
+# VSCode project settings
+.vscode/
+
# Rope project settings
.ropeproject
diff --git a/docs/reference/engine/exporter.md b/docs/reference/engine/exporter.md
index 2cfc47c..6650f8c 100644
--- a/docs/reference/engine/exporter.md
+++ b/docs/reference/engine/exporter.md
@@ -1,6 +1,6 @@
---
-description: Explore the exporter functionality of Ultralytics. Learn about exporting formats, iOSDetectModel, and try exporting with examples.
-keywords: Ultralytics, Exporter, iOSDetectModel, Export Formats, Try export
+description: Explore the exporter functionality of Ultralytics. Learn about exporting formats, IOSDetectModel, and try exporting with examples.
+keywords: Ultralytics, Exporter, IOSDetectModel, Export Formats, Try export
---
# Reference for `ultralytics/engine/exporter.py`
@@ -14,7 +14,7 @@ keywords: Ultralytics, Exporter, iOSDetectModel, Export Formats, Try export
---
-## ::: ultralytics.engine.exporter.iOSDetectModel
+## ::: ultralytics.engine.exporter.IOSDetectModel
---
@@ -28,7 +28,3 @@ keywords: Ultralytics, Exporter, iOSDetectModel, Export Formats, Try export
---
## ::: ultralytics.engine.exporter.try_export
-
----
-## ::: ultralytics.engine.exporter.export
-
diff --git a/docs/reference/models/rtdetr/train.md b/docs/reference/models/rtdetr/train.md
index e524f0e..f7f2881 100644
--- a/docs/reference/models/rtdetr/train.md
+++ b/docs/reference/models/rtdetr/train.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, RTDETRTrainer, model training, Ultralytics models, PyTorc
---
## ::: ultralytics.models.rtdetr.train.RTDETRTrainer
-
----
-## ::: ultralytics.models.rtdetr.train.train
-
diff --git a/docs/reference/models/yolo/classify/predict.md b/docs/reference/models/yolo/classify/predict.md
index 1078b11..4b2485d 100644
--- a/docs/reference/models/yolo/classify/predict.md
+++ b/docs/reference/models/yolo/classify/predict.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, classification predictor, predict, YOLO, AI models, model
---
## ::: ultralytics.models.yolo.classify.predict.ClassificationPredictor
-
----
-## ::: ultralytics.models.yolo.classify.predict.predict
-
diff --git a/docs/reference/models/yolo/classify/train.md b/docs/reference/models/yolo/classify/train.md
index 7fe2477..42e14a9 100644
--- a/docs/reference/models/yolo/classify/train.md
+++ b/docs/reference/models/yolo/classify/train.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, Classification Trainer, deep learning, training pro
---
## ::: ultralytics.models.yolo.classify.train.ClassificationTrainer
-
----
-## ::: ultralytics.models.yolo.classify.train.train
-
diff --git a/docs/reference/models/yolo/classify/val.md b/docs/reference/models/yolo/classify/val.md
index 3235460..df505ec 100644
--- a/docs/reference/models/yolo/classify/val.md
+++ b/docs/reference/models/yolo/classify/val.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, ClassificationValidator, model validation, model fi
---
## ::: ultralytics.models.yolo.classify.val.ClassificationValidator
-
----
-## ::: ultralytics.models.yolo.classify.val.val
-
diff --git a/docs/reference/models/yolo/detect/predict.md b/docs/reference/models/yolo/detect/predict.md
index 8440af1..15191e4 100644
--- a/docs/reference/models/yolo/detect/predict.md
+++ b/docs/reference/models/yolo/detect/predict.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, DetectionPredictor, detect, predict, object detecti
---
## ::: ultralytics.models.yolo.detect.predict.DetectionPredictor
-
----
-## ::: ultralytics.models.yolo.detect.predict.predict
-
diff --git a/docs/reference/models/yolo/detect/train.md b/docs/reference/models/yolo/detect/train.md
index edae647..57092a3 100644
--- a/docs/reference/models/yolo/detect/train.md
+++ b/docs/reference/models/yolo/detect/train.md
@@ -12,7 +12,3 @@ keywords: Ultralytics YOLO, YOLO, Detection Trainer, Model Training, Machine Lea
---
## ::: ultralytics.models.yolo.detect.train.DetectionTrainer
-
----
-## ::: ultralytics.models.yolo.detect.train.train
-
diff --git a/docs/reference/models/yolo/detect/val.md b/docs/reference/models/yolo/detect/val.md
index a8c0192..1afe42e 100644
--- a/docs/reference/models/yolo/detect/val.md
+++ b/docs/reference/models/yolo/detect/val.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, Detection Validator, model valuation, precision, re
---
## ::: ultralytics.models.yolo.detect.val.DetectionValidator
-
----
-## ::: ultralytics.models.yolo.detect.val.val
-
diff --git a/docs/reference/models/yolo/pose/predict.md b/docs/reference/models/yolo/pose/predict.md
index e48fd35..e5ea33c 100644
--- a/docs/reference/models/yolo/pose/predict.md
+++ b/docs/reference/models/yolo/pose/predict.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, PosePredictor, machine learning, AI, predictive mod
---
## ::: ultralytics.models.yolo.pose.predict.PosePredictor
-
----
-## ::: ultralytics.models.yolo.pose.predict.predict
-
diff --git a/docs/reference/models/yolo/pose/train.md b/docs/reference/models/yolo/pose/train.md
index ebbef1c..972edd4 100644
--- a/docs/reference/models/yolo/pose/train.md
+++ b/docs/reference/models/yolo/pose/train.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, PoseTrainer, pose training, AI modeling, custom dat
---
## ::: ultralytics.models.yolo.pose.train.PoseTrainer
-
----
-## ::: ultralytics.models.yolo.pose.train.train
-
diff --git a/docs/reference/models/yolo/pose/val.md b/docs/reference/models/yolo/pose/val.md
index f916805..a826bc3 100644
--- a/docs/reference/models/yolo/pose/val.md
+++ b/docs/reference/models/yolo/pose/val.md
@@ -12,7 +12,3 @@ keywords: PoseValidator, Ultralytics, YOLO, Object detection, Pose validation
---
## ::: ultralytics.models.yolo.pose.val.PoseValidator
-
----
-## ::: ultralytics.models.yolo.pose.val.val
-
diff --git a/docs/reference/models/yolo/segment/predict.md b/docs/reference/models/yolo/segment/predict.md
index 90fb6a7..acfcc1a 100644
--- a/docs/reference/models/yolo/segment/predict.md
+++ b/docs/reference/models/yolo/segment/predict.md
@@ -12,7 +12,3 @@ keywords: YOLO, Ultralytics, object detection, segmentation predictor
---
## ::: ultralytics.models.yolo.segment.predict.SegmentationPredictor
-
----
-## ::: ultralytics.models.yolo.segment.predict.predict
-
diff --git a/docs/usage/engine.md b/docs/usage/engine.md
index 9acf801..4edf331 100644
--- a/docs/usage/engine.md
+++ b/docs/usage/engine.md
@@ -14,7 +14,7 @@ the required functions or operations as long the as correct formats are followed
custom model and dataloader by just overriding these functions:
* `get_model(cfg, weights)` - The function that builds the model to be trained
-* `get_dataloder()` - The function that builds the dataloader
+* `get_dataloader()` - The function that builds the dataloader
More details and source code can be found in [`BaseTrainer` Reference](../reference/engine/trainer.md)
## DetectionTrainer
diff --git a/mkdocs.yml b/mkdocs.yml
index 2beadc6..508e50d 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -401,6 +401,7 @@ plugins:
handlers:
python:
options:
+ docstring_style: google
show_root_heading: true
show_source: true
- ultralytics:
diff --git a/requirements.txt b/requirements.txt
index d38e7cd..ed1093f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
# Ultralytics requirements
-# Usage: pip install -r requirements.txt
+# Example: pip install -r requirements.txt
# Base ----------------------------------------
matplotlib>=3.2.2
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 91dc169..37e5694 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -40,6 +40,14 @@ def test_train(task, model, data):
@pytest.mark.parametrize('task,model,data', TASK_ARGS)
def test_val(task, model, data):
+ # Download annotations to run pycocotools eval
+ # from ultralytics.utils import SETTINGS, Path
+ # from ultralytics.utils.downloads import download
+ # url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
+ # download(f'{url}instances_val2017.json', dir=Path(SETTINGS['datasets_dir']) / 'coco8/annotations')
+ # download(f'{url}person_keypoints_val2017.json', dir=Path(SETTINGS['datasets_dir']) / 'coco8-pose/annotations')
+
+ # Validate
run(f'yolo val {task} model={WEIGHTS_DIR / model}.pt data={data} imgsz=32 save_txt save_json')
diff --git a/tests/test_python.py b/tests/test_python.py
index bda1d38..7b6c6f5 100644
--- a/tests/test_python.py
+++ b/tests/test_python.py
@@ -132,13 +132,13 @@ def test_val():
def test_train_scratch():
model = YOLO(CFG)
- model.train(data='coco8.yaml', epochs=1, imgsz=32, cache='disk', batch=-1) # test disk caching with AutoBatch
+ model.train(data='coco8.yaml', epochs=2, imgsz=32, cache='disk', batch=-1, close_mosaic=1)
model(SOURCE)
def test_train_pretrained():
model = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt')
- model.train(data='coco8-seg.yaml', epochs=1, imgsz=32, cache='ram', copy_paste=0.5, mixup=0.5) # test RAM caching
+ model.train(data='coco8-seg.yaml', epochs=1, imgsz=32, cache='ram', copy_paste=0.5, mixup=0.5)
model(SOURCE)
@@ -283,6 +283,12 @@ def test_data_converter():
coco80_to_coco91_class()
+def test_data_annotator():
+ from ultralytics.data.annotator import auto_annotate
+
+ auto_annotate(ASSETS, det_model='yolov8n.pt', sam_model='mobile_sam.pt', output_dir=TMP / 'auto_annotate_labels')
+
+
def test_events():
# Test event sending
from ultralytics.hub.utils import Events
@@ -304,12 +310,15 @@ def test_utils_init():
def test_utils_checks():
- from ultralytics.utils.checks import check_requirements, check_yolov5u_filename, git_describe
+ from ultralytics.utils.checks import (check_imgsz, check_requirements, check_yolov5u_filename, git_describe,
+ print_args)
check_yolov5u_filename('yolov5n.pt')
# check_imshow(warn=True)
git_describe(ROOT)
check_requirements() # check requirements.txt
+ check_imgsz([600, 600], max_dim=1)
+ print_args()
def test_utils_benchmarks():
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index b8f068f..ad8fa44 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.157'
+__version__ = '8.0.158'
from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO
diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py
index 2ea66be..b4e08c7 100644
--- a/ultralytics/data/annotator.py
+++ b/ultralytics/data/annotator.py
@@ -8,6 +8,7 @@ from ultralytics import SAM, YOLO
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
"""
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
+
Args:
data (str): Path to a folder containing images to be annotated.
det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
@@ -15,12 +16,20 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
output_dir (str | None | optional): Directory to save the annotated results.
Defaults to a 'labels' folder in the same directory as 'data'.
+
+ Example:
+ ```python
+ from ultralytics.data.annotator import auto_annotate
+
+ auto_annotate(data='ultralytics/assets', det_model='yolov8n.pt', sam_model='mobile_sam.pt')
+ ```
"""
det_model = YOLO(det_model)
sam_model = SAM(sam_model)
+ data = Path(data)
if not output_dir:
- output_dir = Path(str(data)).parent / 'labels'
+ output_dir = data.parent / f'{data.stem}_auto_annotate_labels'
Path(output_dir).mkdir(exist_ok=True, parents=True)
det_results = det_model(data, stream=True, device=device)
diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py
index 053a3a5..39866ae 100644
--- a/ultralytics/data/augment.py
+++ b/ultralytics/data/augment.py
@@ -402,7 +402,7 @@ class RandomPerspective:
keypoints (ndarray): keypoints, [N, 17, 3].
M (ndarray): affine matrix.
- Return:
+ Returns:
new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
"""
n, nkpt = keypoints.shape[:2]
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index 3398c05..3f3a91a 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -484,7 +484,7 @@ class Exporter:
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
model = self.model
elif self.model.task == 'detect':
- model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
+ model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
else:
if self.args.nms:
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.")
@@ -846,12 +846,11 @@ class Exporter:
out0, out1 = iter(spec.description.output)
if MACOS:
from PIL import Image
- img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
- # img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
+ img = Image.new('RGB', (w, h)) # w=192, h=320
out = model.predict({'image': img})
- out0_shape = out[out0.name].shape
- out1_shape = out[out1.name].shape
- else: # linux and windows can not run model.predict(), get sizes from pytorch output y
+ out0_shape = out[out0.name].shape # (3780, 80)
+ out1_shape = out[out1.name].shape # (3780, 4)
+ else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
out1_shape = self.output_shape[2], 4 # (3780, 4)
@@ -963,11 +962,11 @@ class Exporter:
callback(self)
-class iOSDetectModel(torch.nn.Module):
- """Wrap an Ultralytics YOLO model for iOS export."""
+class IOSDetectModel(torch.nn.Module):
+ """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
def __init__(self, model, im):
- """Initialize the iOSDetectModel class with a YOLO model and example image."""
+ """Initialize the IOSDetectModel class with a YOLO model and example image."""
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
self.model = model
@@ -981,21 +980,3 @@ class iOSDetectModel(torch.nn.Module):
"""Normalize predictions of object detection model with input size-dependent factors."""
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
-
-
-def export(cfg=DEFAULT_CFG):
- """Export a YOLOv model to a specific format."""
- cfg.model = cfg.model or 'yolov8n.yaml'
- cfg.format = cfg.format or 'torchscript'
-
- from ultralytics import YOLO
- model = YOLO(cfg.model)
- model.export(**vars(cfg))
-
-
-if __name__ == '__main__':
- """
- CLI:
- yolo mode=export model=yolov8n.yaml format=onnx
- """
- export()
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index ebd51bb..8b10980 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -138,12 +138,14 @@ class BasePredictor:
return self.model(im, augment=self.args.augment, visualize=visualize)
def pre_transform(self, im):
- """Pre-transform input image before inference.
+ """
+ Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
- Return: A list of transformed imgs.
+ Returns:
+ (list): A list of transformed images.
"""
same_shapes = all(x.shape == im[0].shape for x in im)
auto = same_shapes and self.model.pt
diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py
index 0d42c40..2ccdec6 100644
--- a/ultralytics/models/fastsam/prompt.py
+++ b/ultralytics/models/fastsam/prompt.py
@@ -26,7 +26,7 @@ class FastSAMPrompt:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
- check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
+ check_requirements('git+https://github.com/openai/CLIP.git')
import clip
self.clip = clip
@@ -91,8 +91,6 @@ class FastSAMPrompt:
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
- h = y2 - y1
- w = x2 - x1
return [x1, y1, x2, y2]
def plot(self,
@@ -104,9 +102,11 @@ class FastSAMPrompt:
mask_random_color=True,
better_quality=True,
retina=False,
- withContours=True):
+ with_countouers=True):
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
+ if isinstance(annotations, torch.Tensor):
+ annotations = annotations.cpu().numpy()
result_name = os.path.basename(self.img_path)
image = self.ori_img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -123,41 +123,22 @@ class FastSAMPrompt:
plt.imshow(image)
if better_quality:
- if isinstance(annotations[0], torch.Tensor):
- annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
- if self.device == 'cpu':
- annotations = np.array(annotations)
- self.fast_show_mask(
- annotations,
- plt.gca(),
- random_color=mask_random_color,
- bbox=bbox,
- points=points,
- pointlabel=point_label,
- retinamask=retina,
- target_height=original_h,
- target_width=original_w,
- )
- else:
- if isinstance(annotations[0], np.ndarray):
- annotations = torch.from_numpy(annotations)
- self.fast_show_mask_gpu(
- annotations,
- plt.gca(),
- random_color=mask_random_color,
- bbox=bbox,
- points=points,
- pointlabel=point_label,
- retinamask=retina,
- target_height=original_h,
- target_width=original_w,
- )
- if isinstance(annotations, torch.Tensor):
- annotations = annotations.cpu().numpy()
- if withContours:
+ self.fast_show_mask(
+ annotations,
+ plt.gca(),
+ random_color=mask_random_color,
+ bbox=bbox,
+ points=points,
+ pointlabel=point_label,
+ retinamask=retina,
+ target_height=original_h,
+ target_width=original_w,
+ )
+
+ if with_countouers:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
@@ -184,8 +165,8 @@ class FastSAMPrompt:
LOGGER.info(f'Saved to {save_path.absolute()}')
# CPU post process
+ @staticmethod
def fast_show_mask(
- self,
annotation,
ax,
random_color=False,
@@ -196,32 +177,29 @@ class FastSAMPrompt:
target_height=960,
target_width=960,
):
- msak_sum = annotation.shape[0]
- height = annotation.shape[1]
- weight = annotation.shape[2]
- # 将annotation 按照面积 排序
+ n, h, w = annotation.shape # batch, height, width
+
areas = np.sum(annotation, axis=(1, 2))
- sorted_indices = np.argsort(areas)
- annotation = annotation[sorted_indices]
+ annotation = annotation[np.argsort(areas)]
index = (annotation != 0).argmax(axis=0)
if random_color:
- color = np.random.random((msak_sum, 1, 1, 3))
+ color = np.random.random((n, 1, 1, 3))
else:
- color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
- transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
+ color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
+ transparency = np.ones((n, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
- show = np.zeros((height, weight, 4))
- h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
+ show = np.zeros((h, w, 4))
+ h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
- # 使用向量化索引更新show的值
+
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
- # draw point
+ # Draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
@@ -240,63 +218,6 @@ class FastSAMPrompt:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
- def fast_show_mask_gpu(
- self,
- annotation,
- ax,
- random_color=False,
- bbox=None,
- points=None,
- pointlabel=None,
- retinamask=True,
- target_height=960,
- target_width=960,
- ):
- msak_sum = annotation.shape[0]
- height = annotation.shape[1]
- weight = annotation.shape[2]
- areas = torch.sum(annotation, dim=(1, 2))
- sorted_indices = torch.argsort(areas, descending=False)
- annotation = annotation[sorted_indices]
- # 找每个位置第一个非零值下标
- index = (annotation != 0).to(torch.long).argmax(dim=0)
- if random_color:
- color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
- else:
- color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
- annotation.device)
- transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
- visual = torch.cat([color, transparency], dim=-1)
- mask_image = torch.unsqueeze(annotation, -1) * visual
- # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
- show = torch.zeros((height, weight, 4)).to(annotation.device)
- h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
- # 使用向量化索引更新show的值
- show[h_indices, w_indices, :] = mask_image[indices]
- show_cpu = show.cpu().numpy()
- if bbox is not None:
- x1, y1, x2, y2 = bbox
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
- # draw point
- if points is not None:
- plt.scatter(
- [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
- s=20,
- c='y',
- )
- plt.scatter(
- [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
- s=20,
- c='m',
- )
- if not retinamask:
- show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
- ax.imshow(show_cpu)
-
- # clip
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
diff --git a/ultralytics/models/nas/predict.py b/ultralytics/models/nas/predict.py
index fe9f486..32f031c 100644
--- a/ultralytics/models/nas/predict.py
+++ b/ultralytics/models/nas/predict.py
@@ -5,7 +5,6 @@ import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
-from ultralytics.utils.ops import xyxy2xywh
class NASPredictor(BasePredictor):
@@ -14,7 +13,7 @@ class NASPredictor(BasePredictor):
"""Postprocess predictions and returns a list of Results objects."""
# Cat boxes and class scores
- boxes = xyxy2xywh(preds_in[0][0])
+ boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
preds = ops.non_max_suppression(preds,
diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py
index 05986c0..5c39171 100644
--- a/ultralytics/models/nas/val.py
+++ b/ultralytics/models/nas/val.py
@@ -4,7 +4,6 @@ import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
-from ultralytics.utils.ops import xyxy2xywh
__all__ = ['NASValidator']
@@ -13,7 +12,7 @@ class NASValidator(DetectionValidator):
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
- boxes = xyxy2xywh(preds_in[0][0])
+ boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
return ops.non_max_suppression(preds,
self.args.conf,
diff --git a/ultralytics/models/rtdetr/predict.py b/ultralytics/models/rtdetr/predict.py
index 356098d..d966d0d 100644
--- a/ultralytics/models/rtdetr/predict.py
+++ b/ultralytics/models/rtdetr/predict.py
@@ -9,6 +9,19 @@ from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
+ """
+ A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
+
+ Example:
+ ```python
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.rtdetr import RTDETRPredictor
+
+ args = dict(model='rtdetr-l.pt', source=ASSETS)
+ predictor = RTDETRPredictor(overrides=args)
+ predictor.predict_cli()
+ ```
+ """
def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
@@ -38,7 +51,9 @@ class RTDETRPredictor(BasePredictor):
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
- Return: A list of transformed imgs.
+ Notes: The size must be square(640) and scaleFilled.
+
+ Returns:
+ (list): A list of transformed imgs.
"""
- # The size must be square(640) and scaleFilled.
return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]
diff --git a/ultralytics/models/rtdetr/train.py b/ultralytics/models/rtdetr/train.py
index a900491..1e58668 100644
--- a/ultralytics/models/rtdetr/train.py
+++ b/ultralytics/models/rtdetr/train.py
@@ -6,12 +6,28 @@ import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
-from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
+from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
+ """
+ A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
+
+ Notes:
+ - F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
+ - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
+
+ Example:
+ ```python
+ from ultralytics.models.rtdetr.train import RTDETRTrainer
+
+ args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
+ trainer = RTDETRTrainer(overrides=args)
+ trainer.train()
+ ```
+ """
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
@@ -54,27 +70,3 @@ class RTDETRTrainer(DetectionTrainer):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
-
-
-def train(cfg=DEFAULT_CFG, use_python=False):
- """Train and optimize RTDETR model given training data and device."""
- model = 'rtdetr-l.yaml'
- data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
- device = cfg.device if cfg.device is not None else ''
-
- # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
- # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
- args = dict(model=model,
- data=data,
- device=device,
- imgsz=640,
- exist_ok=True,
- batch=4,
- deterministic=False,
- amp=False)
- trainer = RTDETRTrainer(overrides=args)
- trainer.train()
-
-
-if __name__ == '__main__':
- train()
diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py
index ff6855a..c90a99b 100644
--- a/ultralytics/models/rtdetr/val.py
+++ b/ultralytics/models/rtdetr/val.py
@@ -67,6 +67,18 @@ class RTDETRDataset(YOLODataset):
class RTDETRValidator(DetectionValidator):
+ """
+ A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
+
+ Example:
+ ```python
+ from ultralytics.models.rtdetr import RTDETRValidator
+
+ args = dict(model='rtdetr-l.pt', data='coco8.yaml')
+ validator = RTDETRValidator(args=args)
+ validator(model=args['model'])
+ ```
+ """
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index c3cec8c..8c0604d 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -55,12 +55,14 @@ class Predictor(BasePredictor):
return img
def pre_transform(self, im):
- """Pre-transform input image before inference.
+ """
+ Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
- Return: A list of transformed imgs.
+ Returns:
+ (list): A list of transformed images.
"""
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
diff --git a/ultralytics/models/yolo/classify/__init__.py b/ultralytics/models/yolo/classify/__init__.py
index 84e7114..33d72e6 100644
--- a/ultralytics/models/yolo/classify/__init__.py
+++ b/ultralytics/models/yolo/classify/__init__.py
@@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
-from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
-from ultralytics.models.yolo.classify.val import ClassificationValidator, val
+from ultralytics.models.yolo.classify.predict import ClassificationPredictor
+from ultralytics.models.yolo.classify.train import ClassificationTrainer
+from ultralytics.models.yolo.classify.val import ClassificationValidator
-__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
+__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'
diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py
index 8e2f594..95b17e4 100644
--- a/ultralytics/models/yolo/classify/predict.py
+++ b/ultralytics/models/yolo/classify/predict.py
@@ -4,10 +4,26 @@ import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
-from ultralytics.utils import ASSETS, DEFAULT_CFG
+from ultralytics.utils import DEFAULT_CFG
class ClassificationPredictor(BasePredictor):
+ """
+ A class extending the BasePredictor class for prediction based on a classification model.
+
+ Notes:
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+
+ Example:
+ ```python
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.yolo.classify import ClassificationPredictor
+
+ args = dict(model='yolov8n-cls.pt', source=ASSETS)
+ predictor = ClassificationPredictor(overrides=args)
+ predictor.predict_cli()
+ ```
+ """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@@ -30,21 +46,3 @@ class ClassificationPredictor(BasePredictor):
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
return results
-
-
-def predict(cfg=DEFAULT_CFG, use_python=False):
- """Run YOLO model predictions on input images/videos."""
- model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
- source = cfg.source or ASSETS
-
- args = dict(model=model, source=source)
- if use_python:
- from ultralytics import YOLO
- YOLO(model)(**args)
- else:
- predictor = ClassificationPredictor(overrides=args)
- predictor.predict_cli()
-
-
-if __name__ == '__main__':
- predict()
diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py
index 420322b..8c798f0 100644
--- a/ultralytics/models/yolo/classify/train.py
+++ b/ultralytics/models/yolo/classify/train.py
@@ -13,6 +13,21 @@ from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_di
class ClassificationTrainer(BaseTrainer):
+ """
+ A class extending the BaseTrainer class for training based on a classification model.
+
+ Notes:
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.classify import ClassificationTrainer
+
+ args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
+ trainer = ClassificationTrainer(overrides=args)
+ trainer.train()
+ ```
+ """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
@@ -137,22 +152,3 @@ class ClassificationTrainer(BaseTrainer):
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
-
-
-def train(cfg=DEFAULT_CFG, use_python=False):
- """Train a YOLO classification model."""
- model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
- data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
- device = cfg.device if cfg.device is not None else ''
-
- args = dict(model=model, data=data, device=device)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).train(**args)
- else:
- trainer = ClassificationTrainer(overrides=args)
- trainer.train()
-
-
-if __name__ == '__main__':
- train()
diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py
index 0df2a35..fd913b9 100644
--- a/ultralytics/models/yolo/classify/val.py
+++ b/ultralytics/models/yolo/classify/val.py
@@ -4,12 +4,27 @@ import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
-from ultralytics.utils import DEFAULT_CFG, LOGGER
+from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
+ """
+ A class extending the BaseValidator class for validation based on a classification model.
+
+ Notes:
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.classify import ClassificationValidator
+
+ args = dict(model='yolov8n-cls.pt', data='imagenet10')
+ validator = ClassificationValidator(args=args)
+ validator(model=args['model'])
+ ```
+ """
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
@@ -92,21 +107,3 @@ class ClassificationValidator(BaseValidator):
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
-
-
-def val(cfg=DEFAULT_CFG, use_python=False):
- """Validate YOLO model using custom data."""
- model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
- data = cfg.data or 'mnist160'
-
- args = dict(model=model, data=data)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).val(**args)
- else:
- validator = ClassificationValidator(args=args)
- validator(model=args['model'])
-
-
-if __name__ == '__main__':
- val()
diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py
index 481951a..20fc0c4 100644
--- a/ultralytics/models/yolo/detect/__init__.py
+++ b/ultralytics/models/yolo/detect/__init__.py
@@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from .predict import DetectionPredictor, predict
-from .train import DetectionTrainer, train
-from .val import DetectionValidator, val
+from .predict import DetectionPredictor
+from .train import DetectionTrainer
+from .val import DetectionValidator
-__all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
+__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator'
diff --git a/ultralytics/models/yolo/detect/predict.py b/ultralytics/models/yolo/detect/predict.py
index 88b134b..46de75f 100644
--- a/ultralytics/models/yolo/detect/predict.py
+++ b/ultralytics/models/yolo/detect/predict.py
@@ -4,10 +4,23 @@ import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
-from ultralytics.utils import ASSETS, DEFAULT_CFG, ops
+from ultralytics.utils import ops
class DetectionPredictor(BasePredictor):
+ """
+ A class extending the BasePredictor class for prediction based on a detection model.
+
+ Example:
+ ```python
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.yolo.detect import DetectionPredictor
+
+ args = dict(model='yolov8n.pt', source=ASSETS)
+ predictor = DetectionPredictor(overrides=args)
+ predictor.predict_cli()
+ ```
+ """
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
@@ -27,21 +40,3 @@ class DetectionPredictor(BasePredictor):
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
-
-
-def predict(cfg=DEFAULT_CFG, use_python=False):
- """Runs YOLO model inference on input image(s)."""
- model = cfg.model or 'yolov8n.pt'
- source = cfg.source or ASSETS
-
- args = dict(model=model, source=source)
- if use_python:
- from ultralytics import YOLO
- YOLO(model)(**args)
- else:
- predictor = DetectionPredictor(overrides=args)
- predictor.predict_cli()
-
-
-if __name__ == '__main__':
- predict()
diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py
index e0eeef7..56d9243 100644
--- a/ultralytics/models/yolo/detect/train.py
+++ b/ultralytics/models/yolo/detect/train.py
@@ -8,12 +8,24 @@ from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
-from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
+from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
class DetectionTrainer(BaseTrainer):
+ """
+ A class extending the BaseTrainer class for training based on a detection model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.detect import DetectionTrainer
+
+ args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
+ trainer = DetectionTrainer(overrides=args)
+ trainer.train()
+ ```
+ """
def build_dataset(self, img_path, mode='train', batch=None):
"""
@@ -102,22 +114,3 @@ class DetectionTrainer(BaseTrainer):
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
-
-
-def train(cfg=DEFAULT_CFG, use_python=False):
- """Train and optimize YOLO model given training data and device."""
- model = cfg.model or 'yolov8n.pt'
- data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
- device = cfg.device if cfg.device is not None else ''
-
- args = dict(model=model, data=data, device=device)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).train(**args)
- else:
- trainer = DetectionTrainer(overrides=args)
- trainer.train()
-
-
-if __name__ == '__main__':
- train()
diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py
index d6fb7e1..6199f77 100644
--- a/ultralytics/models/yolo/detect/val.py
+++ b/ultralytics/models/yolo/detect/val.py
@@ -8,7 +8,7 @@ import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
-from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
+from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import output_to_target, plot_images
@@ -16,6 +16,18 @@ from ultralytics.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator):
+ """
+ A class extending the BaseValidator class for validation based on a detection model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.detect import DetectionValidator
+
+ args = dict(model='yolov8n.pt', data='coco8.yaml')
+ validator = DetectionValidator(args=args)
+ validator(model=args['model'])
+ ```
+ """
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
@@ -254,21 +266,3 @@ class DetectionValidator(BaseValidator):
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
-
-
-def val(cfg=DEFAULT_CFG, use_python=False):
- """Validate trained YOLO model on validation dataset."""
- model = cfg.model or 'yolov8n.pt'
- data = cfg.data or 'coco8.yaml'
-
- args = dict(model=model, data=data)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).val(**args)
- else:
- validator = DetectionValidator(args=args)
- validator(model=args['model'])
-
-
-if __name__ == '__main__':
- val()
diff --git a/ultralytics/models/yolo/pose/__init__.py b/ultralytics/models/yolo/pose/__init__.py
index 8ec6d58..2a79f0f 100644
--- a/ultralytics/models/yolo/pose/__init__.py
+++ b/ultralytics/models/yolo/pose/__init__.py
@@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from .predict import PosePredictor, predict
-from .train import PoseTrainer, train
-from .val import PoseValidator, val
+from .predict import PosePredictor
+from .train import PoseTrainer
+from .val import PoseValidator
-__all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict'
+__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'
diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py
index ffafadf..1a410a1 100644
--- a/ultralytics/models/yolo/pose/predict.py
+++ b/ultralytics/models/yolo/pose/predict.py
@@ -2,10 +2,23 @@
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
-from ultralytics.utils import ASSETS, DEFAULT_CFG, LOGGER, ops
+from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
+ """
+ A class extending the DetectionPredictor class for prediction based on a pose model.
+
+ Example:
+ ```python
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.yolo.pose import PosePredictor
+
+ args = dict(model='yolov8n-pose.pt', source=ASSETS)
+ predictor = PosePredictor(overrides=args)
+ predictor.predict_cli()
+ ```
+ """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@@ -40,21 +53,3 @@ class PosePredictor(DetectionPredictor):
boxes=pred[:, :6],
keypoints=pred_kpts))
return results
-
-
-def predict(cfg=DEFAULT_CFG, use_python=False):
- """Runs YOLO to predict objects in an image or video."""
- model = cfg.model or 'yolov8n-pose.pt'
- source = cfg.source or ASSETS
-
- args = dict(model=model, source=source)
- if use_python:
- from ultralytics import YOLO
- YOLO(model)(**args)
- else:
- predictor = PosePredictor(overrides=args)
- predictor.predict_cli()
-
-
-if __name__ == '__main__':
- predict()
diff --git a/ultralytics/models/yolo/pose/train.py b/ultralytics/models/yolo/pose/train.py
index 979c3f9..2d4f4e0 100644
--- a/ultralytics/models/yolo/pose/train.py
+++ b/ultralytics/models/yolo/pose/train.py
@@ -9,6 +9,18 @@ from ultralytics.utils.plotting import plot_images, plot_results
class PoseTrainer(yolo.detect.DetectionTrainer):
+ """
+ A class extending the DetectionTrainer class for training based on a pose model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.pose import PoseTrainer
+
+ args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
+ trainer = PoseTrainer(overrides=args)
+ trainer.train()
+ ```
+ """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
@@ -59,22 +71,3 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
-
-
-def train(cfg=DEFAULT_CFG, use_python=False):
- """Train the YOLO model on the given data and device."""
- model = cfg.model or 'yolov8n-pose.yaml'
- data = cfg.data or 'coco8-pose.yaml'
- device = cfg.device if cfg.device is not None else ''
-
- args = dict(model=model, data=data, device=device)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).train(**args)
- else:
- trainer = PoseTrainer(overrides=args)
- trainer.train()
-
-
-if __name__ == '__main__':
- train()
diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py
index b68fa50..3332e67 100644
--- a/ultralytics/models/yolo/pose/val.py
+++ b/ultralytics/models/yolo/pose/val.py
@@ -6,13 +6,25 @@ import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
-from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
+from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PoseValidator(DetectionValidator):
+ """
+ A class extending the DetectionValidator class for validation based on a pose model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.pose import PoseValidator
+
+ args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
+ validator = PoseValidator(args=args)
+ validator(model=args['model'])
+ ```
+ """
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
@@ -201,21 +213,3 @@ class PoseValidator(DetectionValidator):
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
-
-
-def val(cfg=DEFAULT_CFG, use_python=False):
- """Performs validation on YOLO model using given data."""
- model = cfg.model or 'yolov8n-pose.pt'
- data = cfg.data or 'coco8-pose.yaml'
-
- args = dict(model=model, data=data)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).val(**args)
- else:
- validator = PoseValidator(args=args)
- validator(model=args['model'])
-
-
-if __name__ == '__main__':
- val()
diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py
index 61a9efe..c84a570 100644
--- a/ultralytics/models/yolo/segment/__init__.py
+++ b/ultralytics/models/yolo/segment/__init__.py
@@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from .predict import SegmentationPredictor, predict
-from .train import SegmentationTrainer, train
-from .val import SegmentationValidator, val
+from .predict import SegmentationPredictor
+from .train import SegmentationTrainer
+from .val import SegmentationValidator
-__all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
+__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'
diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py
index c30efe6..866c32c 100644
--- a/ultralytics/models/yolo/segment/predict.py
+++ b/ultralytics/models/yolo/segment/predict.py
@@ -4,10 +4,23 @@ import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
-from ultralytics.utils import ASSETS, DEFAULT_CFG, ops
+from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
+ """
+ A class extending the DetectionPredictor class for prediction based on a segmentation model.
+
+ Example:
+ ```python
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.yolo.segment import SegmentationPredictor
+
+ args = dict(model='yolov8n-seg.pt', source=ASSETS)
+ predictor = SegmentationPredictor(overrides=args)
+ predictor.predict_cli()
+ ```
+ """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
@@ -42,21 +55,3 @@ class SegmentationPredictor(DetectionPredictor):
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results
-
-
-def predict(cfg=DEFAULT_CFG, use_python=False):
- """Runs YOLO object detection on an image or video source."""
- model = cfg.model or 'yolov8n-seg.pt'
- source = cfg.source or ASSETS
-
- args = dict(model=model, source=source)
- if use_python:
- from ultralytics import YOLO
- YOLO(model)(**args)
- else:
- predictor = SegmentationPredictor(overrides=args)
- predictor.predict_cli()
-
-
-if __name__ == '__main__':
- predict()
diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py
index e61d7fd..c6e148b 100644
--- a/ultralytics/models/yolo/segment/train.py
+++ b/ultralytics/models/yolo/segment/train.py
@@ -9,6 +9,18 @@ from ultralytics.utils.plotting import plot_images, plot_results
class SegmentationTrainer(yolo.detect.DetectionTrainer):
+ """
+ A class extending the DetectionTrainer class for training based on a segmentation model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.segment import SegmentationTrainer
+
+ args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
+ trainer = SegmentationTrainer(overrides=args)
+ trainer.train()
+ ```
+ """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
@@ -46,19 +58,11 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
-def train(cfg=DEFAULT_CFG, use_python=False):
+def train(cfg=DEFAULT_CFG):
"""Train a YOLO segmentation model based on passed arguments."""
- model = cfg.model or 'yolov8n-seg.pt'
- data = cfg.data or 'coco8-seg.yaml'
- device = cfg.device if cfg.device is not None else ''
-
- args = dict(model=model, data=data, device=device)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).train(**args)
- else:
- trainer = SegmentationTrainer(overrides=args)
- trainer.train()
+ args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
+ trainer = SegmentationTrainer(overrides=args)
+ trainer.train()
if __name__ == '__main__':
diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py
index 6cabcaf..6a3aa15 100644
--- a/ultralytics/models/yolo/segment/val.py
+++ b/ultralytics/models/yolo/segment/val.py
@@ -15,6 +15,18 @@ from ultralytics.utils.plotting import output_to_target, plot_images
class SegmentationValidator(DetectionValidator):
+ """
+ A class extending the DetectionValidator class for validation based on a segmentation model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.segment import SegmentationValidator
+
+ args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
+ validator = SegmentationValidator(args=args)
+ validator(model=args['model'])
+ ```
+ """
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
@@ -233,18 +245,11 @@ class SegmentationValidator(DetectionValidator):
return stats
-def val(cfg=DEFAULT_CFG, use_python=False):
+def val(cfg=DEFAULT_CFG):
"""Validate trained YOLO model on validation data."""
- model = cfg.model or 'yolov8n-seg.pt'
- data = cfg.data or 'coco8-seg.yaml'
-
- args = dict(model=model, data=data)
- if use_python:
- from ultralytics import YOLO
- YOLO(model).val(**args)
- else:
- validator = SegmentationValidator(args=args)
- validator(model=args['model'])
+ args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
+ validator = SegmentationValidator(args=args)
+ validator(model=args['model'])
if __name__ == '__main__':
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index e1c7ea8..93de16b 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -414,13 +414,10 @@ class AutoBackend(nn.Module):
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
if x.ndim > 2: # if task is not classification
- # Denormalize xywh with input image size
+ # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
- # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
- x[:, 0] *= w
- x[:, 1] *= h
- x[:, 2] *= w
- x[:, 3] *= h
+ x[:, [0, 2]] *= w
+ x[:, [1, 3]] *= h
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
index 8d07306..3933e12 100644
--- a/ultralytics/utils/__init__.py
+++ b/ultralytics/utils/__init__.py
@@ -169,7 +169,7 @@ def plt_settings(rcparams=None, backend='Agg'):
"""
Decorator to temporarily set rc parameters and the backend for a plotting function.
- Usage:
+ Example:
decorator: @plt_settings({"font.size": 12})
context manager: with plt_settings({"font.size": 12}):
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index 75a3270..3d7adba 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -18,8 +18,7 @@ from .metrics import box_iou
class Profile(contextlib.ContextDecorator):
"""
- YOLOv8 Profile class.
- Usage: as a decorator with @Profile() or as a context manager with 'with Profile():'
+ YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
"""
def __init__(self, t=0.0):
diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py
index aea8918..87f4579 100644
--- a/ultralytics/utils/tal.py
+++ b/ultralytics/utils/tal.py
@@ -10,12 +10,14 @@ TORCH_1_10 = check_version(torch.__version__, '1.10.0')
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
- """select the positive anchor center in gt
+ """
+ Select the positive anchor center in gt.
Args:
xy_centers (Tensor): shape(h*w, 4)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
- Return:
+
+ Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
n_anchors = xy_centers.shape[0]
@@ -27,13 +29,14 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
- """if an anchor box is assigned to multiple gts,
- the one with the highest iou will be selected.
+ """
+ If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
Args:
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
overlaps (Tensor): shape(b, n_max_boxes, h*w)
- Return:
+
+ Returns:
target_gt_idx (Tensor): shape(b, h*w)
fg_mask (Tensor): shape(b, h*w)
mask_pos (Tensor): shape(b, n_max_boxes, h*w)