From 243fc4b1fe214ff6c27759dad51c37809db8f7f8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 Apr 2023 00:36:50 +0200 Subject: [PATCH] `ultralytics 8.0.89` SAM predict and auto-annotate (#2298) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yonghye Kwon Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com> Co-authored-by: Dhruv Nair Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Ayush Chaurasia Co-authored-by: Snyk bot Co-authored-by: Laughing-q <1185102784@qq.com> --- docs/robots.txt | 41 -- mkdocs.yml | 219 ++++--- setup.py | 4 +- tests/test_python.py | 4 +- ultralytics/__init__.py | 5 +- ultralytics/nn/modules.py | 35 ++ ultralytics/vit/__init__.py | 1 + ultralytics/vit/sam/__init__.py | 3 + ultralytics/vit/sam/amg.py | 311 ++++++++++ ultralytics/vit/sam/autosize.py | 92 +++ ultralytics/vit/sam/build.py | 121 ++++ ultralytics/vit/sam/model.py | 35 ++ ultralytics/vit/sam/modules/__init__.py | 0 ultralytics/vit/sam/modules/decoders.py | 161 +++++ ultralytics/vit/sam/modules/encoders.py | 582 ++++++++++++++++++ ultralytics/vit/sam/modules/mask_generator.py | 352 +++++++++++ .../vit/sam/modules/prompt_predictor.py | 240 ++++++++ ultralytics/vit/sam/modules/sam.py | 169 +++++ ultralytics/vit/sam/modules/transformer.py | 233 +++++++ ultralytics/vit/sam/predict.py | 52 ++ ultralytics/yolo/cfg/default.yaml | 1 - ultralytics/yolo/data/__init__.py | 4 +- ultralytics/yolo/data/annotator.py | 42 ++ ultralytics/yolo/data/base.py | 22 +- ultralytics/yolo/data/build.py | 103 +--- .../yolo/data/dataloaders/stream_loaders.py | 66 +- ultralytics/yolo/data/dataset.py | 34 +- ultralytics/yolo/engine/model.py | 2 +- ultralytics/yolo/engine/predictor.py | 48 +- ultralytics/yolo/engine/results.py | 3 +- ultralytics/yolo/engine/trainer.py | 4 + ultralytics/yolo/engine/validator.py | 4 + ultralytics/yolo/utils/callbacks/comet.py | 85 ++- ultralytics/yolo/utils/downloads.py | 3 +- ultralytics/yolo/utils/plotting.py | 70 ++- ultralytics/yolo/utils/tal.py | 2 +- ultralytics/yolo/v8/classify/predict.py | 8 +- ultralytics/yolo/v8/classify/train.py | 30 +- ultralytics/yolo/v8/classify/val.py | 33 +- ultralytics/yolo/v8/detect/predict.py | 9 +- ultralytics/yolo/v8/detect/train.py | 64 +- ultralytics/yolo/v8/detect/val.py | 48 +- ultralytics/yolo/v8/pose/predict.py | 6 +- ultralytics/yolo/v8/segment/predict.py | 6 +- 44 files changed, 2916 insertions(+), 441 deletions(-) create mode 100644 ultralytics/vit/__init__.py create mode 100644 ultralytics/vit/sam/__init__.py create mode 100644 ultralytics/vit/sam/amg.py create mode 100644 ultralytics/vit/sam/autosize.py create mode 100644 ultralytics/vit/sam/build.py create mode 100644 ultralytics/vit/sam/model.py create mode 100644 ultralytics/vit/sam/modules/__init__.py create mode 100644 ultralytics/vit/sam/modules/decoders.py create mode 100644 ultralytics/vit/sam/modules/encoders.py create mode 100644 ultralytics/vit/sam/modules/mask_generator.py create mode 100644 ultralytics/vit/sam/modules/prompt_predictor.py create mode 100644 ultralytics/vit/sam/modules/sam.py create mode 100644 ultralytics/vit/sam/modules/transformer.py create mode 100644 ultralytics/vit/sam/predict.py create mode 100644 ultralytics/yolo/data/annotator.py diff --git a/docs/robots.txt b/docs/robots.txt index 162c3e1..7d329b1 100644 --- a/docs/robots.txt +++ b/docs/robots.txt @@ -1,42 +1 @@ User-agent: * -Disallow: /tutorials/pruning-sparsity/ -Disallow: /tutorials/nvidia-jetson/ -Disallow: /tutorials/training-tips-best-results/ -Disallow: /tutorials/hyperparameter-evolution/ -Disallow: /callbacks/ -Disallow: /config/ -Disallow: /tutorials/transfer-learning-froze-layers/ -Disallow: /environments/Docker-Quickstart/ -Disallow: /tutorials/model-ensembling/ -Disallow: /tutorials/test-time-augmentation/ -Disallow: /quick-start/ -Disallow: /FAQ/augmentation/ -Disallow: /environments/AWS-Quickstart/ -Disallow: /tutorials/pytorch-hub/ -Disallow: /tutorials/torchscript-onnx-coreml-export/ -Disallow: /tasks/tracking/ -Disallow: /cfg/ -Disallow: /tasks/detection/ -Disallow: /tutorials/train-custom-datasets/ -Disallow: /cli/ -Disallow: /tasks/classification/ -Disallow: /tutorials/multi-gpu-training/ -Disallow: /engine/ -Disallow: /tasks/segmentation/ -Disallow: /predict/ -Disallow: /python/ -Disallow: /python -Disallow: /environments/GCP-Quickstart/ -Disallow: /cli -Disallow: /tutorials/comet-logging/ -Disallow: /cfg -Disallow: /tutorials/architecture-summary/ -Disallow: /tutorials/clearml-logging/ -Disallow: /sdk/ -Disallow: /tutorials/roboflow/ -Disallow: /tutorials/training-tips-best-results -Disallow: /package-framework/mock_detector/ -Disallow: /package-framework/ -Disallow: /tutorials/weights-and-biasis-logging/ -Disallow: /tutorials/pruning-sparsity -Disallow: /tutorials/train-custom-datasets diff --git a/mkdocs.yml b/mkdocs.yml index 990f007..f587be8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -123,6 +123,64 @@ markdown_extensions: plugins: - mkdocstrings - search + - redirects: + redirect_maps: + callbacks.md: usage/callbacks.md + cfg.md: usage/cfg.md + cli.md: usage/cli.md + config.md: usage/cfg.md + engine.md: usage/engine.md + environments/AWS-Quickstart.md: yolov5/environments/aws_quickstart_tutorial.md + environments/Docker-Quickstart.md: yolov5/environments/docker_image_quickstart_tutorial.md + environments/GCP-Quickstart.md: yolov5/environments/google_cloud_quickstart_tutorial.md + FAQ/augmentation.md: yolov5/tutorials/tips_for_best_training_results.md + package-framework.md: index.md + package-framework/mock_detector.md: index.md + predict.md: modes/predict.md + python.md: usage/python.md + quick-start.md: quickstart.md + reference/base_pred.md: reference/yolo/engine/predictor.md + reference/base_trainer.md: reference/yolo/engine/trainer.md + reference/exporter.md: reference/yolo/engine/exporter.md + reference/model.md: reference/yolo/engine/model.md + reference/nn.md: reference/nn/modules.md + reference/ops.md: reference/yolo/utils/ops.md + reference/results.md: reference/yolo/engine/results.md + sdk.md: index.md + tasks/classification.md: tasks/classify.md + tasks/detection.md: tasks/detect.md + tasks/segmentation.md: tasks/segment.md + tasks/keypoints.md: tasks/pose.md + tasks/tracking.md: modes/track.md + tutorials/architecture-summary.md: yolov5/tutorials/architecture_description.md + tutorials/clearml-logging.md: yolov5/tutorials/clearml_logging_integration.md + tutorials/comet-logging.md: yolov5/tutorials/comet_logging_integration.md + tutorials/hyperparameter-evolution.md: yolov5/tutorials/hyperparameter_evolution.md + tutorials/model-ensembling.md: yolov5/tutorials/model_ensembling.md + tutorials/multi-gpu-training.md: yolov5/tutorials/multi_gpu_training.md + tutorials/nvidia-jetson.md: yolov5/tutorials/running_on_jetson_nano.md + tutorials/pruning-sparsity.md: yolov5/tutorials/model_pruning_and_sparsity.md + tutorials/pytorch-hub.md: yolov5/tutorials/pytorch_hub_model_loading.md + tutorials/roboflow.md: yolov5/tutorials/roboflow_datasets_integration.md + tutorials/test-time-augmentation.md: yolov5/tutorials/test_time_augmentation.md + tutorials/torchscript-onnx-coreml-export.md: yolov5/tutorials/model_export.md + tutorials/train-custom-datasets.md: yolov5/tutorials/train_custom_data.md + tutorials/training-tips-best-results.md: yolov5/tutorials/tips_for_best_training_results.md + tutorials/transfer-learning-froze-layers.md: yolov5/tutorials/transfer_learning_with_frozen_layers.md + tutorials/weights-and-biasis-logging.md: yolov5/tutorials/comet_logging_integration.md + yolov5/pytorch_hub.md: yolov5/tutorials/pytorch_hub_model_loading.md + yolov5/hyp_evolution.md: yolov5/tutorials/hyperparameter_evolution.md + yolov5/pruning_sparsity.md: yolov5/tutorials/model_pruning_and_sparsity.md + yolov5/comet.md: yolov5/tutorials/comet_logging_integration.md + yolov5/tta.md: yolov5/tutorials/test_time_augmentation.md + yolov5/multi_gpu_training.md: yolov5/tutorials/multi_gpu_training.md + yolov5/ensemble.md: yolov5/tutorials/model_ensembling.md + yolov5/jetson_nano.md: yolov5/tutorials/running_on_jetson_nano.md + yolov5/transfer_learn_frozen.md: yolov5/tutorials/transfer_learning_with_frozen_layers.md + yolov5/neural_magic.md: yolov5/tutorials/neural_magic_pruning_quantization.md + yolov5/train_custom_data.md: yolov5/tutorials/train_custom_data.md + yolov5/architecture.md: yolov5/tutorials/architecture_description.md + yolov5/export.md: yolov5/tutorials/model_export.md # Primary navigation nav: @@ -166,88 +224,87 @@ nav: - Advanced Customization: usage/engine.md - Ultralytics HUB: hub.md - iOS and Android App: app.md - - Reference: - - hub: - - auth: reference/hub/auth.md - - session: reference/hub/session.md - - utils: reference/hub/utils.md - - nn: - - autobackend: reference/nn/autobackend.md - - autoshape: reference/nn/autoshape.md - - modules: reference/nn/modules.md - - tasks: reference/nn/tasks.md - - tracker: - - track: reference/tracker/track.md - - trackers: - - basetrack: reference/tracker/trackers/basetrack.md - - bot_sort: reference/tracker/trackers/bot_sort.md - - byte_tracker: reference/tracker/trackers/byte_tracker.md - - utils: - - gmc: reference/tracker/utils/gmc.md - - kalman_filter: reference/tracker/utils/kalman_filter.md - - matching: reference/tracker/utils/matching.md - - yolo: - - data: - - augment: reference/yolo/data/augment.md - - base: reference/yolo/data/base.md - - build: reference/yolo/data/build.md - - dataloaders: - - stream_loaders: reference/yolo/data/dataloaders/stream_loaders.md - - v5augmentations: reference/yolo/data/dataloaders/v5augmentations.md - - v5loader: reference/yolo/data/dataloaders/v5loader.md - - dataset: reference/yolo/data/dataset.md - - dataset_wrappers: reference/yolo/data/dataset_wrappers.md - - utils: reference/yolo/data/utils.md - - engine: - - exporter: reference/yolo/engine/exporter.md - - model: reference/yolo/engine/model.md - - predictor: reference/yolo/engine/predictor.md - - results: reference/yolo/engine/results.md - - trainer: reference/yolo/engine/trainer.md - - validator: reference/yolo/engine/validator.md - - utils: - - autobatch: reference/yolo/utils/autobatch.md - - benchmarks: reference/yolo/utils/benchmarks.md - - callbacks: - - base: reference/yolo/utils/callbacks/base.md - - clearml: reference/yolo/utils/callbacks/clearml.md - - comet: reference/yolo/utils/callbacks/comet.md - - hub: reference/yolo/utils/callbacks/hub.md - - mlflow: reference/yolo/utils/callbacks/mlflow.md - - neptune: reference/yolo/utils/callbacks/neptune.md - - raytune: reference/yolo/utils/callbacks/raytune.md - - tensorboard: reference/yolo/utils/callbacks/tensorboard.md - - wb: reference/yolo/utils/callbacks/wb.md - - checks: reference/yolo/utils/checks.md - - dist: reference/yolo/utils/dist.md - - downloads: reference/yolo/utils/downloads.md - - errors: reference/yolo/utils/errors.md - - files: reference/yolo/utils/files.md - - instance: reference/yolo/utils/instance.md - - loss: reference/yolo/utils/loss.md - - metrics: reference/yolo/utils/metrics.md - - ops: reference/yolo/utils/ops.md - - plotting: reference/yolo/utils/plotting.md - - tal: reference/yolo/utils/tal.md - - torch_utils: reference/yolo/utils/torch_utils.md - - v8: - - classify: - - predict: reference/yolo/v8/classify/predict.md - - train: reference/yolo/v8/classify/train.md - - val: reference/yolo/v8/classify/val.md - - detect: - - predict: reference/yolo/v8/detect/predict.md - - train: reference/yolo/v8/detect/train.md - - val: reference/yolo/v8/detect/val.md - - pose: - - predict: reference/yolo/v8/pose/predict.md - - train: reference/yolo/v8/pose/train.md - - val: reference/yolo/v8/pose/val.md - - segment: - - predict: reference/yolo/v8/segment/predict.md - - train: reference/yolo/v8/segment/train.md - - val: reference/yolo/v8/segment/val.md + - hub: + - auth: reference/hub/auth.md + - session: reference/hub/session.md + - utils: reference/hub/utils.md + - nn: + - autobackend: reference/nn/autobackend.md + - autoshape: reference/nn/autoshape.md + - modules: reference/nn/modules.md + - tasks: reference/nn/tasks.md + - tracker: + - track: reference/tracker/track.md + - trackers: + - basetrack: reference/tracker/trackers/basetrack.md + - bot_sort: reference/tracker/trackers/bot_sort.md + - byte_tracker: reference/tracker/trackers/byte_tracker.md + - utils: + - gmc: reference/tracker/utils/gmc.md + - kalman_filter: reference/tracker/utils/kalman_filter.md + - matching: reference/tracker/utils/matching.md + - yolo: + - data: + - augment: reference/yolo/data/augment.md + - base: reference/yolo/data/base.md + - build: reference/yolo/data/build.md + - dataloaders: + - stream_loaders: reference/yolo/data/dataloaders/stream_loaders.md + - v5augmentations: reference/yolo/data/dataloaders/v5augmentations.md + - v5loader: reference/yolo/data/dataloaders/v5loader.md + - dataset: reference/yolo/data/dataset.md + - dataset_wrappers: reference/yolo/data/dataset_wrappers.md + - utils: reference/yolo/data/utils.md + - engine: + - exporter: reference/yolo/engine/exporter.md + - model: reference/yolo/engine/model.md + - predictor: reference/yolo/engine/predictor.md + - results: reference/yolo/engine/results.md + - trainer: reference/yolo/engine/trainer.md + - validator: reference/yolo/engine/validator.md + - utils: + - autobatch: reference/yolo/utils/autobatch.md + - benchmarks: reference/yolo/utils/benchmarks.md + - callbacks: + - base: reference/yolo/utils/callbacks/base.md + - clearml: reference/yolo/utils/callbacks/clearml.md + - comet: reference/yolo/utils/callbacks/comet.md + - hub: reference/yolo/utils/callbacks/hub.md + - mlflow: reference/yolo/utils/callbacks/mlflow.md + - neptune: reference/yolo/utils/callbacks/neptune.md + - raytune: reference/yolo/utils/callbacks/raytune.md + - tensorboard: reference/yolo/utils/callbacks/tensorboard.md + - wb: reference/yolo/utils/callbacks/wb.md + - checks: reference/yolo/utils/checks.md + - dist: reference/yolo/utils/dist.md + - downloads: reference/yolo/utils/downloads.md + - errors: reference/yolo/utils/errors.md + - files: reference/yolo/utils/files.md + - instance: reference/yolo/utils/instance.md + - loss: reference/yolo/utils/loss.md + - metrics: reference/yolo/utils/metrics.md + - ops: reference/yolo/utils/ops.md + - plotting: reference/yolo/utils/plotting.md + - tal: reference/yolo/utils/tal.md + - torch_utils: reference/yolo/utils/torch_utils.md + - v8: + - classify: + - predict: reference/yolo/v8/classify/predict.md + - train: reference/yolo/v8/classify/train.md + - val: reference/yolo/v8/classify/val.md + - detect: + - predict: reference/yolo/v8/detect/predict.md + - train: reference/yolo/v8/detect/train.md + - val: reference/yolo/v8/detect/val.md + - pose: + - predict: reference/yolo/v8/pose/predict.md + - train: reference/yolo/v8/pose/train.md + - val: reference/yolo/v8/pose/val.md + - segment: + - predict: reference/yolo/v8/segment/predict.md + - train: reference/yolo/v8/segment/train.md + - val: reference/yolo/v8/segment/val.md - YOLOv5: - yolov5/index.md diff --git a/setup.py b/setup.py index 04ce73b..0aa77d8 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,9 @@ setup( include_package_data=True, install_requires=REQUIREMENTS + PKG_REQUIREMENTS, extras_require={ - 'dev': ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs-material', 'mkdocstrings[python]'], + 'dev': [ + 'check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs-material', 'mkdocstrings[python]', + 'mkdocs-redirects'], 'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflowjs'], # automatically installs tensorflow }, classifiers=[ diff --git a/tests/test_python.py b/tests/test_python.py index 23d94da..56bac37 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -185,7 +185,7 @@ def test_workflow(): def test_predict_callback_and_setup(): # test callback addition for prediction def on_predict_batch_end(predictor): # results -> List[batch_size] - path, _, im0s, _, _ = predictor.batch + path, im0s, _, _ = predictor.batch # print('on_predict_batch_end', im0s[0].shape) im0s = im0s if isinstance(im0s, list) else [im0s] bs = [predictor.dataset.bs for _ in range(len(path))] @@ -194,7 +194,7 @@ def test_predict_callback_and_setup(): model = YOLO(MODEL) model.add_callback('on_predict_batch_end', on_predict_batch_end) - dataset = load_inference_source(source=SOURCE, transforms=model.transforms) + dataset = load_inference_source(source=SOURCE) bs = dataset.bs # noqa access predictor properties results = model.predict(dataset, stream=True) # source already setup for _, (result, im0, bs) in enumerate(results): diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index e2e0eff..fc9d002 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,9 +1,10 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.88' +__version__ = '8.0.89' from ultralytics.hub import start +from ultralytics.vit.sam import SAM from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils.checks import check_yolo as checks -__all__ = '__version__', 'YOLO', 'checks', 'start' # allow simpler import +__all__ = '__version__', 'YOLO', 'SAM', 'checks', 'start' # allow simpler import diff --git a/ultralytics/nn/modules.py b/ultralytics/nn/modules.py index d913f60..8ad7672 100644 --- a/ultralytics/nn/modules.py +++ b/ultralytics/nn/modules.py @@ -495,6 +495,41 @@ class Detect(nn.Module): b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) +class MLPBlock(nn.Module): + + def __init__( + self, + embedding_dim, + mlp_dim, + act=nn.GELU, + ): + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + + def __init__(self, num_channels, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + class Segment(Detect): """YOLOv8 Segment head for segmentation models.""" diff --git a/ultralytics/vit/__init__.py b/ultralytics/vit/__init__.py new file mode 100644 index 0000000..32cd34f --- /dev/null +++ b/ultralytics/vit/__init__.py @@ -0,0 +1 @@ +from .sam import SAM # noqa diff --git a/ultralytics/vit/sam/__init__.py b/ultralytics/vit/sam/__init__.py new file mode 100644 index 0000000..64d8d05 --- /dev/null +++ b/ultralytics/vit/sam/__init__.py @@ -0,0 +1,3 @@ +from .build import build_sam # noqa +from .model import SAM # noqa +from .modules.prompt_predictor import PromptPredictor # noqa diff --git a/ultralytics/vit/sam/amg.py b/ultralytics/vit/sam/amg.py new file mode 100644 index 0000000..3c70f7c --- /dev/null +++ b/ultralytics/vit/sam/amg.py @@ -0,0 +1,311 @@ +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + """Initialize a MaskData object, ensuring all values are supported types.""" + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor)), 'MaskData only supports list, numpy arrays, and torch tensors.' + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + """Set an item in the MaskData object, ensuring it is a supported type.""" + assert isinstance( + item, (list, np.ndarray, torch.Tensor)), 'MaskData only supports list, numpy arrays, and torch tensors.' + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + """Delete an item from the MaskData object.""" + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + """Get an item from the MaskData object.""" + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + """Return an ItemsView of the MaskData object.""" + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + """Filter the MaskData object based on the given boolean tensor.""" + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f'MaskData key {k} has an unsupported type {type(v)}.') + + def cat(self, new_stats: 'MaskData') -> None: + """Concatenate a new MaskData object to the current one.""" + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f'MaskData key {k} has an unsupported type {type(v)}.') + + def to_numpy(self) -> None: + """Convert all torch tensors in the MaskData object to numpy arrays.""" + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge(boxes: torch.Tensor, + crop_box: List[int], + orig_box: List[int], + atol: float = 20.0) -> torch.Tensor: + """Return a boolean tensor indicating if boxes are near the crop edge.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + """Convert bounding boxes from XYXY format to XYWH format.""" + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + """Yield batches of data from the input arguments.""" + assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.' + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """Encode masks as uncompressed RLEs in the format expected by pycocotools.""" + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat([ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ]) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({'size': [h, w], 'counts': counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle['size'] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle['counts']: + mask[idx:idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + """Calculate the area of a mask from its uncompressed RLE.""" + return sum(rle['counts'][1::2]) + + +def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, + dtype=torch.int32)) + unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + return np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + + +def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: + """Generate point grids for all crop layers.""" + return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)] + + +def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int, + overlap_ratio: float) -> Tuple[List[List[int]], List[int]]: + """Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.""" + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + """Crops bounding boxes to the size of the input image.""" + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + """Uncrop bounding boxes by adding the crop box offset.""" + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + """Uncrop points by adding the crop box offset.""" + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor: + """Uncrop masks by padding them to the original image size.""" + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]: + """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator.""" + import cv2 # type: ignore + + assert mode in {'holes', 'islands'} + correct_holes = mode == 'holes' + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if not small_regions: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if not fill_labels: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + """Encode uncompressed RLE (run-length encoding) to COCO RLE format.""" + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle['size'] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle['counts'] = rle['counts'].decode('utf-8') # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0) + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0] diff --git a/ultralytics/vit/sam/autosize.py b/ultralytics/vit/sam/autosize.py new file mode 100644 index 0000000..d0a298c --- /dev/null +++ b/ultralytics/vit/sam/autosize.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate(image, target_size, mode='bilinear', align_corners=False, antialias=True) + + def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/ultralytics/vit/sam/build.py b/ultralytics/vit/sam/build.py new file mode 100644 index 0000000..98a5b54 --- /dev/null +++ b/ultralytics/vit/sam/build.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch + +from ...yolo.utils.downloads import attempt_download_asset +from .modules.decoders import MaskDecoder +from .modules.encoders import ImageEncoderViT, PromptEncoder +from .modules.sam import Sam +from .modules.transformer import TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + """Build and return a Segment Anything Model (SAM) h-size model.""" + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +def build_sam_vit_l(checkpoint=None): + """Build and return a Segment Anything Model (SAM) l-size model.""" + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + """Build and return a Segment Anything Model (SAM) b-size model.""" + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + """Builds the selected SAM model architecture.""" + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + attempt_download_asset(checkpoint) + with open(checkpoint, 'rb') as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam + + +sam_model_map = { + # "default": build_sam_vit_h, + 'sam_h.pt': build_sam_vit_h, + 'sam_l.pt': build_sam_vit_l, + 'sam_b.pt': build_sam_vit_b, } + + +def build_sam(ckpt='sam_b.pt'): + """Build a SAM model specified by ckpt.""" + model_builder = sam_model_map.get(ckpt) + if not model_builder: + raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}') + + return model_builder(ckpt) diff --git a/ultralytics/vit/sam/model.py b/ultralytics/vit/sam/model.py new file mode 100644 index 0000000..684b9fb --- /dev/null +++ b/ultralytics/vit/sam/model.py @@ -0,0 +1,35 @@ +# SAM model interface + +from ultralytics.yolo.cfg import get_cfg + +from .build import build_sam +from .predict import Predictor + + +class SAM: + + def __init__(self, model='sam_b.pt') -> None: + if model and not (model.endswith('.pt') or model.endswith('.pth')): + # Should raise AssertionError instead? + raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') + self.model = build_sam(model) + self.predictor = None # reuse predictor + + def predict(self, source, stream=False, **kwargs): + """Predicts and returns segmentation masks for given image or video source.""" + overrides = dict(conf=0.25, task='segment', mode='predict') + overrides.update(kwargs) # prefer kwargs + if not self.predictor: + self.predictor = Predictor(overrides=overrides) + self.predictor.setup_model(model=self.model) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, overrides) + return self.predictor(source, stream=stream) + + def train(self, **kwargs): + """Function trains models but raises an error as SAM models do not support training.""" + raise NotImplementedError("SAM models don't support training") + + def val(self, **kwargs): + """Run validation given dataset.""" + raise NotImplementedError("SAM models don't support validation") diff --git a/ultralytics/vit/sam/modules/__init__.py b/ultralytics/vit/sam/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ultralytics/vit/sam/modules/decoders.py b/ultralytics/vit/sam/modules/decoders.py new file mode 100644 index 0000000..47acca8 --- /dev/null +++ b/ultralytics/vit/sam/modules/decoders.py @@ -0,0 +1,161 @@ +from typing import List, Tuple, Type + +import torch +from torch import nn +from torch.nn import functional as F + +from ultralytics.nn.modules import LayerNorm2d + + +class MaskDecoder(nn.Module): + + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]) + + self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + mask_slice = slice(1, None) if multimask_output else slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [ + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)] + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + """Executes feedforward within the neural network module and applies activation.""" + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/ultralytics/vit/sam/modules/encoders.py b/ultralytics/vit/sam/modules/encoders.py new file mode 100644 index 0000000..4da6155 --- /dev/null +++ b/ultralytics/vit/sam/modules/encoders.py @@ -0,0 +1,582 @@ +from typing import Any, Optional, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.nn.modules import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class PromptEncoder(nn.Module): + + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + return self.mask_downscaling(masks) + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, + 1).expand(bs, -1, self.image_embedding_size[0], + self.image_embedding_size[1]) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + 'positional_encoding_gaussian_matrix', + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert (input_size is not None), 'Input size must be provided if using relative positional encoding.' + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], + hw: Tuple[int, int]) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode='linear', + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) + rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/ultralytics/vit/sam/modules/mask_generator.py b/ultralytics/vit/sam/modules/mask_generator.py new file mode 100644 index 0000000..ff17fb9 --- /dev/null +++ b/ultralytics/vit/sam/modules/mask_generator.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from ..amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh, + build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes, + is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, + uncrop_masks, uncrop_points) +from .prompt_predictor import PromptPredictor +from .sam import Sam + + +class SamAutomaticMaskGenerator: + + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = 'binary_mask', + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != (point_grids is + None), 'Exactly one of points_per_side or point_grid must be provided.' + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in {'binary_mask', 'uncompressed_rle', 'coco_rle'}, f'Unknown output_mode {output_mode}.' + if output_mode == 'coco_rle': + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = PromptPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + # TODO: Temporary implementation for compatibility + def __call__(self, image: np.ndarray, augment=False, visualize=False) -> List[Dict[str, Any]]: + return self.generate(image) + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == 'coco_rle': + mask_data['segmentations'] = [coco_encode_rle(rle) for rle in mask_data['rles']] + elif self.output_mode == 'binary_mask': + mask_data['segmentations'] = [rle_to_mask(rle) for rle in mask_data['rles']] + else: + mask_data['segmentations'] = mask_data['rles'] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data['segmentations'])): + ann = { + 'segmentation': mask_data['segmentations'][idx], + 'area': area_from_rle(mask_data['rles'][idx]), + 'bbox': box_xyxy_to_xywh(mask_data['boxes'][idx]).tolist(), + 'predicted_iou': mask_data['iou_preds'][idx].item(), + 'point_coords': [mask_data['points'][idx].tolist()], + 'stability_score': mask_data['stability_score'][idx].item(), + 'crop_box': box_xyxy_to_xywh(mask_data['crop_boxes'][idx]).tolist(), } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data['crop_boxes']) + scores = scores.to(data['boxes'].device) + keep_by_nms = batched_nms( + data['boxes'].float(), + scores, + torch.zeros_like(data['boxes'][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points, ) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data['boxes'].float(), + data['iou_preds'], + torch.zeros_like(data['boxes'][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data['boxes'] = uncrop_boxes_xyxy(data['boxes'], crop_box) + data['points'] = uncrop_points(data['points'], crop_box) + data['crop_boxes'] = torch.tensor([crop_box for _ in range(len(data['rles']))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data['iou_preds'] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data['stability_score'] = calculate_stability_score(data['masks'], self.predictor.model.mask_threshold, + self.stability_score_offset) + if self.stability_score_thresh > 0.0: + keep_mask = data['stability_score'] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data['masks'] = data['masks'] > self.predictor.model.mask_threshold + data['boxes'] = batched_mask_to_box(data['masks']) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data['boxes'], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data['masks'] = uncrop_masks(data['masks'], crop_box, orig_h, orig_w) + data['rles'] = mask_to_rle_pytorch(data['masks']) + del data['masks'] + + return data + + @staticmethod + def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data['rles']) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data['rles']: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode='holes') + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode='islands') + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data['rles'][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data['boxes'][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/ultralytics/vit/sam/modules/prompt_predictor.py b/ultralytics/vit/sam/modules/prompt_predictor.py new file mode 100644 index 0000000..82da985 --- /dev/null +++ b/ultralytics/vit/sam/modules/prompt_predictor.py @@ -0,0 +1,240 @@ +from typing import Optional, Tuple + +import numpy as np +import torch + +from ..autosize import ResizeLongestSide +from .sam import Sam + + +class PromptPredictor: + + def __init__(self, sam_model: Sam) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image(self, image: np.ndarray, image_format: str = 'RGB') -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in {'RGB', 'BGR'}, f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image(self, transformed_image: torch.Tensor, original_image_size: Tuple[int, ...]) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + if len(transformed_image.shape) != 4 \ + or transformed_image.shape[1] != 3 \ + or max(*transformed_image.shape[2:]) != self.model.image_encoder.img_size: + raise ValueError('set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}.') + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError('An image must be set with .set_image(...) before mask prediction.') + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert (point_labels is not None), 'point_labels must be supplied if point_coords is supplied.' + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError('An image must be set with .set_image(...) before mask prediction.') + + points = (point_coords, point_labels) if point_coords is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError('An image must be set with .set_image(...) to generate an embedding.') + assert self.features is not None, 'Features must exist if an image has been set.' + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/ultralytics/vit/sam/modules/sam.py b/ultralytics/vit/sam/modules/sam.py new file mode 100644 index 0000000..50a30ee --- /dev/null +++ b/ultralytics/vit/sam/modules/sam.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from .decoders import MaskDecoder +from .encoders import ImageEncoderViT, PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = 'RGB' + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if 'point_coords' in image_record: + points = (image_record['point_coords'], image_record['point_labels']) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get('boxes', None), + masks=image_record.get('mask_inputs', None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record['image'].shape[-2:], + original_size=image_record['original_size'], + ) + masks = masks > self.mask_threshold + outputs.append({ + 'masks': masks, + 'iou_predictions': iou_predictions, + 'low_res_logits': low_res_masks, }) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode='bilinear', + align_corners=False, + ) + masks = masks[..., :input_size[0], :input_size[1]] + masks = F.interpolate(masks, original_size, mode='bilinear', align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + return F.pad(x, (0, padw, 0, padh)) diff --git a/ultralytics/vit/sam/modules/transformer.py b/ultralytics/vit/sam/modules/transformer.py new file mode 100644 index 0000000..3f32b94 --- /dev/null +++ b/ultralytics/vit/sam/modules/transformer.py @@ -0,0 +1,233 @@ +import math +from typing import Tuple, Type + +import torch +from torch import Tensor, nn + +from ultralytics.nn.modules import MLPBlock + + +class TwoWayTransformer(nn.Module): + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + )) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: + """Apply self-attention and cross-attention to queries and keys and return the processed embeddings.""" + + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.' + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + """Separate the input tensor into the specified number of attention heads.""" + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + """Recombine the separated attention heads into a single tensor.""" + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """Compute the attention output given the input query, key, and value tensors.""" + + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/ultralytics/vit/sam/predict.py b/ultralytics/vit/sam/predict.py new file mode 100644 index 0000000..5bbccac --- /dev/null +++ b/ultralytics/vit/sam/predict.py @@ -0,0 +1,52 @@ +import numpy as np +import torch + +from ultralytics.yolo.engine.predictor import BasePredictor +from ultralytics.yolo.engine.results import Results +from ultralytics.yolo.utils.torch_utils import select_device + +from .modules.mask_generator import SamAutomaticMaskGenerator + + +class Predictor(BasePredictor): + + def preprocess(self, im): + """Prepares input image for inference.""" + # TODO: Only support bs=1 for now + # im = ResizeLongestSide(1024).apply_image(im[0]) + # im = torch.as_tensor(im, device=self.device) + # im = im.permute(2, 0, 1).contiguous()[None, :, :, :] + return im[0] + + def setup_model(self, model): + """Set up YOLO model with specified thresholds and device.""" + device = select_device(self.args.device) + model.eval() + self.model = SamAutomaticMaskGenerator(model.to(device), + pred_iou_thresh=self.args.conf, + box_nms_thresh=self.args.iou) + self.device = device + # TODO: Temporary settings for compatibility + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def postprocess(self, preds, path, orig_imgs): + """Postprocesses inference output predictions to create detection masks for objects.""" + names = dict(enumerate(list(range(len(preds))))) + results = [] + # TODO + for i, pred in enumerate([preds]): + masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0)) + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs + path = self.batch[0] + img_path = path[i] if isinstance(path, list) else path + results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks)) + return results + + # def __call__(self, source=None, model=None, stream=False): + # frame = cv2.imread(source) + # preds = self.model.generate(frame) + # return self.postprocess(preds, source, frame) diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index 847fe8b..1c46ac4 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -25,7 +25,6 @@ verbose: True # whether to print verbose output seed: 0 # random seed for reproducibility deterministic: True # whether to enable deterministic mode single_cls: False # train multi-class data as single-class -image_weights: False # use weighted image selection for training rect: False # rectangular training if mode='train' or rectangular validation if mode='val' cos_lr: False # use cosine learning rate scheduler close_mosaic: 0 # (int) disable mosaic augmentation for final epochs diff --git a/ultralytics/yolo/data/__init__.py b/ultralytics/yolo/data/__init__.py index 539c20c..f1d9dee 100644 --- a/ultralytics/yolo/data/__init__.py +++ b/ultralytics/yolo/data/__init__.py @@ -1,9 +1,9 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from .base import BaseDataset -from .build import build_classification_dataloader, build_dataloader, load_inference_source +from .build import build_dataloader, build_yolo_dataset, load_inference_source from .dataset import ClassificationDataset, SemanticDataset, YOLODataset from .dataset_wrappers import MixAndRectDataset __all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset', - 'build_classification_dataloader', 'build_dataloader', 'load_inference_source') + 'build_yolo_dataset', 'build_dataloader', 'load_inference_source') diff --git a/ultralytics/yolo/data/annotator.py b/ultralytics/yolo/data/annotator.py new file mode 100644 index 0000000..ec52194 --- /dev/null +++ b/ultralytics/yolo/data/annotator.py @@ -0,0 +1,42 @@ +from pathlib import Path + +from ultralytics import YOLO +from ultralytics.vit.sam import PromptPredictor, build_sam +from ultralytics.yolo.utils.torch_utils import select_device + + +def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): + device = select_device(device) + det_model = YOLO(det_model) + sam_model = build_sam(sam_model) + det_model.to(device) + sam_model.to(device) + + if not output_dir: + output_dir = Path(str(data)).parent / 'labels' + Path(output_dir).mkdir(exist_ok=True, parents=True) + + prompt_predictor = PromptPredictor(sam_model) + det_results = det_model(data, stream=True) + + for result in det_results: + boxes = result.boxes.xyxy # Boxes object for bbox outputs + class_ids = result.boxes.cls.int().tolist() # noqa + prompt_predictor.set_image(result.orig_img) + masks, _, _ = prompt_predictor.predict_torch( + point_coords=None, + point_labels=None, + boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]), + multimask_output=False, + ) + + result.update(masks=masks.squeeze(1)) + segments = result.masks.xyn # noqa + + with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f: + for i in range(len(segments)): + s = segments[i] + if len(s) == 0: + continue + segment = map(str, segments[i].reshape(-1).tolist()) + f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index d80a232..ae61667 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -24,17 +24,17 @@ class BaseDataset(Dataset): Base dataset class for loading and processing image data. Args: - img_path (str): Image path. - imgsz (int): Target image size for resizing. Default is 640. - cache (bool): Cache images in memory or on disk for faster loading. Default is False. - augment (bool): Apply data augmentation. Default is True. - hyp (dict): Dictionary of hyperparameters for data augmentation. Default is None. - prefix (str): Prefix for file paths. Default is an empty string. - rect (bool): Enable rectangular training. Default is False. - batch_size (int): Batch size for rectangular training. Default is None. - stride (int): Stride for rectangular training. Default is 32. - pad (float): Padding for rectangular training. Default is 0.5. - single_cls (bool): Use a single class for all labels. Default is False. + img_path (str): Path to the folder containing images. + imgsz (int, optional): Image size. Defaults to 640. + cache (bool, optional): Cache images to RAM or disk during training. Defaults to False. + augment (bool, optional): If True, data augmentation is applied. Defaults to True. + hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None. + prefix (str, optional): Prefix to print in log messages. Defaults to ''. + rect (bool, optional): If True, rectangular training is used. Defaults to False. + batch_size (int, optional): Size of batches. Defaults to None. + stride (int, optional): Stride. Defaults to 32. + pad (float, optional): Padding. Defaults to 0.0. + single_cls (bool, optional): If True, single class training is used. Defaults to False. classes (list): List of included classes. Default is None. Attributes: diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py index 919996a..c5f5e2a 100644 --- a/ultralytics/yolo/data/build.py +++ b/ultralytics/yolo/data/build.py @@ -14,9 +14,8 @@ from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImage from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.utils.checks import check_file -from ..utils import LOGGER, RANK, colorstr -from ..utils.torch_utils import torch_distributed_zero_first -from .dataset import ClassificationDataset, YOLODataset +from ..utils import RANK, colorstr +from .dataset import YOLODataset from .utils import PIN_MEMORY @@ -70,34 +69,31 @@ def seed_worker(worker_id): # noqa random.seed(worker_seed) -def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, rank=-1, mode='train'): - """Return an InfiniteDataLoader or DataLoader for training or validation set.""" - assert mode in ['train', 'val'] - shuffle = mode == 'train' - if cfg.rect and shuffle: - LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") - shuffle = False - with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP - dataset = YOLODataset( - img_path=img_path, - imgsz=cfg.imgsz, - batch_size=batch, - augment=mode == 'train', # augmentation - hyp=cfg, # TODO: probably add a get_hyps_from_cfg function - rect=cfg.rect or rect, # rectangular batches - cache=cfg.cache or None, - single_cls=cfg.single_cls or False, - stride=int(stride), - pad=0.0 if mode == 'train' else 0.5, - prefix=colorstr(f'{mode}: '), - use_segments=cfg.task == 'segment', - use_keypoints=cfg.task == 'pose', - classes=cfg.classes, - data=data_info) +def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False, stride=32): + """Build YOLO Dataset""" + dataset = YOLODataset( + img_path=img_path, + imgsz=cfg.imgsz, + batch_size=batch, + augment=mode == 'train', # augmentation + hyp=cfg, # TODO: probably add a get_hyps_from_cfg function + rect=cfg.rect or rect, # rectangular batches + cache=cfg.cache or None, + single_cls=cfg.single_cls or False, + stride=int(stride), + pad=0.0 if mode == 'train' else 0.5, + prefix=colorstr(f'{mode}: '), + use_segments=cfg.task == 'segment', + use_keypoints=cfg.task == 'pose', + classes=cfg.classes, + data=data_info) + return dataset + +def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): + """Return an InfiniteDataLoader or DataLoader for training or validation set.""" batch = min(batch, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices - workers = cfg.workers if mode == 'train' else cfg.workers * 2 nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) generator = torch.Generator() @@ -110,36 +106,7 @@ def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, ran pin_memory=PIN_MEMORY, collate_fn=getattr(dataset, 'collate_fn', None), worker_init_fn=seed_worker, - generator=generator), dataset - - -# Build classification -# TODO: using cfg like `build_dataloader` -def build_classification_dataloader(path, - imgsz=224, - batch_size=16, - augment=True, - cache=False, - rank=-1, - workers=8, - shuffle=True): - """Returns Dataloader object to be used with YOLOv5 Classifier.""" - with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP - dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache) - batch_size = min(batch_size, len(dataset)) - nd = torch.cuda.device_count() - nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) - sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) - generator = torch.Generator() - generator.manual_seed(6148914691236517205 + RANK) - return InfiniteDataLoader(dataset, - batch_size=batch_size, - shuffle=shuffle and sampler is None, - num_workers=nw, - sampler=sampler, - pin_memory=PIN_MEMORY, - worker_init_fn=seed_worker, - generator=generator) # or DataLoader(persistent_workers=True) + generator=generator) def check_source(source): @@ -168,7 +135,7 @@ def check_source(source): return source, webcam, screenshot, from_img, in_memory, tensor -def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True): +def load_inference_source(source=None, imgsz=640, vid_stride=1): """ Loads an inference source for object detection and applies necessary transformations. @@ -192,23 +159,13 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, elif in_memory: dataset = source elif webcam: - dataset = LoadStreams(source, - imgsz=imgsz, - stride=stride, - auto=auto, - transforms=transforms, - vid_stride=vid_stride) + dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride) elif screenshot: - dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms) + dataset = LoadScreenshots(source, imgsz=imgsz) elif from_img: - dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms) + dataset = LoadPilAndNumpy(source, imgsz=imgsz) else: - dataset = LoadImages(source, - imgsz=imgsz, - stride=stride, - auto=auto, - transforms=transforms, - vid_stride=vid_stride) + dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride) # Attach source types to the dataset setattr(dataset, 'source_type', source_type) diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py index 26d3211..d6aca45 100644 --- a/ultralytics/yolo/data/dataloaders/stream_loaders.py +++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py @@ -15,7 +15,6 @@ import requests import torch from PIL import Image -from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops from ultralytics.yolo.utils.checks import check_requirements @@ -31,12 +30,11 @@ class SourceTypes: class LoadStreams: # YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` - def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): + def __init__(self, sources='file.streams', imgsz=640, vid_stride=1): """Initialize instance variables and check for consistent input stream shapes.""" torch.backends.cudnn.benchmark = True # faster for fixed-size inference self.mode = 'stream' self.imgsz = imgsz - self.stride = stride self.vid_stride = vid_stride # video frame-rate stride sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] n = len(sources) @@ -72,10 +70,6 @@ class LoadStreams: LOGGER.info('') # newline # Check for common shapes - s = np.stack([LetterBox(imgsz, auto, stride=stride)(image=x).shape for x in self.imgs]) - self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal - self.auto = auto and self.rect - self.transforms = transforms # optional self.bs = self.__len__() if not self.rect: @@ -110,14 +104,7 @@ class LoadStreams: raise StopIteration im0 = self.imgs.copy() - if self.transforms: - im = np.stack([self.transforms(x) for x in im0]) # transforms - else: - im = np.stack([LetterBox(self.imgsz, self.auto, stride=self.stride)(image=x) for x in im0]) - im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW - im = np.ascontiguousarray(im) # contiguous - - return self.sources, im, im0, None, '' + return self.sources, im0, None, '' def __len__(self): """Return the length of the sources object.""" @@ -126,7 +113,7 @@ class LoadStreams: class LoadScreenshots: # YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen` - def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None): + def __init__(self, source, imgsz=640): """source = [screen_number left top width height] (pixels).""" check_requirements('mss') import mss # noqa @@ -140,9 +127,6 @@ class LoadScreenshots: elif len(params) == 5: self.screen, left, top, width, height = (int(x) for x in params) self.imgsz = imgsz - self.stride = stride - self.transforms = transforms - self.auto = auto self.mode = 'stream' self.frame = 0 self.sct = mss.mss() @@ -165,19 +149,13 @@ class LoadScreenshots: im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: ' - if self.transforms: - im = self.transforms(im0) # transforms - else: - im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0) - im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB - im = np.ascontiguousarray(im) # contiguous self.frame += 1 - return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s + return str(self.screen), im0, None, s # screen, img, original img, im0s, s class LoadImages: # YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4` - def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): + def __init__(self, path, imgsz=640, vid_stride=1): """Initialize the Dataloader and raise FileNotFoundError if file not found.""" if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line path = Path(path).read_text().rsplit() @@ -198,13 +176,10 @@ class LoadImages: ni, nv = len(images), len(videos) self.imgsz = imgsz - self.stride = stride self.files = images + videos self.nf = ni + nv # number of files self.video_flag = [False] * ni + [True] * nv self.mode = 'image' - self.auto = auto - self.transforms = transforms # optional self.vid_stride = vid_stride # video frame-rate stride self.bs = 1 if any(videos): @@ -254,14 +229,7 @@ class LoadImages: raise FileNotFoundError(f'Image Not Found {path}') s = f'image {self.count}/{self.nf} {path}: ' - if self.transforms: - im = self.transforms(im0) # transforms - else: - im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0) - im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB - im = np.ascontiguousarray(im) # contiguous - - return path, im, im0, self.cap, s + return [path], [im0], self.cap, s def _new_video(self, path): """Create a new video capture object.""" @@ -290,16 +258,13 @@ class LoadImages: class LoadPilAndNumpy: - def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None): + def __init__(self, im0, imgsz=640): """Initialize PIL and Numpy Dataloader.""" if not isinstance(im0, list): im0 = [im0] self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] self.im0 = [self._single_check(im) for im in im0] self.imgsz = imgsz - self.stride = stride - self.auto = auto - self.transforms = transforms self.mode = 'image' # Generate fake paths self.bs = len(self.im0) @@ -315,16 +280,6 @@ class LoadPilAndNumpy: im = np.ascontiguousarray(im) # contiguous return im - def _single_preprocess(self, im, auto): - """Preprocesses a single image for inference.""" - if self.transforms: - im = self.transforms(im) # transforms - else: - im = LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=im) - im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB - im = np.ascontiguousarray(im) # contiguous - return im - def __len__(self): """Returns the length of the 'im0' attribute.""" return len(self.im0) @@ -333,11 +288,8 @@ class LoadPilAndNumpy: """Returns batch paths, images, processed images, None, ''.""" if self.count == 1: # loop only once as it's batch inference raise StopIteration - auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto - im = [self._single_preprocess(im, auto) for im in self.im0] - im = np.stack(im, 0) if len(im) > 1 else im[0][None] self.count += 1 - return self.paths, im, self.im0, None, '' + return self.paths, self.im0, None, '' def __iter__(self): """Enables iteration for class LoadPilAndNumpy.""" @@ -362,7 +314,7 @@ class LoadTensor: if self.count == 1: raise StopIteration self.count += 1 - return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, '' + return None, self.im0, None, '' # self.paths, im, self.im0, None, '' def __len__(self): """Returns the batch size.""" diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index 3c8c43f..6d6275a 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -21,21 +21,9 @@ class YOLODataset(BaseDataset): Dataset class for loading object detection and/or segmentation labels in YOLO format. Args: - img_path (str): Path to the folder containing images. - imgsz (int, optional): Image size. Defaults to 640. - cache (bool, optional): Cache images to RAM or disk during training. Defaults to False. - augment (bool, optional): If True, data augmentation is applied. Defaults to True. - hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None. - prefix (str, optional): Prefix to print in log messages. Defaults to ''. - rect (bool, optional): If True, rectangular training is used. Defaults to False. - batch_size (int, optional): Size of batches. Defaults to None. - stride (int, optional): Stride. Defaults to 32. - pad (float, optional): Padding. Defaults to 0.0. - single_cls (bool, optional): If True, single class training is used. Defaults to False. + data (dict, optional): A dataset YAML dictionary. Defaults to None. use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False. use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False. - data (dict, optional): A dataset YAML dictionary. Defaults to None. - classes (list): List of included classes. Default is None. Returns: (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. @@ -43,28 +31,12 @@ class YOLODataset(BaseDataset): cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8 rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] - def __init__(self, - img_path, - imgsz=640, - cache=False, - augment=True, - hyp=None, - prefix='', - rect=False, - batch_size=None, - stride=32, - pad=0.0, - single_cls=False, - use_segments=False, - use_keypoints=False, - data=None, - classes=None): + def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs): self.use_segments = use_segments self.use_keypoints = use_keypoints self.data = data assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' - super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls, - classes) + super().__init__(*args, **kwargs) def cache_labels(self, path=Path('./labels.cache')): """Cache dataset labels, check images and read shapes. diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index efa938d..c4d811b 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -453,7 +453,7 @@ class YOLO: reduction_factor=3) # Define the callbacks for the hyperparameter search - tuner_callbacks = [WandbLoggerCallback(project='yolov8_tune') if wandb else None] + tuner_callbacks = [WandbLoggerCallback(project='yolov8_tune')] if wandb else [] # Create the Ray Tune hyperparameter search tuner tuner = tune.Tuner(trainable_with_resources, diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 79b31c3..f27a8f4 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -31,11 +31,13 @@ import platform from pathlib import Path import cv2 +import numpy as np +import torch from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data import load_inference_source -from ultralytics.yolo.data.augment import classify_transforms +from ultralytics.yolo.data.augment import LetterBox, classify_transforms from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils.checks import check_imgsz, check_imshow from ultralytics.yolo.utils.files import increment_path @@ -106,9 +108,23 @@ class BasePredictor: self.callbacks = _callbacks or callbacks.get_default_callbacks() callbacks.add_integration_callbacks(self) - def preprocess(self, img): - """Prepares input image before inference.""" - pass + def preprocess(self, im): + """Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + """ + if not isinstance(im, torch.Tensor): + auto = all(x.shape == im[0].shape for x in im) and self.model.pt + im = np.stack([LetterBox(self.imgsz, auto=auto, stride=self.model.stride)(image=x) for x in im]) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + # NOTE: assuming im with (b, 3, h, w) if it's a tensor + img = im.to(self.device) + img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 + img /= 255 # 0 - 255 to 0.0 - 1.0 + return img def write_results(self, idx, results, batch): """Write inference results to a file or directory.""" @@ -165,16 +181,9 @@ class BasePredictor: def setup_source(self, source): """Sets up source and inference mode.""" self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size - if self.args.task == 'classify': - transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0])) - else: # predict, segment - transforms = None - self.dataset = load_inference_source(source=source, - transforms=transforms, - imgsz=self.imgsz, - vid_stride=self.args.vid_stride, - stride=self.model.stride, - auto=self.model.pt) + self.transforms = getattr(self.model.model, 'transforms', classify_transforms( + self.imgsz[0])) if self.args.task == 'classify' else None + self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride) self.source_type = self.dataset.source_type if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams len(self.dataset) > 1000 or # images @@ -207,14 +216,12 @@ class BasePredictor: for batch in self.dataset: self.run_callbacks('on_predict_batch_start') self.batch = batch - path, im, im0s, vid_cap, s = batch + path, im0s, vid_cap, s = batch visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False # Preprocess with self.dt[0]: - im = self.preprocess(im) - if len(im.shape) == 3: - im = im[None] # expand for batch dim + im = self.preprocess(im0s) # Inference with self.dt[1]: @@ -226,7 +233,7 @@ class BasePredictor: self.run_callbacks('on_predict_postprocess_end') # Visualize, save, write results - n = len(im) + n = len(im0s) for i in range(n): self.results[i].speed = { 'preprocess': self.dt[0].dt * 1E3 / n, @@ -234,8 +241,7 @@ class BasePredictor: 'postprocess': self.dt[2].dt * 1E3 / n} if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor continue - p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \ - else (path, im0s.copy()) + p, im0 = path[i], im0s[i].copy() p = Path(p) if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index f066775..c6a3787 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -213,7 +213,8 @@ class Results(SimpleClass): img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( 2, 0, 1).flip(0).contiguous() / 255 - annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu) + idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) + annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu) if pred_boxes and show_boxes: for d in reversed(pred_boxes): diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 6a08e4d..557e8be 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -481,6 +481,10 @@ class BaseTrainer: """ raise NotImplementedError('get_dataloader function not implemented in trainer') + def build_dataset(self, img_path, mode='train', batch=None): + """Build dataset""" + raise NotImplementedError('build_dataset function not implemented in trainer') + def criterion(self, preds, batch): """ Returns loss and individual loss items as Tensor. diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 96c4086..3d57e90 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -207,6 +207,10 @@ class BaseValidator: """Get data loader from dataset path and batch size.""" raise NotImplementedError('get_dataloader function not implemented for this validator') + def build_dataset(self, img_path): + """Build dataset""" + raise NotImplementedError('build_dataset function not implemented in validator') + def preprocess(self, batch): """Preprocesses an input batch.""" return batch diff --git a/ultralytics/yolo/utils/callbacks/comet.py b/ultralytics/yolo/utils/callbacks/comet.py index 2d55df1..f35eed2 100644 --- a/ultralytics/yolo/utils/callbacks/comet.py +++ b/ultralytics/yolo/utils/callbacks/comet.py @@ -13,20 +13,8 @@ try: except (ImportError, AssertionError): comet_ml = None -COMET_MODE = os.getenv('COMET_MODE', 'online') -COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'YOLOv8') -# Determines how many batches of image predictions to log from the validation set -COMET_EVAL_BATCH_LOGGING_INTERVAL = int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1)) -# Determines whether to log confusion matrix every evaluation epoch -COMET_EVAL_LOG_CONFUSION_MATRIX = (os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'true').lower() == 'true') -# Determines whether to log image predictions every evaluation epoch -COMET_EVAL_LOG_IMAGE_PREDICTIONS = (os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true') -COMET_MAX_IMAGE_PREDICTIONS = int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100)) - # Ensures certain logging functions only run for supported tasks COMET_SUPPORTED_TASKS = ['detect'] -# Scales reported confidence scores (0.0-1.0) by this value -COMET_MAX_CONFIDENCE_SCORE = int(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100)) # Names of plots created by YOLOv8 that are logged to Comet EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix' @@ -35,6 +23,35 @@ LABEL_PLOT_NAMES = 'labels', 'labels_correlogram' _comet_image_prediction_count = 0 +def _get_comet_mode(): + return os.getenv('COMET_MODE', 'online') + + +def _get_comet_model_name(): + return os.getenv('COMET_MODEL_NAME', 'YOLOv8') + + +def _get_eval_batch_logging_interval(): + return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1)) + + +def _get_max_image_predictions_to_log(): + return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100)) + + +def _scale_confidence_score(score): + scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0)) + return score * scale + + +def _should_log_confusion_matrix(): + return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'true').lower() == 'true' + + +def _should_log_image_predictions(): + return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true' + + def _get_experiment_type(mode, project_name): """Return an experiment based on mode and project name.""" if mode == 'offline': @@ -48,13 +65,14 @@ def _create_experiment(args): if RANK not in (-1, 0): return try: - experiment = _get_experiment_type(COMET_MODE, args.project) + comet_mode = _get_comet_mode() + experiment = _get_experiment_type(comet_mode, args.project) experiment.log_parameters(vars(args)) experiment.log_others({ - 'eval_batch_logging_interval': COMET_EVAL_BATCH_LOGGING_INTERVAL, - 'log_confusion_matrix': COMET_EVAL_LOG_CONFUSION_MATRIX, - 'log_image_predictions': COMET_EVAL_LOG_IMAGE_PREDICTIONS, - 'max_image_predictions': COMET_MAX_IMAGE_PREDICTIONS, }) + 'eval_batch_logging_interval': _get_eval_batch_logging_interval(), + 'log_confusion_matrix': _should_log_confusion_matrix(), + 'log_image_predictions': _should_log_image_predictions(), + 'max_image_predictions': _get_max_image_predictions_to_log(), }) experiment.log_other('Created from', 'yolov8') except Exception as e: @@ -74,7 +92,12 @@ def _fetch_trainer_metadata(trainer): save_interval = curr_epoch % save_period == 0 save_assets = save and save_period > 0 and save_interval and not final_epoch - return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch) + return dict( + curr_epoch=curr_epoch, + curr_step=curr_step, + save_assets=save_assets, + final_epoch=final_epoch, + ) def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad): @@ -117,7 +140,10 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c data = [] for box, label in zip(bboxes, cls_labels): box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad) - data.append({'boxes': [box], 'label': f'gt_{label}', 'score': COMET_MAX_CONFIDENCE_SCORE}) + data.append({ + 'boxes': [box], + 'label': f'gt_{label}', + 'score': _scale_confidence_score(1.0), }) return {'name': 'ground_truth', 'data': data} @@ -135,7 +161,7 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab data = [] for prediction in predictions: boxes = prediction['bbox'] - score = prediction['score'] * COMET_MAX_CONFIDENCE_SCORE + score = _scale_confidence_score(prediction['score']) cls_label = prediction['category_id'] if class_label_map: cls_label = str(class_label_map[cls_label]) @@ -207,13 +233,16 @@ def _log_image_predictions(experiment, validator, curr_step): dataloader = validator.dataloader class_label_map = validator.names + batch_logging_interval = _get_eval_batch_logging_interval() + max_image_predictions = _get_max_image_predictions_to_log() + for batch_idx, batch in enumerate(dataloader): - if (batch_idx + 1) % COMET_EVAL_BATCH_LOGGING_INTERVAL != 0: + if (batch_idx + 1) % batch_logging_interval != 0: continue image_paths = batch['im_file'] for img_idx, image_path in enumerate(image_paths): - if _comet_image_prediction_count >= COMET_MAX_IMAGE_PREDICTIONS: + if _comet_image_prediction_count >= max_image_predictions: return image_path = Path(image_path) @@ -244,8 +273,9 @@ def _log_plots(experiment, trainer): def _log_model(experiment, trainer): """Log the best-trained model to Comet.ml.""" + model_name = _get_comet_model_name() experiment.log_model( - COMET_MODEL_NAME, + model_name, file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True, @@ -255,7 +285,8 @@ def _log_model(experiment, trainer): def on_pretrain_routine_start(trainer): """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine.""" experiment = comet_ml.get_global_experiment() - if not experiment: + is_alive = getattr(experiment, 'alive', False) + if not experiment or not is_alive: _create_experiment(trainer.args) @@ -296,16 +327,16 @@ def on_fit_epoch_end(trainer): model_info = { 'model/parameters': get_num_params(trainer.model), 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} + 'model/speed(ms)': round(trainer.validator.speed['inference'], 3), } experiment.log_metrics(model_info, step=curr_step, epoch=curr_epoch) if not save_assets: return _log_model(experiment, trainer) - if COMET_EVAL_LOG_CONFUSION_MATRIX: + if _should_log_confusion_matrix(): _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) - if COMET_EVAL_LOG_IMAGE_PREDICTIONS: + if _should_log_image_predictions(): _log_image_predictions(experiment, trainer.validator, curr_step) diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 231f7f9..a8ead2b 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -17,7 +17,8 @@ from ultralytics.yolo.utils import LOGGER, checks, clean_url, emojis, is_online, GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \ [f'yolov5{k}u.pt' for k in 'nsmlx'] + \ - [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \ + [f'sam_{k}.pt' for k in 'bl'] GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES] diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py index 774c3ce..a351db2 100644 --- a/ultralytics/yolo/utils/plotting.py +++ b/ultralytics/yolo/utils/plotting.py @@ -192,14 +192,27 @@ class Annotator: """Add rectangle to image (PIL-only).""" self.draw.rectangle(xy, fill, outline, width) - def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'): + def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False): """Adds text to an image using PIL or cv2.""" if anchor == 'bottom': # start y from font bottom w, h = self.font.getsize(text) # text width, height xy[1] += 1 - h if self.pil: + if box_style: + w, h = self.font.getsize(text) + self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color) + # Using `txt_color` for background and draw fg with white color + txt_color = (255, 255, 255) self.draw.text(xy, text, fill=txt_color, font=self.font) else: + if box_style: + tf = max(self.lw - 1, 1) # font thickness + w, h = cv2.getTextSize(text, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height + outside = xy[1] - h >= 3 + p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3 + cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled + # Using `txt_color` for background and draw fg with white color + txt_color = (255, 255, 255) tf = max(self.lw - 1, 1) # font thickness cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA) @@ -283,7 +296,7 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, def plot_images(images, batch_idx, cls, - bboxes, + bboxes=np.zeros(0, dtype=np.float32), masks=np.zeros(0, dtype=np.uint8), kpts=np.zeros((0, 51), dtype=np.float32), paths=None, @@ -337,27 +350,33 @@ def plot_images(images, annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames if len(cls) > 0: idx = batch_idx == i - - boxes = xywh2xyxy(bboxes[idx, :4]).T classes = cls[idx].astype('int') - labels = bboxes.shape[1] == 4 # labels if no conf column - conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred) - - if boxes.shape[1]: - if boxes.max() <= 1.01: # if normalized with tolerance 0.01 - boxes[[0, 2]] *= w # scale to pixels - boxes[[1, 3]] *= h - elif scale < 1: # absolute coords need scale if image scales - boxes *= scale - boxes[[0, 2]] += x - boxes[[1, 3]] += y - for j, box in enumerate(boxes.T.tolist()): - c = classes[j] - color = colors(c) - c = names.get(c, c) if names else c - if labels or conf[j] > 0.25: # 0.25 conf thresh - label = f'{c}' if labels else f'{c} {conf[j]:.1f}' - annotator.box_label(box, label, color=color) + + if len(bboxes): + boxes = xywh2xyxy(bboxes[idx, :4]).T + labels = bboxes.shape[1] == 4 # labels if no conf column + conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred) + + if boxes.shape[1]: + if boxes.max() <= 1.01: # if normalized with tolerance 0.01 + boxes[[0, 2]] *= w # scale to pixels + boxes[[1, 3]] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes *= scale + boxes[[0, 2]] += x + boxes[[1, 3]] += y + for j, box in enumerate(boxes.T.tolist()): + c = classes[j] + color = colors(c) + c = names.get(c, c) if names else c + if labels or conf[j] > 0.25: # 0.25 conf thresh + label = f'{c}' if labels else f'{c} {conf[j]:.1f}' + annotator.box_label(box, label, color=color) + elif len(classes): + for c in classes: + color = colors(c) + c = names.get(c, c) if names else c + annotator.text((x, y), f'{c}', txt_color=color, box_style=True) # Plot keypoints if len(kpts): @@ -403,11 +422,14 @@ def plot_images(images, @plt_settings() -def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False): +def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False): """Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv').""" import pandas as pd save_dir = Path(file).parent if file else Path(dir) - if segment: + if classify: + fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True) + index = [1, 4, 2, 3] + elif segment: fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12] elif pose: diff --git a/ultralytics/yolo/utils/tal.py b/ultralytics/yolo/utils/tal.py index 09868bd..d0ceb6c 100644 --- a/ultralytics/yolo/utils/tal.py +++ b/ultralytics/yolo/utils/tal.py @@ -225,7 +225,7 @@ class TaskAlignedAssigner(nn.Module): target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx] # Assigned target scores - target_labels.clamp(0) + target_labels.clamp_(0) target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80) fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index 363448c..fb486e2 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -9,8 +9,14 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT class ClassificationPredictor(BasePredictor): + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + super().__init__(cfg, overrides, _callbacks) + self.args.task = 'classify' + def preprocess(self, img): """Converts input image to model-compatible data type.""" + if not isinstance(img, torch.Tensor): + img = torch.stack([self.transforms(im) for im in img], dim=0) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 @@ -19,7 +25,7 @@ class ClassificationPredictor(BasePredictor): results = [] for i, pred in enumerate(preds): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs - path, _, _, _, _ = self.batch + path = self.batch[0] img_path = path[i] if isinstance(path, list) else path results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 3a3f284..ef15692 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -5,10 +5,11 @@ import torchvision from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight from ultralytics.yolo import v8 -from ultralytics.yolo.data import build_classification_dataloader +from ultralytics.yolo.data import ClassificationDataset, build_dataloader from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr -from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer +from ultralytics.yolo.utils.plotting import plot_images, plot_results +from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first class ClassificationTrainer(BaseTrainer): @@ -71,14 +72,16 @@ class ClassificationTrainer(BaseTrainer): return # dont return ckpt. Classification doesn't support resume + def build_dataset(self, img_path, mode='train'): + dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train') + return dataset + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" - loader = build_classification_dataloader(path=dataset_path, - imgsz=self.args.imgsz, - batch_size=batch_size if mode == 'train' else (batch_size * 2), - augment=mode == 'train', - rank=rank, - workers=self.args.workers) + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode) + + loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) # Attach inference transforms if mode != 'train': if is_parallel(self.model): @@ -124,6 +127,10 @@ class ClassificationTrainer(BaseTrainer): """Resumes training from a given checkpoint.""" pass + def plot_metrics(self): + """Plots metrics from a CSV file.""" + plot_results(file=self.csv, classify=True) # save results.png + def final_eval(self): """Evaluate trained model and save validation results.""" for f in self.last, self.best: @@ -138,6 +145,13 @@ class ClassificationTrainer(BaseTrainer): # self.run_callbacks('on_fit_epoch_end') LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") + def plot_training_samples(self, batch, ni): + """Plots training samples with their annotations.""" + plot_images(images=batch['img'], + batch_idx=torch.arange(len(batch['img'])), + cls=batch['cls'].squeeze(-1), + fname=self.save_dir / f'train_batch{ni}.jpg') + def train(cfg=DEFAULT_CFG, use_python=False): """Train the YOLO classification model.""" diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index 6722dfc..d6a7ab2 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -1,9 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.yolo.data import build_classification_dataloader +import torch + +from ultralytics.yolo.data import ClassificationDataset, build_dataloader from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix +from ultralytics.yolo.utils.plotting import plot_images class ClassificationValidator(BaseValidator): @@ -52,20 +55,36 @@ class ClassificationValidator(BaseValidator): self.metrics.process(self.targets, self.pred) return self.metrics.results_dict + def build_dataset(self, img_path): + dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False) + return dataset + def get_dataloader(self, dataset_path, batch_size): """Builds and returns a data loader for classification tasks with given parameters.""" - return build_classification_dataloader(path=dataset_path, - imgsz=self.args.imgsz, - batch_size=batch_size, - augment=False, - shuffle=False, - workers=self.args.workers) + dataset = self.build_dataset(dataset_path) + return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) def print_results(self): """Prints evaluation metrics for YOLO object detection model.""" pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) + def plot_val_samples(self, batch, ni): + """Plot validation image samples.""" + plot_images(images=batch['img'], + batch_idx=torch.arange(len(batch['img'])), + cls=batch['cls'].squeeze(-1), + fname=self.save_dir / f'val_batch{ni}_labels.jpg', + names=self.names) + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images(batch['img'], + batch_idx=torch.arange(len(batch['img'])), + cls=torch.argmax(preds, dim=1), + fname=self.save_dir / f'val_batch{ni}_pred.jpg', + names=self.names) # pred + def val(cfg=DEFAULT_CFG, use_python=False): """Validate YOLO model using custom data.""" diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 68f0937..31e8a9f 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -9,13 +9,6 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops class DetectionPredictor(BasePredictor): - def preprocess(self, img): - """Convert an image to PyTorch tensor and normalize pixel values.""" - img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) - img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 - img /= 255 # 0 - 255 to 0.0 - 1.0 - return img - def postprocess(self, preds, img, orig_imgs): """Postprocesses predictions and returns a list of Results objects.""" preds = ops.non_max_suppression(preds, @@ -30,7 +23,7 @@ class DetectionPredictor(BasePredictor): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs if not isinstance(orig_imgs, torch.Tensor): pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) - path, _, _, _, _ = self.batch + path = self.batch[0] img_path = path[i] if isinstance(path, list) else path results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) return results diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 077bd2c..2965089 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -7,41 +7,63 @@ import torch.nn as nn from ultralytics.nn.tasks import DetectionModel from ultralytics.yolo import v8 -from ultralytics.yolo.data import build_dataloader +from ultralytics.yolo.data import build_dataloader, build_yolo_dataset from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.engine.trainer import BaseTrainer -from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr from ultralytics.yolo.utils.loss import BboxLoss from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors -from ultralytics.yolo.utils.torch_utils import de_parallel +from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first # BaseTrainer python usage class DetectionTrainer(BaseTrainer): + def build_dataset(self, img_path, mode='train', batch=None): + """Build YOLO Dataset + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch_size (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs) + def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'): """TODO: manage splits differently.""" # Calculate stride - check if model is initialized - gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) - return create_dataloader(path=dataset_path, - imgsz=self.args.imgsz, - batch_size=batch_size, - stride=gs, - hyp=vars(self.args), - augment=mode == 'train', - cache=self.args.cache, - pad=0 if mode == 'train' else 0.5, - rect=self.args.rect or mode == 'val', - rank=rank, - workers=self.args.workers, - close_mosaic=self.args.close_mosaic != 0, - prefix=colorstr(f'{mode}: '), - shuffle=mode == 'train', - seed=self.args.seed)[0] if self.args.v5loader else \ - build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode, - rect=mode == 'val', data_info=self.data)[0] + if self.args.v5loader: + LOGGER.warning("WARNING ⚠️ 'v5loader' feature is deprecated and will be removed soon. You can train using " + 'the default YOLOv8 dataloader instead, no argument is needed.') + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return create_dataloader(path=dataset_path, + imgsz=self.args.imgsz, + batch_size=batch_size, + stride=gs, + hyp=vars(self.args), + augment=mode == 'train', + cache=self.args.cache, + pad=0 if mode == 'train' else 0.5, + rect=self.args.rect or mode == 'val', + rank=rank, + workers=self.args.workers, + close_mosaic=self.args.close_mosaic != 0, + prefix=colorstr(f'{mode}: '), + shuffle=mode == 'train', + seed=self.args.seed)[0] + assert mode in ['train', 'val'] + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode, batch_size) + shuffle = mode == 'train' + if getattr(dataset, 'rect', False) and shuffle: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") + shuffle = False + workers = self.args.workers if mode == 'train' else self.args.workers * 2 + dataloader = build_dataloader(dataset, batch_size, workers, shuffle, rank) + return dataloader def preprocess_batch(self, batch): """Preprocesses a batch of images by scaling and converting to float.""" diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 1304186..907a530 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -6,7 +6,7 @@ from pathlib import Path import numpy as np import torch -from ultralytics.yolo.data import build_dataloader +from ultralytics.yolo.data import build_dataloader, build_yolo_dataset from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr, ops @@ -171,24 +171,40 @@ class DetectionValidator(BaseValidator): correct[matches[:, 1].astype(int), i] = True return torch.tensor(correct, dtype=torch.bool, device=detections.device) + def build_dataset(self, img_path, mode='val', batch=None): + """Build YOLO Dataset + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch_size (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride if self.model else 0), 32) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=gs) + def get_dataloader(self, dataset_path, batch_size): """TODO: manage splits differently.""" # Calculate stride - check if model is initialized - gs = max(int(de_parallel(self.model).stride if self.model else 0), 32) - return create_dataloader(path=dataset_path, - imgsz=self.args.imgsz, - batch_size=batch_size, - stride=gs, - hyp=vars(self.args), - cache=False, - pad=0.5, - rect=self.args.rect, - workers=self.args.workers, - prefix=colorstr(f'{self.args.mode}: '), - shuffle=False, - seed=self.args.seed)[0] if self.args.v5loader else \ - build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, data_info=self.data, - mode='val')[0] + if self.args.v5loader: + LOGGER.warning("WARNING ⚠️ 'v5loader' feature is deprecated and will be removed soon. You can train using " + 'the default YOLOv8 dataloader instead, no argument is needed.') + gs = max(int(de_parallel(self.model).stride if self.model else 0), 32) + return create_dataloader(path=dataset_path, + imgsz=self.args.imgsz, + batch_size=batch_size, + stride=gs, + hyp=vars(self.args), + cache=False, + pad=0.5, + rect=self.args.rect, + workers=self.args.workers, + prefix=colorstr(f'{self.args.mode}: '), + shuffle=False, + seed=self.args.seed)[0] + + dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val') + dataloader = build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) + return dataloader def plot_val_samples(self, batch, ni): """Plot validation image samples.""" diff --git a/ultralytics/yolo/v8/pose/predict.py b/ultralytics/yolo/v8/pose/predict.py index a3af259..0734bc6 100644 --- a/ultralytics/yolo/v8/pose/predict.py +++ b/ultralytics/yolo/v8/pose/predict.py @@ -7,6 +7,10 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor class PosePredictor(DetectionPredictor): + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + super().__init__(cfg, overrides, _callbacks) + self.args.task = 'pose' + def postprocess(self, preds, img, orig_img): """Return detection results for a given input image or list of images.""" preds = ops.non_max_suppression(preds, @@ -24,7 +28,7 @@ class PosePredictor(DetectionPredictor): pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:] pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape) - path, _, _, _, _ = self.batch + path = self.batch[0] img_path = path[i] if isinstance(path, list) else path results.append( Results(orig_img=orig_img, diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 6ac24ed..0b6ebc4 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -9,6 +9,10 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor class SegmentationPredictor(DetectionPredictor): + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + super().__init__(cfg, overrides, _callbacks) + self.args.task = 'segment' + def postprocess(self, preds, img, orig_imgs): """TODO: filter by classes.""" p = ops.non_max_suppression(preds[0], @@ -22,7 +26,7 @@ class SegmentationPredictor(DetectionPredictor): proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported for i, pred in enumerate(p): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs - path, _, _, _, _ = self.batch + path = self.batch[0] img_path = path[i] if isinstance(path, list) else path if not len(pred): # save empty boxes results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))