diff --git a/docs/models/rtdetr.md b/docs/models/rtdetr.md
index 61d156a..f2c6516 100644
--- a/docs/models/rtdetr.md
+++ b/docs/models/rtdetr.md
@@ -35,6 +35,7 @@ from ultralytics import RTDETR
model = RTDETR("rtdetr-l.pt")
model.info() # display model information
+model.train(data="coco8.yaml") # train
model.predict("path/to/image.jpg") # predict
```
@@ -51,7 +52,7 @@ model.predict("path/to/image.jpg") # predict
|------------|--------------------|
| Inference | :heavy_check_mark: |
| Validation | :heavy_check_mark: |
-| Training | :x: (Coming soon) |
+| Training | :heavy_check_mark: |
# Citations and Acknowledgements
@@ -70,4 +71,4 @@ If you use Baidu's RT-DETR in your research or development work, please cite the
We would like to acknowledge Baidu and the [PaddlePaddle](https://github.com/PaddlePaddle/PaddleDetection) team for creating and maintaining this valuable resource for the computer vision community. Their contribution to the field with the development of the Vision Transformers-based real-time object detector, RT-DETR, is greatly appreciated.
-*Keywords: RT-DETR, Transformer, ViT, Vision Transformers, Baidu RT-DETR, PaddlePaddle, Paddle Paddle RT-DETR, real-time object detection, Vision Transformers-based object detection, pre-trained PaddlePaddle RT-DETR models, Baidu's RT-DETR usage, Ultralytics Python API*
\ No newline at end of file
+*Keywords: RT-DETR, Transformer, ViT, Vision Transformers, Baidu RT-DETR, PaddlePaddle, Paddle Paddle RT-DETR, real-time object detection, Vision Transformers-based object detection, pre-trained PaddlePaddle RT-DETR models, Baidu's RT-DETR usage, Ultralytics Python API*
diff --git a/docs/reference/yolo/utils/loss.md b/docs/reference/yolo/utils/loss.md
index 38c4f0a..23cba8f 100644
--- a/docs/reference/yolo/utils/loss.md
+++ b/docs/reference/yolo/utils/loss.md
@@ -8,6 +8,11 @@ keywords: Ultralytics, YOLO, loss functions, object detection, keypoint detectio
:::ultralytics.yolo.utils.loss.VarifocalLoss
+# FocalLoss
+---
+:::ultralytics.yolo.utils.loss.FocalLoss
+
+
# BboxLoss
---
:::ultralytics.yolo.utils.loss.BboxLoss
diff --git a/docs/reference/yolo/utils/metrics.md b/docs/reference/yolo/utils/metrics.md
index 7f14803..c230c2c 100644
--- a/docs/reference/yolo/utils/metrics.md
+++ b/docs/reference/yolo/utils/metrics.md
@@ -3,11 +3,6 @@ description: Explore Ultralytics YOLO's FocalLoss, DetMetrics, PoseMetrics, Clas
keywords: YOLOv5, metrics, losses, confusion matrix, detection metrics, pose metrics, classification metrics, intersection over area, intersection over union, keypoint intersection over union, average precision, per class average precision, Ultralytics Docs
---
-# FocalLoss
----
-:::ultralytics.yolo.utils.metrics.FocalLoss
-
-
# ConfusionMatrix
---
:::ultralytics.yolo.utils.metrics.ConfusionMatrix
diff --git a/tests/test_python.py b/tests/test_python.py
index 53e2789..f633bd6 100644
--- a/tests/test_python.py
+++ b/tests/test_python.py
@@ -7,7 +7,7 @@ import numpy as np
import torch
from PIL import Image
-from ultralytics import YOLO
+from ultralytics import RTDETR, YOLO
from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS
@@ -174,7 +174,10 @@ def test_export_paddle(enabled=False):
def test_all_model_yamls():
for m in list((ROOT / 'models').rglob('yolo*.yaml')):
- YOLO(m.name)
+ if m.name == 'yolov8-rtdetr.yaml': # except the rtdetr model
+ RTDETR(m.name)
+ else:
+ YOLO(m.name)
def test_workflow():
diff --git a/ultralytics/models/rt-detr/rt-detr-l.yaml b/ultralytics/models/rt-detr/rtdetr-l.yaml
similarity index 100%
rename from ultralytics/models/rt-detr/rt-detr-l.yaml
rename to ultralytics/models/rt-detr/rtdetr-l.yaml
diff --git a/ultralytics/models/rt-detr/rt-detr-x.yaml b/ultralytics/models/rt-detr/rtdetr-x.yaml
similarity index 100%
rename from ultralytics/models/rt-detr/rt-detr-x.yaml
rename to ultralytics/models/rt-detr/rtdetr-x.yaml
diff --git a/ultralytics/models/v8/yolov8-rtdetr.yaml b/ultralytics/models/v8/yolov8-rtdetr.yaml
new file mode 100644
index 0000000..a058106
--- /dev/null
+++ b/ultralytics/models/v8/yolov8-rtdetr.yaml
@@ -0,0 +1,46 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+
+# Parameters
+nc: 80 # number of classes
+scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
+ # [depth, width, max_channels]
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
+
+# YOLOv8.0n backbone
+backbone:
+ # [from, repeats, module, args]
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
+ - [-1, 3, C2f, [128, True]]
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
+ - [-1, 6, C2f, [256, True]]
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
+ - [-1, 6, C2f, [512, True]]
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
+ - [-1, 3, C2f, [1024, True]]
+ - [-1, 1, SPPF, [1024, 5]] # 9
+
+# YOLOv8.0n head
+head:
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
+ - [-1, 3, C2f, [512]] # 12
+
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
+
+ - [-1, 1, Conv, [256, 3, 2]]
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
+
+ - [-1, 1, Conv, [512, 3, 2]]
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
+
+ - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py
index f7105bf..8b55fd0 100644
--- a/ultralytics/nn/modules/head.py
+++ b/ultralytics/nn/modules/head.py
@@ -163,220 +163,187 @@ class RTDETRDecoder(nn.Module):
self,
nc=80,
ch=(512, 1024, 2048),
- hidden_dim=256,
- num_queries=300,
- strides=(8, 16, 32), # TODO
- nl=3,
- num_decoder_points=4,
- nhead=8,
- num_decoder_layers=6,
- dim_feedforward=1024,
+ hd=256, # hidden dim
+ nq=300, # num queries
+ ndp=4, # num decoder points
+ nh=8, # num head
+ ndl=6, # num decoder layers
+ d_ffn=1024, # dim of feedforward
dropout=0.,
act=nn.ReLU(),
eval_idx=-1,
# training args
- num_denoising=100,
+ nd=100, # num denoising
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=False):
super().__init__()
- assert len(ch) <= nl
- assert len(strides) == len(ch)
- for _ in range(nl - len(strides)):
- strides.append(strides[-1] * 2)
-
- self.hidden_dim = hidden_dim
- self.nhead = nhead
- self.feat_strides = strides
- self.nl = nl
+ self.hidden_dim = hd
+ self.nhead = nh
+ self.nl = len(ch) # num level
self.nc = nc
- self.num_queries = num_queries
- self.num_decoder_layers = num_decoder_layers
+ self.num_queries = nq
+ self.num_decoder_layers = ndl
# backbone feature projection
- self._build_input_proj_layer(ch)
+ self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
+ # NOTE: simplified version but it's not consistent with .pt weights.
+ # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
# Transformer module
- decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, act, nl,
- num_decoder_points)
- self.decoder = DeformableTransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
+ decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
+ self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
# denoising part
- self.denoising_class_embed = nn.Embedding(nc, hidden_dim)
- self.num_denoising = num_denoising
+ self.denoising_class_embed = nn.Embedding(nc, hd)
+ self.num_denoising = nd
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
# decoder embedding
self.learnt_init_query = learnt_init_query
if learnt_init_query:
- self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
- self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
+ self.tgt_embed = nn.Embedding(nq, hd)
+ self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
# encoder head
- self.enc_output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim))
- self.enc_score_head = nn.Linear(hidden_dim, nc)
- self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+ self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
+ self.enc_score_head = nn.Linear(hd, nc)
+ self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
# decoder head
- self.dec_score_head = nn.ModuleList([nn.Linear(hidden_dim, nc) for _ in range(num_decoder_layers)])
- self.dec_bbox_head = nn.ModuleList([
- MLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(num_decoder_layers)])
+ self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
+ self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
self._reset_parameters()
- def forward(self, feats, gt_meta=None):
+ def forward(self, x, batch=None):
+ from ultralytics.vit.utils.ops import get_cdn_group
+
# input projection and embedding
- memory, spatial_shapes, _ = self._get_encoder_input(feats)
+ feats, shapes = self._get_encoder_input(x)
# prepare denoising training
- if self.training:
- raise NotImplementedError
- # denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
- # get_contrastive_denoising_training_group(gt_meta,
- # self.num_classes,
- # self.num_queries,
- # self.denoising_class_embed.weight,
- # self.num_denoising,
- # self.label_noise_ratio,
- # self.box_noise_scale)
- else:
- denoising_class, denoising_bbox_unact, attn_mask = None, None, None
-
- target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
- self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
+ dn_embed, dn_bbox, attn_mask, dn_meta = \
+ get_cdn_group(batch,
+ self.nc,
+ self.num_queries,
+ self.denoising_class_embed.weight,
+ self.num_denoising,
+ self.label_noise_ratio,
+ self.box_noise_scale,
+ self.training)
+
+ embed, refer_bbox, enc_bboxes, enc_scores = \
+ self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
# decoder
- out_bboxes, out_logits = self.decoder(target,
- init_ref_points_unact,
- memory,
- spatial_shapes,
+ dec_bboxes, dec_scores = self.decoder(embed,
+ refer_bbox,
+ feats,
+ shapes,
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask)
if not self.training:
- out_logits = out_logits.sigmoid_()
- return out_bboxes, out_logits # enc_topk_bboxes, enc_topk_logits, dn_meta
-
- def _reset_parameters(self):
- # class and bbox head init
- bias_cls = bias_init_with_prob(0.01)
- linear_init_(self.enc_score_head)
- constant_(self.enc_score_head.bias, bias_cls)
- constant_(self.enc_bbox_head.layers[-1].weight, 0.)
- constant_(self.enc_bbox_head.layers[-1].bias, 0.)
- for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
- linear_init_(cls_)
- constant_(cls_.bias, bias_cls)
- constant_(reg_.layers[-1].weight, 0.)
- constant_(reg_.layers[-1].bias, 0.)
+ dec_scores = dec_scores.sigmoid_()
+ return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
- linear_init_(self.enc_output[0])
- xavier_uniform_(self.enc_output[0].weight)
- if self.learnt_init_query:
- xavier_uniform_(self.tgt_embed.weight)
- xavier_uniform_(self.query_pos_head.layers[0].weight)
- xavier_uniform_(self.query_pos_head.layers[1].weight)
- for layer in self.input_proj:
- xavier_uniform_(layer[0].weight)
-
- def _build_input_proj_layer(self, ch):
- self.input_proj = nn.ModuleList()
- for in_channels in ch:
- self.input_proj.append(
- nn.Sequential(nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1, bias=False),
- nn.BatchNorm2d(self.hidden_dim)))
- in_channels = ch[-1]
- for _ in range(self.nl - len(ch)):
- self.input_proj.append(
- nn.Sequential(nn.Conv2D(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(self.hidden_dim)))
- in_channels = self.hidden_dim
-
- def _generate_anchors(self, spatial_shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
+ def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
anchors = []
- for lvl, (h, w) in enumerate(spatial_shapes):
- grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=torch.float32),
- torch.arange(end=w, dtype=torch.float32),
+ for i, (h, w) in enumerate(shapes):
+ grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
+ torch.arange(end=w, dtype=dtype, device=device),
indexing='ij')
- grid_xy = torch.stack([grid_x, grid_y], -1)
+ grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
- valid_WH = torch.tensor([h, w]).to(torch.float32)
- grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
- wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
- anchors.append(torch.concat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
+ valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
+ wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
+ anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
- anchors = torch.concat(anchors, 1)
- valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
+ anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.inf)
- return anchors.to(device=device, dtype=dtype), valid_mask.to(device=device)
+ return anchors, valid_mask
- def _get_encoder_input(self, feats):
+ def _get_encoder_input(self, x):
# get projection features
- proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
- if self.nl > len(proj_feats):
- len_srcs = len(proj_feats)
- for i in range(len_srcs, self.nl):
- if i == len_srcs:
- proj_feats.append(self.input_proj[i](feats[-1]))
- else:
- proj_feats.append(self.input_proj[i](proj_feats[-1]))
-
+ x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
# get encoder inputs
- feat_flatten = []
- spatial_shapes = []
- level_start_index = [0]
- for feat in proj_feats:
- _, _, h, w = feat.shape
+ feats = []
+ shapes = []
+ for feat in x:
+ h, w = feat.shape[2:]
# [b, c, h, w] -> [b, h*w, c]
- feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
+ feats.append(feat.flatten(2).permute(0, 2, 1))
# [nl, 2]
- spatial_shapes.append([h, w])
- # [l], start index of each level
- level_start_index.append(h * w + level_start_index[-1])
+ shapes.append([h, w])
- # [b, l, c]
- feat_flatten = torch.concat(feat_flatten, 1)
- level_start_index.pop()
- return feat_flatten, spatial_shapes, level_start_index
+ # [b, h*w, c]
+ feats = torch.cat(feats, 1)
+ return feats, shapes
- def _get_decoder_input(self, memory, spatial_shapes, denoising_class=None, denoising_bbox_unact=None):
- bs, _, _ = memory.shape
+ def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
+ bs = len(feats)
# prepare input for decoder
- anchors, valid_mask = self._generate_anchors(spatial_shapes, dtype=memory.dtype, device=memory.device)
- memory = torch.where(valid_mask, memory, 0)
- output_memory = self.enc_output(memory)
+ anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
+ features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256
- enc_outputs_class = self.enc_score_head(output_memory) # (bs, h*w, nc)
- enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors # (bs, h*w, 4)
+ enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
+ # dynamic anchors + static content
+ enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
- # (bs, topk)
- _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
- # extract region proposal boxes
- # (bs, topk_ind)
+ # query selection
+ # (bs, num_queries)
+ topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
+ # (bs, num_queries)
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
- topk_ind = topk_ind.view(-1)
# Unsigmoided
- reference_points_unact = enc_outputs_coord_unact[batch_ind, topk_ind].view(bs, self.num_queries, -1)
+ refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
+ # refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
- enc_topk_bboxes = torch.sigmoid(reference_points_unact)
- if denoising_bbox_unact is not None:
- reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
+ enc_bboxes = refer_bbox.sigmoid()
+ if dn_bbox is not None:
+ refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
if self.training:
- reference_points_unact = reference_points_unact.detach()
- enc_topk_logits = enc_outputs_class[batch_ind, topk_ind].view(bs, self.num_queries, -1)
+ refer_bbox = refer_bbox.detach()
+ enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
- # extract region features
if self.learnt_init_query:
- target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+ embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
else:
- target = output_memory[batch_ind, topk_ind].view(bs, self.num_queries, -1)
+ embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
if self.training:
- target = target.detach()
- if denoising_class is not None:
- target = torch.concat([denoising_class, target], 1)
+ embeddings = embeddings.detach()
+ if dn_embed is not None:
+ embeddings = torch.cat([dn_embed, embeddings], 1)
+
+ return embeddings, refer_bbox, enc_bboxes, enc_scores
- return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
+ # TODO
+ def _reset_parameters(self):
+ # class and bbox head init
+ bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
+ # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
+ # linear_init_(self.enc_score_head)
+ constant_(self.enc_score_head.bias, bias_cls)
+ constant_(self.enc_bbox_head.layers[-1].weight, 0.)
+ constant_(self.enc_bbox_head.layers[-1].bias, 0.)
+ for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
+ # linear_init_(cls_)
+ constant_(cls_.bias, bias_cls)
+ constant_(reg_.layers[-1].weight, 0.)
+ constant_(reg_.layers[-1].bias, 0.)
+
+ linear_init_(self.enc_output[0])
+ xavier_uniform_(self.enc_output[0].weight)
+ if self.learnt_init_query:
+ xavier_uniform_(self.tgt_embed.weight)
+ xavier_uniform_(self.query_pos_head.layers[0].weight)
+ xavier_uniform_(self.query_pos_head.layers[1].weight)
+ for layer in self.input_proj:
+ xavier_uniform_(layer[0].weight)
diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py
index 4745905..89a7d91 100644
--- a/ultralytics/nn/modules/transformer.py
+++ b/ultralytics/nn/modules/transformer.py
@@ -229,23 +229,23 @@ class MSDeformAttn(nn.Module):
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
- def forward(self, query, reference_points, value, value_spatial_shapes, value_mask=None):
+ def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
"""
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
Args:
- query (Tensor): [bs, query_length, C]
- reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
+ query (torch.Tensor): [bs, query_length, C]
+ refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area
- value (Tensor): [bs, value_length, C]
- value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ value (torch.Tensor): [bs, value_length, C]
+ value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, len_q = query.shape[:2]
- _, len_v = value.shape[:2]
- assert sum(s[0] * s[1] for s in value_spatial_shapes) == len_v
+ len_v = value.shape[1]
+ assert sum(s[0] * s[1] for s in value_shapes) == len_v
value = self.value_proj(value)
if value_mask is not None:
@@ -255,18 +255,17 @@ class MSDeformAttn(nn.Module):
attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
- n = reference_points.shape[-1]
- if n == 2:
- offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip(-1)
+ num_points = refer_bbox.shape[-1]
+ if num_points == 2:
+ offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
- sampling_locations = reference_points[:, :, None, :, None, :] + add
-
- elif n == 4:
- add = sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
- sampling_locations = reference_points[:, :, None, :, None, :2] + add
+ sampling_locations = refer_bbox[:, :, None, :, None, :] + add
+ elif num_points == 4:
+ add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
+ sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
else:
- raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {n}.')
- output = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights)
+ raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
+ output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output
@@ -308,33 +307,24 @@ class DeformableTransformerDecoderLayer(nn.Module):
tgt = self.norm3(tgt)
return tgt
- def forward(self,
- tgt,
- reference_points,
- src,
- src_spatial_shapes,
- src_padding_mask=None,
- attn_mask=None,
- query_pos=None):
+ def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
# self attention
- q = k = self.with_pos_embed(tgt, query_pos)
- if attn_mask is not None:
- attn_mask = torch.where(attn_mask.astype('bool'), torch.zeros(attn_mask.shape, tgt.dtype),
- torch.full(attn_mask.shape, float('-inf'), tgt.dtype))
- tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
- tgt = tgt + self.dropout1(tgt2)
- tgt = self.norm1(tgt)
+ q = k = self.with_pos_embed(embed, query_pos)
+ tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
+ attn_mask=attn_mask)[0].transpose(0, 1)
+ embed = embed + self.dropout1(tgt)
+ embed = self.norm1(embed)
# cross attention
- tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), reference_points, src, src_spatial_shapes,
- src_padding_mask)
- tgt = tgt + self.dropout2(tgt2)
- tgt = self.norm2(tgt)
+ tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
+ padding_mask)
+ embed = embed + self.dropout2(tgt)
+ embed = self.norm2(embed)
# ffn
- tgt = self.forward_ffn(tgt)
+ embed = self.forward_ffn(embed)
- return tgt
+ return embed
class DeformableTransformerDecoder(nn.Module):
@@ -349,41 +339,40 @@ class DeformableTransformerDecoder(nn.Module):
self.hidden_dim = hidden_dim
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
- def forward(self,
- tgt,
- reference_points,
- src,
- src_spatial_shapes,
- bbox_head,
- score_head,
- query_pos_head,
- attn_mask=None,
- src_padding_mask=None):
- output = tgt
- dec_out_bboxes = []
- dec_out_logits = []
- ref_points = None
- ref_points_detach = torch.sigmoid(reference_points)
+ def forward(
+ self,
+ embed, # decoder embeddings
+ refer_bbox, # anchor
+ feats, # image features
+ shapes, # feature shapes
+ bbox_head,
+ score_head,
+ pos_mlp,
+ attn_mask=None,
+ padding_mask=None):
+ output = embed
+ dec_bboxes = []
+ dec_cls = []
+ last_refined_bbox = None
+ refer_bbox = refer_bbox.sigmoid()
for i, layer in enumerate(self.layers):
- ref_points_input = ref_points_detach.unsqueeze(2)
- query_pos_embed = query_pos_head(ref_points_detach)
- output = layer(output, ref_points_input, src, src_spatial_shapes, src_padding_mask, attn_mask,
- query_pos_embed)
+ output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
- inter_ref_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
+ # refine bboxes, (bs, num_queries+num_denoising, 4)
+ refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
if self.training:
- dec_out_logits.append(score_head[i](output))
+ dec_cls.append(score_head[i](output))
if i == 0:
- dec_out_bboxes.append(inter_ref_bbox)
+ dec_bboxes.append(refined_bbox)
else:
- dec_out_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
+ dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
elif i == self.eval_idx:
- dec_out_logits.append(score_head[i](output))
- dec_out_bboxes.append(inter_ref_bbox)
+ dec_cls.append(score_head[i](output))
+ dec_bboxes.append(refined_bbox)
break
- ref_points = inter_ref_bbox
- ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox
+ last_refined_bbox = refined_bbox
+ refer_bbox = refined_bbox.detach() if self.training else refined_bbox
- return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
+ return torch.stack(dec_bboxes), torch.stack(dec_cls)
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 7af3f34..6e75553 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -210,7 +210,9 @@ class BaseModel(nn.Module):
"""
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
- return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
+
+ preds = self.forward(batch['img']) if preds is None else preds
+ return self.criterion(preds, batch)
def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
@@ -410,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
"""Compute the classification loss between predictions and true labels."""
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
- return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True)
+ return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
def loss(self, batch, preds=None):
if not hasattr(self, 'criterion'):
@@ -420,31 +422,36 @@ class RTDETRDetectionModel(DetectionModel):
# NOTE: preprocess gt_bbox and gt_labels to list.
bs = len(img)
batch_idx = batch['batch_idx']
- gt_bbox, gt_class = [], []
+ gt_groups = []
for i in range(bs):
- gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device))
- gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long))
- targets = {'cls': gt_class, 'bboxes': gt_bbox}
+ gt_groups.append((batch_idx == i).sum().item())
+ targets = {
+ 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
+ 'bboxes': batch['bboxes'].to(device=img.device),
+ 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
+ 'gt_groups': gt_groups}
preds = self.predict(img, batch=targets) if preds is None else preds
- dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds
- # NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
+ dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
if dn_meta is None:
- return 0, torch.zeros(3, device=dec_out_bboxes.device)
- dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
- dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2)
+ dn_bboxes, dn_scores = None, None
+ else:
+ dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
+ dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
- out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
- out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits])
+ dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
+ dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
- loss = self.criterion((out_bboxes, out_logits),
+ loss = self.criterion((dec_bboxes, dec_scores),
targets,
- dn_out_bboxes=dn_out_bboxes,
- dn_out_logits=dn_out_logits,
+ dn_bboxes=dn_bboxes,
+ dn_scores=dn_scores,
dn_meta=dn_meta)
- return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']])
+ # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
+ return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
+ device=img.device)
- def predict(self, x, profile=False, visualize=False, batch=None):
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
"""
Perform a forward pass through the network.
diff --git a/ultralytics/vit/__init__.py b/ultralytics/vit/__init__.py
index e142705..8e96f91 100644
--- a/ultralytics/vit/__init__.py
+++ b/ultralytics/vit/__init__.py
@@ -3,4 +3,4 @@
from .rtdetr import RTDETR
from .sam import SAM
-__all__ = 'RTDETR', 'SAM', 'SAM' # allow simpler import
+__all__ = 'RTDETR', 'SAM' # allow simpler import
diff --git a/ultralytics/vit/rtdetr/model.py b/ultralytics/vit/rtdetr/model.py
index 1297e15..322912c 100644
--- a/ultralytics/vit/rtdetr/model.py
+++ b/ultralytics/vit/rtdetr/model.py
@@ -5,15 +5,15 @@
from pathlib import Path
-from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, yaml_model_load
+from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
-from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
+from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
-from ultralytics.yolo.utils.torch_utils import model_info
+from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
-from ...yolo.utils.torch_utils import smart_inference_mode
from .predict import RTDETRPredictor
+from .train import RTDETRTrainer
from .val import RTDETRValidator
@@ -24,6 +24,7 @@ class RTDETR:
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
# Load or create new YOLO model
self.predictor = None
+ self.ckpt = None
suffix = Path(model).suffix
if suffix == '.yaml':
self._new(model)
@@ -34,7 +35,7 @@ class RTDETR:
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = 'detect'
- self.model = DetectionModel(cfg_dict, verbose=verbose) # build model
+ self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
# Below added to allow export from yamls
self.model.args = DEFAULT_CFG_DICT # attach args to model
@@ -42,10 +43,20 @@ class RTDETR:
@smart_inference_mode()
def _load(self, weights: str):
- self.model, _ = attempt_load_one_weight(weights)
+ self.model, self.ckpt = attempt_load_one_weight(weights)
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.task = self.model.args['task']
+ @smart_inference_mode()
+ def load(self, weights='yolov8n.pt'):
+ """
+ Transfers parameters with matching names and shapes from 'weights' to model.
+ """
+ if isinstance(weights, (str, Path)):
+ weights, self.ckpt = attempt_load_one_weight(weights)
+ self.model.load(weights)
+ return self
+
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
@@ -74,8 +85,30 @@ class RTDETR:
return self.predictor(source, stream=stream)
def train(self, **kwargs):
- """Function trains models but raises an error as RTDETR models do not support training."""
- raise NotImplementedError("RTDETR models don't support training")
+ """
+ Trains the model on a given dataset.
+
+ Args:
+ **kwargs (Any): Any number of arguments representing the training configuration.
+ """
+ overrides = dict(task='detect', mode='train')
+ overrides.update(kwargs)
+ overrides['deterministic'] = False
+ if not overrides.get('data'):
+ raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
+ if overrides.get('resume'):
+ overrides['resume'] = self.ckpt_path
+ self.task = overrides.get('task') or self.task
+ self.trainer = RTDETRTrainer(overrides=overrides)
+ if not overrides.get('resume'): # manually set model only if not resuming
+ self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
+ self.model = self.trainer.model
+ self.trainer.train()
+ # Update model and cfg after training
+ if RANK in (-1, 0):
+ self.model, _ = attempt_load_one_weight(str(self.trainer.best))
+ self.overrides = self.model.args
+ self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def val(self, **kwargs):
"""Run validation given dataset."""
diff --git a/ultralytics/vit/rtdetr/train.py b/ultralytics/vit/rtdetr/train.py
new file mode 100644
index 0000000..5a29589
--- /dev/null
+++ b/ultralytics/vit/rtdetr/train.py
@@ -0,0 +1,78 @@
+from copy import copy
+
+import torch
+
+from ultralytics.nn.tasks import RTDETRDetectionModel
+from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
+from ultralytics.yolo.v8.detect import DetectionTrainer
+
+from .val import RTDETRDataset, RTDETRValidator
+
+
+class RTDETRTrainer(DetectionTrainer):
+
+ def get_model(self, cfg=None, weights=None, verbose=True):
+ """Return a YOLO detection model."""
+ model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ if weights:
+ model.load(weights)
+ return model
+
+ def build_dataset(self, img_path, mode='val', batch=None):
+ """Build RTDETR Dataset
+
+ Args:
+ img_path (str): Path to the folder containing images.
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+ batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+ """
+ return RTDETRDataset(
+ img_path=img_path,
+ imgsz=self.args.imgsz,
+ batch_size=batch,
+ augment=mode == 'train', # no augmentation
+ hyp=self.args,
+ rect=False, # no rect
+ cache=self.args.cache or None,
+ prefix=colorstr(f'{mode}: '),
+ data=self.data)
+
+ def get_validator(self):
+ """Returns a DetectionValidator for RTDETR model validation."""
+ self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
+ return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
+
+ def preprocess_batch(self, batch):
+ """Preprocesses a batch of images by scaling and converting to float."""
+ batch = super().preprocess_batch(batch)
+ bs = len(batch['img'])
+ batch_idx = batch['batch_idx']
+ gt_bbox, gt_class = [], []
+ for i in range(bs):
+ gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
+ gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
+ return batch
+
+
+def train(cfg=DEFAULT_CFG, use_python=False):
+ """Train and optimize RTDETR model given training data and device."""
+ model = 'rtdetr-l.yaml'
+ data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
+ device = cfg.device if cfg.device is not None else ''
+
+ # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
+ # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
+ args = dict(model=model,
+ data=data,
+ device=device,
+ imgsz=640,
+ exist_ok=True,
+ batch=4,
+ deterministic=False,
+ amp=False)
+ trainer = RTDETRTrainer(overrides=args)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ train()
diff --git a/ultralytics/vit/rtdetr/val.py b/ultralytics/vit/rtdetr/val.py
index e1619a6..57376a6 100644
--- a/ultralytics/vit/rtdetr/val.py
+++ b/ultralytics/vit/rtdetr/val.py
@@ -2,10 +2,12 @@
from pathlib import Path
+import cv2
+import numpy as np
import torch
from ultralytics.yolo.data import YOLODataset
-from ultralytics.yolo.data.augment import Compose, Format, LetterBox
+from ultralytics.yolo.data.augment import Compose, Format, v8_transforms
from ultralytics.yolo.utils import colorstr, ops
from ultralytics.yolo.v8.detect import DetectionValidator
@@ -18,9 +20,41 @@ class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
+ # NOTE: add stretch version load_image for rtdetr mosaic
+ def load_image(self, i):
+ """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
+ if im is None: # not cached in RAM
+ if fn.exists(): # load npy
+ im = np.load(fn)
+ else: # read image
+ im = cv2.imread(f) # BGR
+ if im is None:
+ raise FileNotFoundError(f'Image Not Found {f}')
+ h0, w0 = im.shape[:2] # orig hw
+ im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
+
+ # Add to buffer if training with augmentations
+ if self.augment:
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
+ self.buffer.append(i)
+ if len(self.buffer) >= self.max_buffer_length:
+ j = self.buffer.pop(0)
+ self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
+
+ return im, (h0, w0), im.shape[:2]
+
+ return self.ims[i], self.im_hw0[i], self.im_hw[i]
+
def build_transforms(self, hyp=None):
"""Temporarily, only for evaluation."""
- transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
+ if self.augment:
+ hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
+ hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
+ transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
+ else:
+ # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
+ transforms = Compose([])
transforms.append(
Format(bbox_format='xywh',
normalize=True,
@@ -65,6 +99,8 @@ class RTDETRValidator(DetectionValidator):
# Do not need threshold for evaluation as only got 300 boxes here.
# idx = score > self.args.conf
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
+ # sort by confidence to correctly get internal metrics.
+ pred = pred[score.argsort(descending=True)]
outputs[i] = pred # [idx]
return outputs
@@ -100,7 +136,8 @@ class RTDETRValidator(DetectionValidator):
tbox[..., [0, 2]] *= shape[1] # native-space pred
tbox[..., [1, 3]] *= shape[0] # native-space pred
labelsn = torch.cat((cls, tbox), 1) # native-space labels
- correct_bboxes = self._process_batch(predn, labelsn)
+ # NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
+ correct_bboxes = self._process_batch(predn.float(), labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
diff --git a/ultralytics/vit/sam/amg.py b/ultralytics/vit/sam/amg.py
index 3c70f7c..1522931 100644
--- a/ultralytics/vit/sam/amg.py
+++ b/ultralytics/vit/sam/amg.py
@@ -256,10 +256,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
- fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
- if not fill_labels:
- fill_labels = [int(np.argmax(sizes)) + 1]
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
diff --git a/ultralytics/vit/sam/modules/sam.py b/ultralytics/vit/sam/modules/sam.py
index 50a30ee..34963f1 100644
--- a/ultralytics/vit/sam/modules/sam.py
+++ b/ultralytics/vit/sam/modules/sam.py
@@ -18,14 +18,12 @@ class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = 'RGB'
- def __init__(
- self,
- image_encoder: ImageEncoderViT,
- prompt_encoder: PromptEncoder,
- mask_decoder: MaskDecoder,
- pixel_mean: List[float] = [123.675, 116.28, 103.53],
- pixel_std: List[float] = [58.395, 57.12, 57.375],
- ) -> None:
+ def __init__(self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = None,
+ pixel_std: List[float] = None) -> None:
"""
SAM predicts object masks from an image and input prompts.
@@ -38,6 +36,10 @@ class Sam(nn.Module):
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
+ if pixel_mean is None:
+ pixel_mean = [123.675, 116.28, 103.53]
+ if pixel_std is None:
+ pixel_std = [58.395, 57.12, 57.375]
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
diff --git a/ultralytics/vit/utils/loss.py b/ultralytics/vit/utils/loss.py
new file mode 100644
index 0000000..1a5ba29
--- /dev/null
+++ b/ultralytics/vit/utils/loss.py
@@ -0,0 +1,291 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ultralytics.vit.utils.ops import HungarianMatcher
+from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss
+from ultralytics.yolo.utils.metrics import bbox_iou
+
+
+class DETRLoss(nn.Module):
+
+ def __init__(self,
+ nc=80,
+ loss_gain=None,
+ aux_loss=True,
+ use_fl=True,
+ use_vfl=False,
+ use_uni_match=False,
+ uni_match_ind=0):
+ """
+ Args:
+ nc (int): The number of classes.
+ loss_gain (dict): The coefficient of loss.
+ aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
+ use_focal_loss (bool): Use focal loss or not.
+ use_vfl (bool): Use VarifocalLoss or not.
+ use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
+ uni_match_ind (int): The fixed indices of a layer.
+ """
+ super().__init__()
+
+ if loss_gain is None:
+ loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
+ self.nc = nc
+ self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
+ self.loss_gain = loss_gain
+ self.aux_loss = aux_loss
+ self.fl = FocalLoss() if use_fl else None
+ self.vfl = VarifocalLoss() if use_vfl else None
+
+ self.use_uni_match = use_uni_match
+ self.uni_match_ind = uni_match_ind
+ self.device = None
+
+ def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
+ # logits: [b, query, num_classes], gt_class: list[[n, 1]]
+ name_class = f'loss_class{postfix}'
+ bs, nq = pred_scores.shape[:2]
+ # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
+ one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
+ one_hot.scatter_(2, targets.unsqueeze(-1), 1)
+ one_hot = one_hot[..., :-1]
+ gt_scores = gt_scores.view(bs, nq, 1) * one_hot
+
+ if self.fl:
+ if num_gts and self.vfl:
+ loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
+ else:
+ loss_cls = self.fl(pred_scores, one_hot.float())
+ loss_cls /= max(num_gts, 1) / nq
+ else:
+ loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
+
+ return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
+
+ def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
+ # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
+ name_bbox = f'loss_bbox{postfix}'
+ name_giou = f'loss_giou{postfix}'
+
+ loss = {}
+ if len(gt_bboxes) == 0:
+ loss[name_bbox] = torch.tensor(0., device=self.device)
+ loss[name_giou] = torch.tensor(0., device=self.device)
+ return loss
+
+ loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
+ loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
+ loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
+ loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
+ loss = {k: v.squeeze() for k, v in loss.items()}
+ return loss
+
+ def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
+ # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
+ name_mask = f'loss_mask{postfix}'
+ name_dice = f'loss_dice{postfix}'
+
+ loss = {}
+ if sum(len(a) for a in gt_mask) == 0:
+ loss[name_mask] = torch.tensor(0., device=self.device)
+ loss[name_dice] = torch.tensor(0., device=self.device)
+ return loss
+
+ num_gts = len(gt_mask)
+ src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
+ src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
+ # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
+ loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
+ torch.tensor([num_gts], dtype=torch.float32))
+ loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
+ return loss
+
+ def _dice_loss(self, inputs, targets, num_gts):
+ inputs = F.sigmoid(inputs)
+ inputs = inputs.flatten(1)
+ targets = targets.flatten(1)
+ numerator = 2 * (inputs * targets).sum(1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_gts
+
+ def _get_loss_aux(self,
+ pred_bboxes,
+ pred_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ match_indices=None,
+ postfix='',
+ masks=None,
+ gt_mask=None):
+ """Get auxiliary losses"""
+ # NOTE: loss class, bbox, giou, mask, dice
+ loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
+ if match_indices is None and self.use_uni_match:
+ match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
+ pred_scores[self.uni_match_ind],
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=masks[self.uni_match_ind] if masks is not None else None,
+ gt_mask=gt_mask)
+ for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
+ aux_masks = masks[i] if masks is not None else None
+ loss_ = self._get_loss(aux_bboxes,
+ aux_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=aux_masks,
+ gt_mask=gt_mask,
+ postfix=postfix,
+ match_indices=match_indices)
+ loss[0] += loss_[f'loss_class{postfix}']
+ loss[1] += loss_[f'loss_bbox{postfix}']
+ loss[2] += loss_[f'loss_giou{postfix}']
+ # if masks is not None and gt_mask is not None:
+ # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
+ # loss[3] += loss_[f'loss_mask{postfix}']
+ # loss[4] += loss_[f'loss_dice{postfix}']
+
+ loss = {
+ f'loss_class_aux{postfix}': loss[0],
+ f'loss_bbox_aux{postfix}': loss[1],
+ f'loss_giou_aux{postfix}': loss[2]}
+ # if masks is not None and gt_mask is not None:
+ # loss[f'loss_mask_aux{postfix}'] = loss[3]
+ # loss[f'loss_dice_aux{postfix}'] = loss[4]
+ return loss
+
+ def _get_index(self, match_indices):
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
+ src_idx = torch.cat([src for (src, _) in match_indices])
+ dst_idx = torch.cat([dst for (_, dst) in match_indices])
+ return (batch_idx, src_idx), dst_idx
+
+ def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
+ pred_assigned = torch.cat([
+ t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+ for t, (I, _) in zip(pred_bboxes, match_indices)])
+ gt_assigned = torch.cat([
+ t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+ for t, (_, J) in zip(gt_bboxes, match_indices)])
+ return pred_assigned, gt_assigned
+
+ def _get_loss(self,
+ pred_bboxes,
+ pred_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=None,
+ gt_mask=None,
+ postfix='',
+ match_indices=None):
+ """Get losses"""
+ if match_indices is None:
+ match_indices = self.matcher(pred_bboxes,
+ pred_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=masks,
+ gt_mask=gt_mask)
+
+ idx, gt_idx = self._get_index(match_indices)
+ pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
+
+ bs, nq = pred_scores.shape[:2]
+ targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
+ targets[idx] = gt_cls[gt_idx]
+
+ gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
+ if len(gt_bboxes):
+ gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
+
+ loss = {}
+ loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
+ loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
+ # if masks is not None and gt_mask is not None:
+ # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
+ return loss
+
+ def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
+ """
+ Args:
+ pred_bboxes (torch.Tensor): [l, b, query, 4]
+ pred_scores (torch.Tensor): [l, b, query, num_classes]
+ batch (dict): A dict includes:
+ gt_cls (torch.Tensor) with shape [num_gts, ],
+ gt_bboxes (torch.Tensor): [num_gts, 4],
+ gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
+ postfix (str): postfix of loss name.
+ """
+ self.device = pred_bboxes.device
+ match_indices = kwargs.get('match_indices', None)
+ gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
+
+ total_loss = self._get_loss(pred_bboxes[-1],
+ pred_scores[-1],
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ postfix=postfix,
+ match_indices=match_indices)
+
+ if self.aux_loss:
+ total_loss.update(
+ self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
+ postfix))
+
+ return total_loss
+
+
+class RTDETRDetectionLoss(DETRLoss):
+
+ def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
+ pred_bboxes, pred_scores = preds
+ total_loss = super().forward(pred_bboxes, pred_scores, batch)
+
+ if dn_meta is not None:
+ dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
+ assert len(batch['gt_groups']) == len(dn_pos_idx)
+
+ # denoising match indices
+ match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
+
+ # compute denoising training loss
+ dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
+ total_loss.update(dn_loss)
+ else:
+ total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
+
+ return total_loss
+
+ @staticmethod
+ def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
+ """Get the match indices for denoising.
+
+ Args:
+ dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
+ dn_num_group (int): The number of groups of denoising.
+ gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
+
+ Returns:
+ dn_match_indices (List(tuple)): Matched indices.
+
+ """
+ dn_match_indices = []
+ idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
+ for i, num_gt in enumerate(gt_groups):
+ if num_gt > 0:
+ gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i]
+ gt_idx = gt_idx.repeat(dn_num_group)
+ assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
+ f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
+ dn_match_indices.append((dn_pos_idx[i], gt_idx))
+ else:
+ dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32)))
+ return dn_match_indices
diff --git a/ultralytics/vit/utils/ops.py b/ultralytics/vit/utils/ops.py
new file mode 100644
index 0000000..5b92963
--- /dev/null
+++ b/ultralytics/vit/utils/ops.py
@@ -0,0 +1,230 @@
+# TODO: license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+
+from ultralytics.yolo.utils.metrics import bbox_iou
+from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
+
+
+class HungarianMatcher(nn.Module):
+
+ def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
+ """
+ Args:
+ matcher_coeff (dict): The coefficient of hungarian matcher cost.
+ """
+ super().__init__()
+ if cost_gain is None:
+ cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
+ self.cost_gain = cost_gain
+ self.use_fl = use_fl
+ self.with_mask = with_mask
+ self.num_sample_points = num_sample_points
+ self.alpha = alpha
+ self.gamma = gamma
+
+ def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
+ """
+ Args:
+ pred_bboxes (Tensor): [b, query, 4]
+ pred_scores (Tensor): [b, query, num_classes]
+ gt_cls (torch.Tensor) with shape [num_gts, ]
+ gt_bboxes (torch.Tensor): [num_gts, 4]
+ gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
+ masks (Tensor|None): [b, query, h, w]
+ gt_mask (List(Tensor)): list[[n, H, W]]
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+ bs, nq, nc = pred_scores.shape
+
+ if sum(gt_groups) == 0:
+ return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
+
+ # We flatten to compute the cost matrices in a batch
+ # [batch_size * num_queries, num_classes]
+ pred_scores = pred_scores.detach().view(-1, nc)
+ pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
+ # [batch_size * num_queries, 4]
+ pred_bboxes = pred_bboxes.detach().view(-1, 4)
+
+ # Compute the classification cost
+ pred_scores = pred_scores[:, gt_cls]
+ if self.use_fl:
+ neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
+ pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
+ cost_class = pos_cost_class - neg_cost_class
+ else:
+ cost_class = -pred_scores
+
+ # Compute the L1 cost between boxes
+ cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
+
+ # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
+ cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
+
+ # Final cost matrix
+ C = self.cost_gain['class'] * cost_class + \
+ self.cost_gain['bbox'] * cost_bbox + \
+ self.cost_gain['giou'] * cost_giou
+ # Compute the mask cost and dice cost
+ if self.with_mask:
+ C += self._cost_mask(bs, gt_groups, masks, gt_mask)
+
+ C = C.view(bs, nq, -1).cpu()
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
+ gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
+ # (idx for queries, idx for gt)
+ return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
+ for k, (i, j) in enumerate(indices)]
+
+ def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
+ assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
+ # all masks share the same set of points for efficient matching
+ sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
+ sample_points = 2.0 * sample_points - 1.0
+
+ out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
+ out_mask = out_mask.flatten(0, 1)
+
+ tgt_mask = torch.cat(gt_mask).unsqueeze(1)
+ sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
+ tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
+
+ with torch.cuda.amp.autocast(False):
+ # binary cross entropy cost
+ pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
+ neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
+ cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
+ cost_mask /= self.num_sample_points
+
+ # dice cost
+ out_mask = F.sigmoid(out_mask)
+ numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
+ denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
+ cost_dice = 1 - (numerator + 1) / (denominator + 1)
+
+ C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
+ return C
+
+
+def get_cdn_group(batch,
+ num_classes,
+ num_queries,
+ class_embed,
+ num_dn=100,
+ cls_noise_ratio=0.5,
+ box_noise_scale=1.0,
+ training=False):
+ """Get contrastive denoising training group
+
+ Args:
+ batch (dict): A dict includes:
+ gt_cls (torch.Tensor) with shape [num_gts, ],
+ gt_bboxes (torch.Tensor): [num_gts, 4],
+ gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
+ num_classes (int): Number of classes.
+ num_queries (int): Number of queries.
+ class_embed (torch.Tensor): Embedding weights to map cls to embedding space.
+ num_dn (int): Number of denoising.
+ cls_noise_ratio (float): Noise ratio for class.
+ box_noise_scale (float): Noise scale for bbox.
+ training (bool): If it's training or not.
+
+ Returns:
+
+ """
+ if (not training) or num_dn <= 0:
+ return None, None, None, None
+ gt_groups = batch['gt_groups']
+ total_num = sum(gt_groups)
+ max_nums = max(gt_groups)
+ if max_nums == 0:
+ return None, None, None, None
+
+ num_group = num_dn // max_nums
+ num_group = 1 if num_group == 0 else num_group
+ # pad gt to max_num of a batch
+ bs = len(gt_groups)
+ gt_cls = batch['cls'] # (bs*num, )
+ gt_bbox = batch['bboxes'] # bs*num, 4
+ b_idx = batch['batch_idx']
+
+ # each group has positive and negative queries.
+ dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
+ dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
+ dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
+
+ # positive and negative mask
+ # (bs*num*num_group, ), the second total_num*num_group part as negative samples
+ neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
+
+ if cls_noise_ratio > 0:
+ # half of bbox prob
+ mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
+ idx = torch.nonzero(mask).squeeze(-1)
+ # randomly put a new one here
+ new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
+ dn_cls[idx] = new_label
+
+ if box_noise_scale > 0:
+ known_bbox = xywh2xyxy(dn_bbox)
+
+ diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
+
+ rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
+ rand_part = torch.rand_like(dn_bbox)
+ rand_part[neg_idx] += 1.0
+ rand_part *= rand_sign
+ known_bbox += rand_part * diff
+ known_bbox.clip_(min=0.0, max=1.0)
+ dn_bbox = xyxy2xywh(known_bbox)
+ dn_bbox = inverse_sigmoid(dn_bbox)
+
+ # total denoising queries
+ num_dn = int(max_nums * 2 * num_group)
+ # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
+ dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
+ padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
+ padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
+
+ map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
+ pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
+
+ map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
+ padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
+ padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
+
+ tgt_size = num_dn + num_queries
+ attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
+ # match query cannot see the reconstruct
+ attn_mask[num_dn:, :num_dn] = True
+ # reconstruct cannot see each other
+ for i in range(num_group):
+ if i == 0:
+ attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
+ if i == num_group - 1:
+ attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
+ else:
+ attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
+ attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
+ dn_meta = {
+ 'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)],
+ 'dn_num_group': num_group,
+ 'dn_num_split': [num_dn, num_queries]}
+
+ return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
+ class_embed.device), dn_meta
+
+
+def inverse_sigmoid(x, eps=1e-6):
+ x = x.clip(min=0., max=1.)
+ return torch.log(x / (1 - x + eps) + eps)
diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py
index 4c7f167..42688c9 100644
--- a/ultralytics/yolo/data/augment.py
+++ b/ultralytics/yolo/data/augment.py
@@ -759,7 +759,7 @@ class Format:
return masks, instances, cls
-def v8_transforms(dataset, imgsz, hyp):
+def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
@@ -770,7 +770,7 @@ def v8_transforms(dataset, imgsz, hyp):
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
- pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
+ pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
)])
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
if dataset.use_keypoints:
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 679c545..e4925b3 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -278,7 +278,8 @@ class BaseTrainer:
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
- nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
+ nw = max(round(self.args.warmup_epochs *
+ nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
last_opt_step = -1
self.run_callbacks('on_train_start')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py
index f165acb..71ed0a5 100644
--- a/ultralytics/yolo/utils/loss.py
+++ b/ultralytics/yolo/utils/loss.py
@@ -24,10 +24,34 @@ class VarifocalLoss(nn.Module):
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
- weight).sum()
+ weight).mean(1).sum()
return loss
+# Losses
+class FocalLoss(nn.Module):
+ """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
+
+ def __init__(self, ):
+ super().__init__()
+
+ def forward(self, pred, label, gamma=1.5, alpha=0.25):
+ """Calculates and updates confusion matrix for object detection/classification tasks."""
+ loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
+ # p_t = torch.exp(-loss)
+ # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
+
+ # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
+ pred_prob = pred.sigmoid() # prob from logits
+ p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
+ modulating_factor = (1.0 - p_t) ** gamma
+ loss *= modulating_factor
+ if alpha > 0:
+ alpha_factor = label * alpha + (1 - label) * (1 - alpha)
+ loss *= alpha_factor
+ return loss.mean(1).sum()
+
+
class BboxLoss(nn.Module):
def __init__(self, reg_max, use_dfl=False):
diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py
index 30899e4..8544adf 100644
--- a/ultralytics/yolo/utils/metrics.py
+++ b/ultralytics/yolo/utils/metrics.py
@@ -9,7 +9,6 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
-import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
@@ -175,40 +174,6 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
return 1.0 - 0.5 * eps, 0.5 * eps
-# Losses
-class FocalLoss(nn.Module):
- """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
-
- def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
- """Initialize FocalLoss object with given loss function and hyperparameters."""
- super().__init__()
- self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
- self.gamma = gamma
- self.alpha = alpha
- self.reduction = loss_fcn.reduction
- self.loss_fcn.reduction = 'none' # required to apply FL to each element
-
- def forward(self, pred, true):
- """Calculates and updates confusion matrix for object detection/classification tasks."""
- loss = self.loss_fcn(pred, true)
- # p_t = torch.exp(-loss)
- # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
-
- # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
- pred_prob = torch.sigmoid(pred) # prob from logits
- p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
- alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
- modulating_factor = (1.0 - p_t) ** self.gamma
- loss *= alpha_factor * modulating_factor
-
- if self.reduction == 'mean':
- return loss.mean()
- elif self.reduction == 'sum':
- return loss.sum()
- else: # 'None'
- return loss
-
-
class ConfusionMatrix:
"""
A class for calculating and updating a confusion matrix for object detection and classification tasks.
diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py
index cb3725a..a9d7917 100644
--- a/ultralytics/yolo/utils/torch_utils.py
+++ b/ultralytics/yolo/utils/torch_utils.py
@@ -327,6 +327,9 @@ def init_seeds(seed=0, deterministic=False):
os.environ['PYTHONHASHSEED'] = str(seed)
else:
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
+ else:
+ torch.use_deterministic_algorithms(False)
+ torch.backends.cudnn.deterministic = False
class ModelEMA: