ultralytics 8.0.136 refactor and simplify package (#3748)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2023-07-16 23:47:45 +08:00
committed by GitHub
parent 8ebe94d1e9
commit 620f3eb218
383 changed files with 4213 additions and 4646 deletions

View File

@ -1,45 +0,0 @@
## Models
Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
segmentation tasks.
These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
directory provides a great starting point for your custom model development needs.
To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
### Usage
Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
```bash
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
```
They may also be used directly in a Python environment, and accepts the same
[arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
```python
from ultralytics import YOLO
model = YOLO("model.yaml") # build a YOLOv8n model from scratch
# YOLO("model.pt") use pre-trained model if available
model.info() # display model information
model.train(data="coco128.yaml", epochs=100) # train the model
```
## Pre-trained Model Architectures
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information
and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
## Contributing New Models
If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).

View File

@ -0,0 +1,4 @@
from .rtdetr import RTDETR
from .sam import SAM
__all__ = 'RTDETR', 'SAM' # allow simpler import

View File

@ -0,0 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import FastSAM
from .predict import FastSAMPredictor
from .prompt import FastSAMPrompt
from .val import FastSAMValidator
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'

View File

@ -0,0 +1,111 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.engine.model import YOLO
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import FastSAMPredictor
class FastSAM(YOLO):
def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
super().__init__(model=model)
# any additional initialization code for FastSAM
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("FastSAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

View File

@ -0,0 +1,53 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.fastsam.utils import bbox_iou
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class FastSAMPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
full_box = torch.zeros_like(p[0][0])
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:, 4]
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
p[0][critical_iou_index] = full_box
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
if not len(pred): # save empty boxes
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
continue
if self.args.retina_masks:
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results

View File

@ -0,0 +1,406 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
class FastSAMPrompt:
def __init__(self, img_path, results, device='cuda') -> None:
# self.img_path = img_path
self.device = device
self.results = results
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
# Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
self.clip = clip
@staticmethod
def _segment_image(image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
@staticmethod
def _format_results(result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation = {
'id': i,
'segmentation': mask.cpu().numpy(),
'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
@staticmethod
def filter_masks(annotations): # filter the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set()
for i in range(len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove and b['area'] < a['area'] and \
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
@staticmethod
def _get_bbox_from_mask(mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# 将多个bbox合并成一个
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
h = y2 - y1
w = x2 - x1
return [x1, y1, x2, y2]
def plot(self,
annotations,
output,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True):
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
result_name = os.path.basename(self.img_path)
image = self.ori_img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
if self.device == 'cpu':
annotations = np.array(annotations)
self.fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
self.fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if withContours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
if type(mask) == dict:
mask = mask['segmentation']
annotation = mask.astype(np.uint8)
if not retina:
annotation = cv2.resize(
annotation,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
save_path = output
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.axis('off')
fig = plt.gcf()
plt.draw()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
# CPU post process
def fast_show_mask(
self,
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# 将annotation 按照面积 排序
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
def fast_show_mask_gpu(
self,
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# 找每个位置第一个非零值下标
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# 按index取数index指每个位置选哪个batch的数把mask_image转成一个batch的形式
show = torch.zeros((height, weight, 4)).to(annotation.device)
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
show_cpu = show.cpu().numpy()
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show_cpu)
# clip
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = self.clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results):
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
# annotations, _ = filter_masks(annotations)
# filter_id = list(_)
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
cropped_images.append(bbox) # 保存裁剪的图片的bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox):
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.results[0].masks.data
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()])
def point_prompt(self, points, pointlabel): # numpy 处理
masks = self._format_results(self.results[0], 0)
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for i, annotation in enumerate(masks):
mask = annotation['segmentation'] if type(annotation) == dict else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask -= mask
onemask = onemask >= 1
return np.array([onemask])
def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
return np.array([annotations[max_idx]['segmentation']])
def everything_prompt(self):
return self.results[0].masks.data

View File

@ -0,0 +1,64 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): (n, 4)
image_shape (tuple): (height, width)
threshold (int): pixel threshold
Returns:
adjusted_boxes (torch.Tensor): adjusted bounding boxes
"""
# Image dimensions
h, w = image_shape
# Adjust boxes
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1 (torch.Tensor): (4, )
boxes (torch.Tensor): (n, 4)
Returns:
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# compute the area of intersection
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# compute the area of both individual boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# compute the area of union
union = box1_area + box2_area - intersection
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
return 0 if iou.numel() == 0 else iou
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()

View File

@ -0,0 +1,244 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from multiprocessing.pool import ThreadPool
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class FastSAMValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
return batch
def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Masks
midx = [si] if self.args.overlap_mask else idx
gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn,
labelsn,
pred_masks,
gt_masks,
overlap=self.args.overlap_mask,
masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape,
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if masks:
if overlap:
nl = len(labels)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
"""Save one JSON result."""
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
return rle
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import NAS
from .predict import NASPredictor
from .val import NASValidator
__all__ = 'NASPredictor', 'NASValidator', 'NAS'

View File

@ -0,0 +1,133 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
YOLO-NAS model interface.
Usage - Predict:
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from pathlib import Path
import torch
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator
class NAS:
def __init__(self, model='yolo_nas_s.pt') -> None:
# Load or create new NAS model
import super_gradients
self.predictor = None
suffix = Path(model).suffix
if suffix == '.pt':
self._load(model)
elif suffix == '':
self.model = super_gradients.training.models.get(model, pretrained_weights='coco')
self.task = 'detect'
self.model.args = DEFAULT_CFG_DICT # attach args to model
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = model # for export()
self.model.task = 'detect' # for export()
self.info()
@smart_inference_mode()
def _load(self, weights: str):
self.model = torch.load(weights)
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = NASPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as NAS models do not support training."""
raise NotImplementedError("NAS models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = NASValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

View File

@ -0,0 +1,35 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
class NASPredictor(BasePredictor):
def postprocess(self, preds_in, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects."""
# Cat boxes and class scores
boxes = xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results

View File

@ -0,0 +1,25 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
__all__ = ['NASValidator']
class NASValidator(DetectionValidator):
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=False,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
max_time_img=0.5)

View File

@ -1,50 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, HGStem, [32, 48]] # 0-P2/4
- [-1, 6, HGBlock, [48, 128, 3]] # stage 1
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
- [-1, 6, HGBlock, [96, 512, 3]] # stage 2
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
- [-1, 6, HGBlock, [192, 1024, 5, True, True]]
- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
head:
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
- [-1, 1, AIFI, [1024, 8]]
- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
- [[-2, -1], 1, Concat, [1]]
- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
- [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
- [[-1, 17], 1, Concat, [1]] # cat Y4
- [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
- [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
- [[-1, 12], 1, Concat, [1]] # cat Y5
- [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
- [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)

View File

@ -1,54 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
x: [1.00, 1.00, 2048]
backbone:
# [from, repeats, module, args]
- [-1, 1, HGStem, [32, 64]] # 0-P2/4
- [-1, 6, HGBlock, [64, 128, 3]] # stage 1
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
- [-1, 6, HGBlock, [128, 512, 3]]
- [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16
- [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
- [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32
- [-1, 6, HGBlock, [512, 2048, 5, True, False]]
- [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4
head:
- [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2
- [-1, 1, AIFI, [2048, 8]]
- [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1
- [[-2, -1], 1, Concat, [1]]
- [-1, 3, RepC3, [384]] # 20, fpn_blocks.0
- [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
- [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1
- [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0
- [[-1, 21], 1, Concat, [1]] # cat Y4
- [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0
- [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1
- [[-1, 16], 1, Concat, [1]] # cat Y5
- [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1
- [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator
__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'

View File

@ -0,0 +1,173 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
RT-DETR model interface
"""
from pathlib import Path
import torch.nn as nn
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
class RTDETR:
def __init__(self, model='rtdetr-l.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.yaml'):
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
# Load or create new YOLO model
self.predictor = None
self.ckpt = None
suffix = Path(model).suffix
if suffix == '.yaml':
self._new(model)
else:
self._load(model)
def _new(self, cfg: str, verbose=True):
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = 'detect'
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
# Below added to allow export from YAMLs
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.model.task = self.task
@smart_inference_mode()
def _load(self, weights: str):
self.model, self.ckpt = attempt_load_one_weight(weights)
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.task = self.model.args['task']
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = RTDETRPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
overrides = dict(task='detect', mode='train')
overrides.update(kwargs)
overrides['deterministic'] = False
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = RTDETRTrainer(overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = RTDETRValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
def info(self, verbose=True):
"""Get model info"""
return model_info(self.model, verbose=verbose)
def _check_is_pytorch_model(self):
"""
Raises TypeError is model is not a PyTorch model
"""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

View File

@ -0,0 +1,44 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
results = []
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
idx = score.squeeze(-1) > self.args.conf # (300, )
if self.args.classes is not None:
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
oh, ow = orig_img.shape[:2]
if not isinstance(orig_imgs, torch.Tensor):
pred[..., [0, 2]] *= ow
pred[..., [1, 3]] *= oh
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
def pre_transform(self, im):
"""Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
"""
# The size must be square(640) and scaleFilled.
return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]

View File

@ -0,0 +1,80 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path, mode='val', batch=None):
"""Build RTDETR Dataset
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == 'train', # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
def get_validator(self):
"""Returns a DetectionValidator for RTDETR model validation."""
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch = super().preprocess_batch(batch)
bs = len(batch['img'])
batch_idx = batch['batch_idx']
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize RTDETR model given training data and device."""
model = 'rtdetr-l.yaml'
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
args = dict(model=model,
data=data,
device=device,
imgsz=640,
exist_ok=True,
batch=4,
deterministic=False,
amp=False)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -0,0 +1,151 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
import cv2
import numpy as np
import torch
from ultralytics.data import YOLODataset
from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
__all__ = 'RTDETRValidator', # tuple or list
# TODO: Temporarily, RT-DETR does not need padding.
class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
# NOTE: add stretch version load_image for rtdetr mosaic
def load_image(self, i):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
# Add to buffer if training with augmentations
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
self.buffer.append(i)
if len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
"""Temporarily, only for evaluation."""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
else:
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
transforms = Compose([])
transforms.append(
Format(bbox_format='xywh',
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
return transforms
class RTDETRValidator(DetectionValidator):
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=False, # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
score, cls = scores[i].max(-1) # (300, )
# Do not need threshold for evaluation as only got 300 boxes here.
# idx = score > self.args.conf
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
# sort by confidence to correctly get internal metrics.
pred = pred[score.argsort(descending=True)]
outputs[i] = pred # [idx]
return outputs
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred
# Evaluate
if nl:
tbox = ops.xywh2xyxy(bbox) # target boxes
tbox[..., [0, 2]] *= shape[1] # native-space pred
tbox[..., [1, 3]] *= shape[0] # native-space pred
labelsn = torch.cat((cls, tbox), 1) # native-space labels
# NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
correct_bboxes = self._process_batch(predn.float(), labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, shape, file)

View File

@ -0,0 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import SAM
from .predict import Predictor
# from .build import build_sam
__all__ = 'SAM', 'Predictor' # tuple or list

View File

@ -0,0 +1,311 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
import numpy as np
import torch
class MaskData:
"""
A structure for storing masks and their related data in batched format.
Implements basic filtering and concatenation.
"""
def __init__(self, **kwargs) -> None:
"""Initialize a MaskData object, ensuring all values are supported types."""
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)), 'MaskData only supports list, numpy arrays, and torch tensors.'
self._stats = dict(**kwargs)
def __setitem__(self, key: str, item: Any) -> None:
"""Set an item in the MaskData object, ensuring it is a supported type."""
assert isinstance(
item, (list, np.ndarray, torch.Tensor)), 'MaskData only supports list, numpy arrays, and torch tensors.'
self._stats[key] = item
def __delitem__(self, key: str) -> None:
"""Delete an item from the MaskData object."""
del self._stats[key]
def __getitem__(self, key: str) -> Any:
"""Get an item from the MaskData object."""
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
"""Return an ItemsView of the MaskData object."""
return self._stats.items()
def filter(self, keep: torch.Tensor) -> None:
"""Filter the MaskData object based on the given boolean tensor."""
for k, v in self._stats.items():
if v is None:
self._stats[k] = None
elif isinstance(v, torch.Tensor):
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
elif isinstance(v, np.ndarray):
self._stats[k] = v[keep.detach().cpu().numpy()]
elif isinstance(v, list) and keep.dtype == torch.bool:
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
elif isinstance(v, list):
self._stats[k] = [v[i] for i in keep]
else:
raise TypeError(f'MaskData key {k} has an unsupported type {type(v)}.')
def cat(self, new_stats: 'MaskData') -> None:
"""Concatenate a new MaskData object to the current one."""
for k, v in new_stats.items():
if k not in self._stats or self._stats[k] is None:
self._stats[k] = deepcopy(v)
elif isinstance(v, torch.Tensor):
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
elif isinstance(v, np.ndarray):
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
elif isinstance(v, list):
self._stats[k] = self._stats[k] + deepcopy(v)
else:
raise TypeError(f'MaskData key {k} has an unsupported type {type(v)}.')
def to_numpy(self) -> None:
"""Convert all torch tensors in the MaskData object to numpy arrays."""
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.detach().cpu().numpy()
def is_box_near_crop_edge(boxes: torch.Tensor,
crop_box: List[int],
orig_box: List[int],
atol: float = 20.0) -> torch.Tensor:
"""Return a boolean tensor indicating if boxes are near the crop edge."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
"""Convert bounding boxes from XYXY format to XYWH format."""
box_xywh = deepcopy(box_xyxy)
box_xywh[2] = box_xywh[2] - box_xywh[0]
box_xywh[3] = box_xywh[3] - box_xywh[1]
return box_xywh
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
"""Yield batches of data from the input arguments."""
assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""Encode masks as uncompressed RLEs in the format expected by pycocotools."""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(b):
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
cur_idxs = torch.cat([
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
cur_idxs + 1,
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ])
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if tensor[i, 0] == 0 else [0]
counts.extend(btw_idxs.detach().cpu().tolist())
out.append({'size': [h, w], 'counts': counts})
return out
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
h, w = rle['size']
mask = np.empty(h * w, dtype=bool)
idx = 0
parity = False
for count in rle['counts']:
mask[idx:idx + count] = parity
idx += count
parity ^= True
mask = mask.reshape(w, h)
return mask.transpose() # Put in C order
def area_from_rle(rle: Dict[str, Any]) -> int:
"""Calculate the area of a mask from its uncompressed RLE."""
return sum(rle['counts'][1::2])
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
dtype=torch.int32))
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generate point grids for all crop layers."""
return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
"""Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
"""Crops bounding boxes to the size of the input image."""
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop bounding boxes by adding the crop box offset."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop points by adding the crop box offset."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
"""Uncrop masks by padding them to the original image size."""
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
import cv2 # type: ignore
assert mode in {'holes', 'islands'}
correct_holes = mode == 'holes'
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if not small_regions:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
# If every region is below threshold, keep largest
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
"""Encode uncompressed RLE (run-length encoding) to COCO RLE format."""
from pycocotools import mask as mask_utils # type: ignore
h, w = uncompressed_rle['size']
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
rle['counts'] = rle['counts'].decode('utf-8') # Necessary to serialize with json
return rle
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
shape = masks.shape
h, w = shape[-2:]
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]

View File

@ -0,0 +1,158 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) h-size model."""
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
def build_sam_vit_l(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) l-size model."""
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) b-size model."""
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def build_mobile_sam(checkpoint=None):
"""Build and return Mobile Segment Anything Model (Mobile-SAM)."""
return _build_sam(
encoder_embed_dim=[64, 128, 160, 320],
encoder_depth=[2, 2, 6, 2],
encoder_num_heads=[2, 4, 5, 10],
encoder_global_attn_indexes=None,
mobile_sam=True,
checkpoint=checkpoint,
)
def _build_sam(encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
mobile_sam=False):
"""Builds the selected SAM model architecture."""
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder = (TinyViT(
img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=encoder_embed_dim,
depths=encoder_depth,
num_heads=encoder_num_heads,
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
) if mobile_sam else ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
))
sam = Sam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, 'rb') as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
sam.eval()
# sam.load_state_dict(torch.load(checkpoint), strict=True)
# sam.eval()
return sam
sam_model_map = {
'sam_h.pt': build_sam_vit_h,
'sam_l.pt': build_sam_vit_l,
'sam_b.pt': build_sam_vit_b,
'mobile_sam.pt': build_mobile_sam, }
def build_sam(ckpt='sam_b.pt'):
"""Build a SAM model specified by ckpt."""
model_builder = None
for k in sam_model_map.keys():
if ckpt.endswith(k):
model_builder = sam_model_map.get(k)
if not model_builder:
raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
return model_builder(ckpt)

View File

@ -0,0 +1,59 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
SAM model interface
"""
from ultralytics.cfg import get_cfg
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
class SAM:
def __init__(self, model='sam_b.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.pth'):
# Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
self.model = build_sam(model)
self.task = 'segment' # required
self.predictor = None # reuse predictor
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = Predictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
def train(self, **kwargs):
"""Function trains models but raises an error as SAM models do not support training."""
raise NotImplementedError("SAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation")
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)

View File

@ -0,0 +1 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

View File

@ -0,0 +1,159 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List, Tuple, Type
import torch
from torch import nn
from torch.nn import functional as F
from ultralytics.nn.modules import LayerNorm2d
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer module
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
activation (nn.Module): the type of activation to use when upscaling masks
iou_head_depth (int): the depth of the MLP used to predict mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList([
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
mask_slice = slice(1, None) if multimask_output else slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = [
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
class MLP(nn.Module):
"""
Lightly adapted from
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output
def forward(self, x):
"""Executes feedforward within the neural network module and applies activation."""
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = torch.sigmoid(x)
return x

View File

@ -0,0 +1,583 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import Any, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
return self.mask_downscaling(masks)
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates
and labels to embed.
boxes (torch.Tensor, None): boxes to embed
masks (torch.Tensor, None): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
1).expand(bs, -1, self.image_embedding_size[0],
self.image_embedding_size[1])
return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
'positional_encoding_gaussian_matrix',
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int), None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int), None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
hw: Tuple[int, int]) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode='linear',
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x

View File

@ -0,0 +1,173 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = 'RGB'
def __init__(self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = None,
pixel_std: List[float] = None) -> None:
"""
SAM predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
if pixel_mean is None:
pixel_mean = [123.675, 116.28, 103.53]
if pixel_std is None:
pixel_std = [58.395, 57.12, 57.375]
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if 'point_coords' in image_record:
points = (image_record['point_coords'], image_record['point_labels'])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get('boxes', None),
masks=image_record.get('mask_inputs', None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record['image'].shape[-2:],
original_size=image_record['original_size'],
)
masks = masks > self.mask_threshold
outputs.append({
'masks': masks,
'iou_predictions': iou_predictions,
'low_res_logits': low_res_masks, })
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode='bilinear',
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1]]
masks = F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
return F.pad(x, (0, padw, 0, padh))

View File

@ -0,0 +1,653 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# --------------------------------------------------------
# TinyViT Model Architecture
# Copyright (c) 2022 Microsoft
# Adapted from LeViT and Swin Transformer
# LeViT: (https://github.com/facebookresearch/levit)
# Swin: (https://github.com/microsoft/swin-transformer)
# Build the TinyViT Model
# --------------------------------------------------------
import itertools
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
w.size(0),
w.shape[2:],
stride=self.c.stride,
padding=self.c.padding,
dilation=self.c.dilation,
groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
# NOTE: This module and timm package is needed only for training.
# from ultralytics.utils.checks import check_requirements
# check_requirements('timm')
# from timm.models.layers import DropPath as TimmDropPath
# from timm.models.layers import trunc_normal_
# class DropPath(TimmDropPath):
#
# def __init__(self, drop_prob=None):
# super().__init__(drop_prob=drop_prob)
# self.drop_prob = drop_prob
#
# def __repr__(self):
# msg = super().__repr__()
# msg += f'(drop_prob={self.drop_prob})'
# return msg
class PatchEmbed(nn.Module):
def __init__(self, in_chans, embed_dim, resolution, activation):
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * \
self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
self.seq = nn.Sequential(
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
activation(),
Conv2d_BN(n // 2, n, 3, 2, 1),
)
def forward(self, x):
return self.seq(x)
class MBConv(nn.Module):
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
self.out_chans = out_chans
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
self.act1 = activation()
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
self.act2 = activation()
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
self.act3 = activation()
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act3(x)
return x
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 2
if (out_dim == 320 or out_dim == 448 or out_dim == 576):
stride_c = 1
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
class ConvLayer(nn.Module):
def __init__(
self,
dim,
input_resolution,
depth,
activation,
drop_path=0.,
downsample=None,
use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.act = act_layer()
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(torch.nn.Module):
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=(14, 14),
):
super().__init__()
# (h, w)
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,N,C)
B, N, _ = x.shape
# Normalization
x = self.norm(x)
qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
self.ab = self.ab.to(self.attention_biases.device)
attn = ((q @ k.transpose(-2, -1)) * self.scale +
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class TinyViTBlock(nn.Module):
r""" TinyViT Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int, int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between
Attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
local_conv_size=3,
activation=nn.GELU,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
self.mlp_ratio = mlp_ratio
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_activation = activation
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
pad = local_conv_size // 2
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size - W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = x.view(B, nH, self.window_size, nW, self.window_size,
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
x = self.attn(x)
# window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.view(B, L, C)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
class BasicLayer(nn.Module):
""" A basic TinyViT layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
out_dim (int | optional): the output dimension of the layer. Default: None
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
activation=nn.GELU,
out_dim=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
local_conv_size=local_conv_size,
activation=activation,
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class TinyViT(nn.Module):
def __init__(
self,
img_size=224,
in_chans=3,
num_classes=1000,
embed_dims=[96, 192, 384, 768],
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=1.0,
):
super().__init__()
self.img_size = img_size
self.num_classes = num_classes
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
activation = nn.GELU
self.patch_embed = PatchEmbed(in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1,
len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs,
)
else:
layer = BasicLayer(num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs)
self.layers.append(layer)
# Classifier head
self.norm_head = nn.LayerNorm(embed_dims[-1])
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
# init weights
self.apply(self._init_weights)
self.set_layer_lr_decay(layer_lr_decay)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dims[-1],
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
)
def set_layer_lr_decay(self, layer_lr_decay):
decay_rate = layer_lr_decay
# layers -> blocks (depth)
depth = sum(self.depths)
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
def _set_lr_scale(m, scale):
for p in m.parameters():
p.lr_scale = scale
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
i = 0
for layer in self.layers:
for block in layer.blocks:
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
i += 1
if layer.downsample is not None:
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
assert i == depth
for m in [self.norm_head, self.head]:
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
for k, p in self.named_parameters():
p.param_name = k
def _check_lr_scale(m):
for p in m.parameters():
assert hasattr(p, 'lr_scale'), p.param_name
self.apply(_check_lr_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'attention_biases'}
def forward_features(self, x):
# x: (N, C, H, W)
x = self.patch_embed(x)
x = self.layers[0](x)
start_i = 1
for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
B, _, C = x.size()
x = x.view(B, 64, 64, C)
x = x.permute(0, 3, 1, 2)
x = self.neck(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x

View File

@ -0,0 +1,235 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
from typing import Tuple, Type
import torch
from torch import Tensor, nn
from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
))
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
"""Separate the input tensor into the specified number of attention heads."""
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
"""Recombine the separated attention heads into a single tensor."""
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Compute the attention output given the input query, key, and value tensors."""
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out

View File

@ -0,0 +1,398 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
from .build import build_sam
class Predictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
if overrides is None:
overrides = {}
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
# SAM needs retina_masks=True, or the results would be a mess.
self.args.retina_masks = True
# Args for set_image
self.im = None
self.features = None
# Args for segment everything
self.segment_all = False
def preprocess(self, im):
"""Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
"""
if self.im is not None:
return self.im
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
img = im.to(self.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
if not_tensor:
img = (img - self.mean) / self.std
return img
def pre_transform(self, im):
"""Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
"""
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
"""
Predict masks for the given input prompts, using the currently set image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
"""
Predict masks for the given input prompts, using the currently set image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
features = self.model.image_encoder(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = np.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes *= r
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device)
masks = masks[:, None, :, :]
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=bboxes,
masks=masks,
)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def generate(self,
im,
crop_n_layers=0,
crop_overlap_ratio=512 / 1500,
crop_downscale_factor=1,
point_grids=None,
points_stride=32,
points_batch_size=64,
conf_thres=0.88,
stability_score_thresh=0.95,
stability_score_offset=0.95,
crop_nms_thresh=0.7):
"""Segment the whole image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray), None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
points_stride (int, None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_batch_size (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
conf_thres (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
"""
self.segment_all = True
ih, iw = im.shape[2:]
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
if point_grids is None:
point_grids = build_all_layer_point_grids(
points_stride,
crop_n_layers,
crop_downscale_factor,
)
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
x1, y1, x2, y2 = crop_region
w, h = x2 - x1, y2 - y1
area = torch.tensor(w * h, device=im.device)
points_scale = np.array([[w, h]]) # w, h
# Crop image and interpolate to input size
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
# (num_points, 2)
points_for_image = point_grids[layer_idx] * points_scale
crop_masks, crop_scores, crop_bboxes = [], [], []
for (points, ) in batch_iterator(points_batch_size, points_for_image):
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
# Interpolate predicted masks to input size
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
idx = pred_score > conf_thres
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
stability_score_offset)
idx = stability_score > stability_score_thresh
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
# Bool type is much more memory-efficient.
pred_mask = pred_mask > self.model.mask_threshold
# (N, 4)
pred_bbox = batched_mask_to_box(pred_mask).float()
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
if not torch.all(keep_mask):
pred_bbox = pred_bbox[keep_mask]
pred_mask = pred_mask[keep_mask]
pred_score = pred_score[keep_mask]
crop_masks.append(pred_mask)
crop_bboxes.append(pred_bbox)
crop_scores.append(pred_score)
# Do nms within this crop
crop_masks = torch.cat(crop_masks)
crop_bboxes = torch.cat(crop_bboxes)
crop_scores = torch.cat(crop_scores)
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
crop_scores = crop_scores[keep]
pred_masks.append(crop_masks)
pred_bboxes.append(crop_bboxes)
pred_scores.append(crop_scores)
region_areas.append(area.expand(len(crop_masks)))
pred_masks = torch.cat(pred_masks)
pred_bboxes = torch.cat(pred_bboxes)
pred_scores = torch.cat(pred_scores)
region_areas = torch.cat(region_areas)
# Remove duplicate masks between crops
if len(crop_regions) > 1:
scores = 1 / region_areas
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
pred_masks = pred_masks[keep]
pred_bboxes = pred_bboxes[keep]
pred_scores = pred_scores[keep]
return pred_masks, pred_scores, pred_bboxes
def setup_model(self, model, verbose=True):
"""Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device)
if model is None:
model = build_sam(self.args.model)
model.eval()
self.model = model.to(device)
self.device = device
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
# TODO: Temporary settings for compatibility
self.model.pt = False
self.model.triton = False
self.model.stride = 32
self.model.fp16 = False
self.done_warmup = True
def postprocess(self, preds, img, orig_imgs):
"""Postprocesses inference output predictions to create detection masks for objects."""
# (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
results = []
for i, masks in enumerate([pred_masks]):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
masks = masks > self.model.mask_threshold # to bool
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
# Reset segment-all mode.
self.segment_all = False
return results
def setup_source(self, source):
"""Sets up source and inference mode."""
if source is not None:
super().setup_source(source)
def set_image(self, image):
"""Set image in advance.
Args:
image (str | np.ndarray): image file path or np.ndarray image by cv2.
"""
if self.model is None:
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_source(image)
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
self.im = im
break
def reset_image(self):
self.im = None
self.features = None
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates. Requires open-cv as a dependency.
Args:
masks (torch.Tensor): Masks, (N, H, W).
min_area (int): Minimum area threshold.
nms_thresh (float): NMS threshold.
"""
if len(masks) == 0:
return masks
# Filter small disconnected regions and holes
new_masks = []
scores = []
for mask in masks:
mask = mask.cpu().numpy()
mask, changed = remove_small_regions(mask, min_area, mode='holes')
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
new_masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(new_masks)
keep = torchvision.ops.nms(
boxes.float(),
torch.as_tensor(scores),
nms_thresh,
)
# Only recalculate masks for masks that have changed
for i in keep:
if scores[i] == 0.0:
masks[i] = new_masks[i]
return masks[keep]

View File

@ -0,0 +1 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

View File

@ -0,0 +1,295 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.utils.metrics import bbox_iou
from .ops import HungarianMatcher
class DETRLoss(nn.Module):
def __init__(self,
nc=80,
loss_gain=None,
aux_loss=True,
use_fl=True,
use_vfl=False,
use_uni_match=False,
uni_match_ind=0):
"""
DETR loss function.
Args:
nc (int): The number of classes.
loss_gain (dict): The coefficient of loss.
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
use_vfl (bool): Use VarifocalLoss or not.
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
uni_match_ind (int): The fixed indices of a layer.
"""
super().__init__()
if loss_gain is None:
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
self.nc = nc
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
self.loss_gain = loss_gain
self.aux_loss = aux_loss
self.fl = FocalLoss() if use_fl else None
self.vfl = VarifocalLoss() if use_vfl else None
self.use_uni_match = use_uni_match
self.uni_match_ind = uni_match_ind
self.device = None
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f'loss_class{postfix}'
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
one_hot = one_hot[..., :-1]
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
if self.fl:
if num_gts and self.vfl:
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
else:
loss_cls = self.fl(pred_scores, one_hot.float())
loss_cls /= max(num_gts, 1) / nq
else:
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f'loss_bbox{postfix}'
name_giou = f'loss_giou{postfix}'
loss = {}
if len(gt_bboxes) == 0:
loss[name_bbox] = torch.tensor(0., device=self.device)
loss[name_giou] = torch.tensor(0., device=self.device)
return loss
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
loss = {k: v.squeeze() for k, v in loss.items()}
return loss
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
name_mask = f'loss_mask{postfix}'
name_dice = f'loss_dice{postfix}'
loss = {}
if sum(len(a) for a in gt_mask) == 0:
loss[name_mask] = torch.tensor(0., device=self.device)
loss[name_dice] = torch.tensor(0., device=self.device)
return loss
num_gts = len(gt_mask)
src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
# TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
torch.tensor([num_gts], dtype=torch.float32))
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
return loss
def _dice_loss(self, inputs, targets, num_gts):
inputs = F.sigmoid(inputs)
inputs = inputs.flatten(1)
targets = targets.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_gts
def _get_loss_aux(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
match_indices=None,
postfix='',
masks=None,
gt_mask=None):
"""Get auxiliary losses"""
# NOTE: loss class, bbox, giou, mask, dice
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
if match_indices is None and self.use_uni_match:
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
gt_bboxes,
gt_cls,
gt_groups,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask)
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
aux_masks = masks[i] if masks is not None else None
loss_ = self._get_loss(aux_bboxes,
aux_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=aux_masks,
gt_mask=gt_mask,
postfix=postfix,
match_indices=match_indices)
loss[0] += loss_[f'loss_class{postfix}']
loss[1] += loss_[f'loss_bbox{postfix}']
loss[2] += loss_[f'loss_giou{postfix}']
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
loss = {
f'loss_class_aux{postfix}': loss[0],
f'loss_bbox_aux{postfix}': loss[1],
f'loss_giou_aux{postfix}': loss[2]}
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
return loss
def _get_index(self, match_indices):
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
src_idx = torch.cat([src for (src, _) in match_indices])
dst_idx = torch.cat([dst for (_, dst) in match_indices])
return (batch_idx, src_idx), dst_idx
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
pred_assigned = torch.cat([
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (I, _) in zip(pred_bboxes, match_indices)])
gt_assigned = torch.cat([
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, J) in zip(gt_bboxes, match_indices)])
return pred_assigned, gt_assigned
def _get_loss(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=None,
gt_mask=None,
postfix='',
match_indices=None):
"""Get losses"""
if match_indices is None:
match_indices = self.matcher(pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=masks,
gt_mask=gt_mask)
idx, gt_idx = self._get_index(match_indices)
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
bs, nq = pred_scores.shape[:2]
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
targets[idx] = gt_cls[gt_idx]
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
if len(gt_bboxes):
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
loss = {}
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
# if masks is not None and gt_mask is not None:
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
return loss
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
"""
Args:
pred_bboxes (torch.Tensor): [l, b, query, 4]
pred_scores (torch.Tensor): [l, b, query, num_classes]
batch (dict): A dict includes:
gt_cls (torch.Tensor) with shape [num_gts, ],
gt_bboxes (torch.Tensor): [num_gts, 4],
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
postfix (str): postfix of loss name.
"""
self.device = pred_bboxes.device
match_indices = kwargs.get('match_indices', None)
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
total_loss = self._get_loss(pred_bboxes[-1],
pred_scores[-1],
gt_bboxes,
gt_cls,
gt_groups,
postfix=postfix,
match_indices=match_indices)
if self.aux_loss:
total_loss.update(
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
postfix))
return total_loss
class RTDETRDetectionLoss(DETRLoss):
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
pred_bboxes, pred_scores = preds
total_loss = super().forward(pred_bboxes, pred_scores, batch)
if dn_meta is not None:
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
assert len(batch['gt_groups']) == len(dn_pos_idx)
# denoising match indices
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
# compute denoising training loss
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
total_loss.update(dn_loss)
else:
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
return total_loss
@staticmethod
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
"""Get the match indices for denoising.
Args:
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
dn_num_group (int): The number of groups of denoising.
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
Returns:
dn_match_indices (List(tuple)): Matched indices.
"""
dn_match_indices = []
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
for i, num_gt in enumerate(gt_groups):
if num_gt > 0:
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
gt_idx = gt_idx.repeat(dn_num_group)
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
dn_match_indices.append((dn_pos_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
return dn_match_indices

View File

@ -0,0 +1,260 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from ultralytics.utils.metrics import bbox_iou
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
class HungarianMatcher(nn.Module):
"""
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in
an end-to-end fashion.
HungarianMatcher performs optimal assignment over predicted and ground truth bounding boxes using a cost function
that considers classification scores, bounding box coordinates, and optionally, mask predictions.
Attributes:
cost_gain (dict): Dictionary of cost coefficients for different components: 'class', 'bbox', 'giou', 'mask', and 'dice'.
use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
with_mask (bool): Indicates whether the model makes mask predictions.
num_sample_points (int): The number of sample points used in mask cost calculation.
alpha (float): The alpha factor in Focal Loss calculation.
gamma (float): The gamma factor in Focal Loss calculation.
Methods:
forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the assignment
between predictions and ground truths for a batch.
_cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
"""
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
super().__init__()
if cost_gain is None:
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
self.cost_gain = cost_gain
self.use_fl = use_fl
self.with_mask = with_mask
self.num_sample_points = num_sample_points
self.alpha = alpha
self.gamma = gamma
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
"""
Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching
between predictions and ground truth based on these costs.
Args:
pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
each image.
masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
Defaults to None.
gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
Defaults to None.
Returns:
(List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
- index_i is the tensor of indices of the selected predictions (in order)
- index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, nq, nc = pred_scores.shape
if sum(gt_groups) == 0:
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
# We flatten to compute the cost matrices in a batch
# [batch_size * num_queries, num_classes]
pred_scores = pred_scores.detach().view(-1, nc)
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
# [batch_size * num_queries, 4]
pred_bboxes = pred_bboxes.detach().view(-1, 4)
# Compute the classification cost
pred_scores = pred_scores[:, gt_cls]
if self.use_fl:
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -pred_scores
# Compute the L1 cost between boxes
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
# Final cost matrix
C = self.cost_gain['class'] * cost_class + \
self.cost_gain['bbox'] * cost_bbox + \
self.cost_gain['giou'] * cost_giou
# Compute the mask cost and dice cost
if self.with_mask:
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
C = C.view(bs, nq, -1).cpu()
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
# (idx for queries, idx for gt)
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
for k, (i, j) in enumerate(indices)]
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
# all masks share the same set of points for efficient matching
sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
sample_points = 2.0 * sample_points - 1.0
out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
out_mask = out_mask.flatten(0, 1)
tgt_mask = torch.cat(gt_mask).unsqueeze(1)
sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
with torch.cuda.amp.autocast(False):
# binary cross entropy cost
pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
cost_mask /= self.num_sample_points
# dice cost
out_mask = F.sigmoid(out_mask)
numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
cost_dice = 1 - (numerator + 1) / (denominator + 1)
C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
return C
def get_cdn_group(batch,
num_classes,
num_queries,
class_embed,
num_dn=100,
cls_noise_ratio=0.5,
box_noise_scale=1.0,
training=False):
"""
Get contrastive denoising training group. This function creates a contrastive denoising training group with
positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding
box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information.
Args:
batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
(torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
indicating the number of gts of each image.
num_classes (int): Number of classes.
num_queries (int): Number of queries.
class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
num_dn (int, optional): Number of denoising. Defaults to 100.
cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
training (bool, optional): If it's in training mode. Defaults to False.
Returns:
(Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
is less than or equal to 0, the function returns None for all elements in the tuple.
"""
if (not training) or num_dn <= 0:
return None, None, None, None
gt_groups = batch['gt_groups']
total_num = sum(gt_groups)
max_nums = max(gt_groups)
if max_nums == 0:
return None, None, None, None
num_group = num_dn // max_nums
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch['cls'] # (bs*num, )
gt_bbox = batch['bboxes'] # bs*num, 4
b_idx = batch['batch_idx']
# each group has positive and negative queries.
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
# positive and negative mask
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
if cls_noise_ratio > 0:
# half of bbox prob
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# randomly put a new one here
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label
if box_noise_scale > 0:
known_bbox = xywh2xyxy(dn_bbox)
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(dn_bbox)
rand_part[neg_idx] += 1.0
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
dn_bbox = xyxy2xywh(known_bbox)
dn_bbox = inverse_sigmoid(dn_bbox)
# total denoising queries
num_dn = int(max_nums * 2 * num_group)
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
# match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
if i == num_group - 1:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
else:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
dn_meta = {
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
'dn_num_group': num_group,
'dn_num_split': [num_dn, num_queries]}
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
class_embed.device), dn_meta
def inverse_sigmoid(x, eps=1e-6):
"""Inverse sigmoid function."""
x = x.clip(min=0., max=1.)
return torch.log(x / (1 - x + eps) + eps)

View File

@ -1,48 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
# Parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# darknet53 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [32, 3, 1]], # 0
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
[-1, 1, Bottleneck, [64]],
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
[-1, 2, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
[-1, 8, Bottleneck, [256]],
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
[-1, 8, Bottleneck, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
[-1, 4, Bottleneck, [1024]], # 10
]
# YOLOv3-SPP head
head:
[[-1, 1, Bottleneck, [1024, False]],
[-1, 1, SPP, [512, [5, 9, 13]]],
[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
[-2, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P4
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
[-2, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P3
[-1, 1, Bottleneck, [256, False]],
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
[[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
]

View File

@ -1,39 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
# Parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# YOLOv3-tiny backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [16, 3, 1]], # 0
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
[-1, 1, Conv, [32, 3, 1]],
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
[-1, 1, Conv, [64, 3, 1]],
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
[-1, 1, Conv, [128, 3, 1]],
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
[-1, 1, Conv, [256, 3, 1]],
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
[-1, 1, Conv, [512, 3, 1]],
[-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
[-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
]
# YOLOv3-tiny head
head:
[[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
[-2, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
[[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
]

View File

@ -1,48 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
# Parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# darknet53 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [32, 3, 1]], # 0
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
[-1, 1, Bottleneck, [64]],
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
[-1, 2, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
[-1, 8, Bottleneck, [256]],
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
[-1, 8, Bottleneck, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
[-1, 4, Bottleneck, [1024]], # 10
]
# YOLOv3 head
head:
[[-1, 1, Bottleneck, [1024, False]],
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
[-2, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P4
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
[-2, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P3
[-1, 1, Bottleneck, [256, False]],
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
[[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
]

View File

@ -1,61 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 1024]
l: [1.00, 1.00, 1024]
x: [1.33, 1.25, 1024]
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C3, [768]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 11
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [768, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P5
[-1, 3, C3, [768, False]], # 15
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 19
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 20], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 16], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
[-1, 1, Conv, [768, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P6
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
]

View File

@ -1,50 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 1024]
l: [1.00, 1.00, 1024]
x: [1.33, 1.25, 1024]
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
]

View File

@ -1,53 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
# Parameters
nc: 80 # number of classes
activation: nn.ReLU() # (optional) model default activation function
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv6-3.0s backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 6, Conv, [128, 3, 1]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 12, Conv, [256, 3, 1]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 18, Conv, [512, 3, 1]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 6, Conv, [1024, 3, 1]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv6-3.0s head
head:
- [-1, 1, Conv, [256, 1, 1]]
- [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 1, Conv, [256, 3, 1]]
- [-1, 9, Conv, [256, 3, 1]] # 14
- [-1, 1, Conv, [128, 1, 1]]
- [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 1, Conv, [128, 3, 1]]
- [-1, 9, Conv, [128, 3, 1]] # 19
- [-1, 1, Conv, [128, 3, 2]]
- [[-1, 15], 1, Concat, [1]] # cat head P4
- [-1, 1, Conv, [256, 3, 1]]
- [-1, 9, Conv, [256, 3, 1]] # 23
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 1, Conv, [512, 3, 1]]
- [-1, 9, Conv, [512, 3, 1]] # 27
- [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -1,29 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
# Parameters
nc: 1000 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 1024]
l: [1.00, 1.00, 1024]
x: [1.00, 1.25, 1024]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]] # Classify

View File

@ -1,54 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv8.0 backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0-p2 head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 2], 1, Concat, [1]] # cat backbone P2
- [-1, 3, C2f, [128]] # 18 (P2/4-xsmall)
- [-1, 1, Conv, [128, 3, 2]]
- [[-1, 15], 1, Concat, [1]] # cat head P3
- [-1, 3, C2f, [256]] # 21 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 24 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 27 (P5/32-large)
- [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5)

View File

@ -1,56 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv8.0x6 backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [768, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 11
# YOLOv8.0x6 head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 8], 1, Concat, [1]] # cat backbone P5
- [-1, 3, C2, [768, False]] # 14
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2, [512, False]] # 17
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2, [256, False]] # 20 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 17], 1, Concat, [1]] # cat head P4
- [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 14], 1, Concat, [1]] # cat head P5
- [-1, 3, C2, [768, False]] # 26 (P5/32-large)
- [-1, 1, Conv, [768, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P6
- [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
- [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)

View File

@ -1,57 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
# Parameters
nc: 1 # number of classes
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv8.0x6 backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [768, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 11
# YOLOv8.0x6 head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 8], 1, Concat, [1]] # cat backbone P5
- [-1, 3, C2, [768, False]] # 14
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2, [512, False]] # 17
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2, [256, False]] # 20 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 17], 1, Concat, [1]] # cat head P4
- [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 14], 1, Concat, [1]] # cat head P5
- [-1, 3, C2, [768, False]] # 26 (P5/32-large)
- [-1, 1, Conv, [768, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P6
- [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
- [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6)

View File

@ -1,47 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
# Parameters
nc: 1 # number of classes
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)

View File

@ -1,46 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)

View File

@ -1,46 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)

View File

@ -1,46 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo import classify, detect, pose, segment
__all__ = 'classify', 'segment', 'detect', 'pose'

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
from ultralytics.models.yolo.classify.val import ClassificationValidator, val
__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'

View File

@ -0,0 +1,51 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ROOT
class ClassificationPredictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
img = torch.stack([self.transforms(im) for im in img], dim=0)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions to return Results objects."""
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Run YOLO model predictions on input images/videos."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -0,0 +1,161 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torchvision
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None:
overrides = {}
overrides['task'] = 'classify'
if overrides.get('imgsz') is None:
overrides['imgsz'] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data['names']
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""
load/create/download model for any task
"""
# Classification models require special handling
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith('.pt'):
self.model, _ = attempt_load_one_weight(model, device='cpu')
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.endswith('.yaml'):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
# Attach inference transforms
if mode != 'train':
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
batch['img'] = batch['img'].to(self.device)
batch['cls'] = batch['cls'].to(self.device)
return batch
def progress_string(self):
"""Returns a formatted string showing training progress."""
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ['loss']
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def resume_training(self, ckpt):
"""Resumes training from a given checkpoint."""
pass
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
# TODO: validate best.pt after training completes
# if f is self.best:
# LOGGER.info(f'\nValidating {f}...')
# self.validator.args.save_json = True
# self.metrics = self.validator(model=f)
# self.metrics.pop('fitness', None)
# self.run_callbacks('on_fit_epoch_end')
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = ClassificationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -0,0 +1,109 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'classify'
self.metrics = ClassifyMetrics()
def get_desc(self):
"""Returns a formatted string summarizing classification metrics."""
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
def init_metrics(self, model):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names
self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
self.pred = []
self.targets = []
def preprocess(self, batch):
"""Preprocesses input batch and returns it."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch['cls'] = batch['cls'].to(self.device)
return batch
def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.model.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path):
return ClassificationDataset(root=img_path, args=self.args, augment=False)
def get_dataloader(self, dataset_path, batch_size):
"""Builds and returns a data loader for classification tasks with given parameters."""
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate YOLO model using custom data."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = ClassificationValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import DetectionPredictor, predict
from .train import DetectionTrainer, train
from .val import DetectionValidator, val
__all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'

View File

@ -0,0 +1,48 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ROOT, ops
class DetectionPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO model inference on input image(s)."""
model = cfg.model or 'yolov8n.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -0,0 +1,123 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
import numpy as np
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
# BaseTrainer python usage
class DetectionTrainer(BaseTrainer):
def build_dataset(self, img_path, mode='train', batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Construct and return dataloader."""
assert mode in ['train', 'val']
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == 'train'
if getattr(dataset, 'rect', False) and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
workers = self.args.workers if mode == 'train' else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
return batch
def set_model_attributes(self):
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data['nc'] # attach number of classes to model
self.model.names = self.data['names'] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
else:
return keys
def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=batch['batch_idx'],
cls=batch['cls'].squeeze(-1),
bboxes=batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize YOLO model given training data and device."""
model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = DetectionTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -0,0 +1,276 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import numpy as np
import torch
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import output_to_target, plot_images
from ultralytics.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'detect'
self.is_coco = False
self.class_map = None
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
for k in ['batch_idx', 'cls', 'bboxes']:
batch[k] = batch[k].to(self.device)
nb = len(batch['img'])
self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i]
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
return batch
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, '') # validation path
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
self.names = model.names
self.nc = len(model.names)
self.metrics.names = self.names
self.metrics.plot = self.args.plots
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.seen = 0
self.jdict = []
self.stats = []
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det)
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, shape, file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
self.metrics.process(*stats)
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
return self.metrics.results_dict
def print_results(self):
"""Prints training/validation set metrics per class."""
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
LOGGER.warning(
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
def _process_batch(self, detections, labels):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=gs)
def get_dataloader(self, dataset_path, batch_size):
"""Construct and return dataloader."""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val')
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
*output_to_target(preds, max_det=self.args.max_det),
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(file, 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5)})
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, 'bbox')
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate trained YOLO model on validation dataset."""
model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco128.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = DetectionValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import PosePredictor, predict
from .train import PoseTrainer, train
from .val import PoseValidator, val
__all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict'

View File

@ -0,0 +1,58 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ROOT, ops
class PosePredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'pose'
def postprocess(self, preds, img, orig_imgs):
"""Return detection results for a given input image or list of images."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
nc=len(self.model.names))
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
shape = orig_img.shape
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape)
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(
Results(orig_img=orig_img,
path=img_path,
names=self.model.names,
boxes=pred[:, :6],
keypoints=pred_kpts))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO to predict objects in an image or video."""
model = cfg.model or 'yolov8n-pose.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = PosePredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -0,0 +1,77 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.plotting import plot_images, plot_results
# BaseTrainer python usage
class PoseTrainer(yolo.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
if overrides is None:
overrides = {}
overrides['task'] = 'pose'
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get pose estimation model with specified configuration and weights."""
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
if weights:
model.load(weights)
return model
def set_model_attributes(self):
"""Sets keypoints shape attribute of PoseModel."""
super().set_model_attributes()
self.model.kpt_shape = self.data['kpt_shape']
def get_validator(self):
"""Returns an instance of the PoseValidator class for validation."""
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
images = batch['img']
kpts = batch['keypoints']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images,
batch_idx,
cls,
bboxes,
kpts=kpts,
paths=paths,
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO model on the given data and device."""
model = cfg.model or 'yolov8n-pose.yaml'
data = cfg.data or 'coco8-pose.yaml'
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = PoseTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -0,0 +1,224 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PoseValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
batch['keypoints'] = batch['keypoints'].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
self.kpt_shape = self.data['kpt_shape']
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
kpts = batch['keypoints'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
nk = kpts.shape[1] # number of keypoints
shape = batch['ori_shape'][si]
correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
pred_kpts = predn[:, 6:].view(npr, nk, -1)
ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
tkpts = kpts.clone()
tkpts[..., 0] *= width
tkpts[..., 1] *= height
tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn[:, :6], labelsn)
correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
pred_kpts (array[N, 51]), 51 = 17 * 3
gt_kpts (array[N, 51])
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
kpts=batch['keypoints'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
plot_images(batch['img'],
*output_to_target(preds, max_det=self.args.max_det),
kpts=pred_kpts,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'keypoints': p[6:],
'score': round(p[4], 5)})
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
"""Performs validation on YOLO model using given data."""
model = cfg.model or 'yolov8n-pose.pt'
data = cfg.data or 'coco8-pose.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = PoseValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .predict import SegmentationPredictor, predict
from .train import SegmentationTrainer, train
from .val import SegmentationValidator, val
__all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'

View File

@ -0,0 +1,63 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ROOT, ops
class SegmentationPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
if not len(pred): # save empty boxes
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
continue
if self.args.retina_masks:
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO object detection on an image or video source."""
model = cfg.model or 'yolov8n-seg.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -0,0 +1,65 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.utils.plotting import plot_images, plot_results
# BaseTrainer python usage
class SegmentationTrainer(yolo.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides['task'] = 'segment'
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train a YOLO segmentation model based on passed arguments."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = SegmentationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -0,0 +1,262 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from multiprocessing.pool import ThreadPool
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
return batch
def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Masks
midx = [si] if self.args.overlap_mask else idx
gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn,
labelsn,
pred_masks,
gt_masks,
overlap=self.args.overlap_mask,
masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape,
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if masks:
if overlap:
nl = len(labels)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
"""Save one JSON result."""
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
return rle
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate trained YOLO model on validation data."""
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = SegmentationValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()