Add RTDETR Trainer (#2745)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
single_channel
Ayush Chaurasia 1 year ago committed by GitHub
parent 03bce07848
commit a0ba8ef5f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,6 +35,7 @@ from ultralytics import RTDETR
model = RTDETR("rtdetr-l.pt") model = RTDETR("rtdetr-l.pt")
model.info() # display model information model.info() # display model information
model.train(data="coco8.yaml") # train
model.predict("path/to/image.jpg") # predict model.predict("path/to/image.jpg") # predict
``` ```
@ -51,7 +52,7 @@ model.predict("path/to/image.jpg") # predict
|------------|--------------------| |------------|--------------------|
| Inference | :heavy_check_mark: | | Inference | :heavy_check_mark: |
| Validation | :heavy_check_mark: | | Validation | :heavy_check_mark: |
| Training | :x: (Coming soon) | | Training | :heavy_check_mark: |
# Citations and Acknowledgements # Citations and Acknowledgements

@ -8,6 +8,11 @@ keywords: Ultralytics, YOLO, loss functions, object detection, keypoint detectio
:::ultralytics.yolo.utils.loss.VarifocalLoss :::ultralytics.yolo.utils.loss.VarifocalLoss
<br><br> <br><br>
# FocalLoss
---
:::ultralytics.yolo.utils.loss.FocalLoss
<br><br>
# BboxLoss # BboxLoss
--- ---
:::ultralytics.yolo.utils.loss.BboxLoss :::ultralytics.yolo.utils.loss.BboxLoss

@ -3,11 +3,6 @@ description: Explore Ultralytics YOLO's FocalLoss, DetMetrics, PoseMetrics, Clas
keywords: YOLOv5, metrics, losses, confusion matrix, detection metrics, pose metrics, classification metrics, intersection over area, intersection over union, keypoint intersection over union, average precision, per class average precision, Ultralytics Docs keywords: YOLOv5, metrics, losses, confusion matrix, detection metrics, pose metrics, classification metrics, intersection over area, intersection over union, keypoint intersection over union, average precision, per class average precision, Ultralytics Docs
--- ---
# FocalLoss
---
:::ultralytics.yolo.utils.metrics.FocalLoss
<br><br>
# ConfusionMatrix # ConfusionMatrix
--- ---
:::ultralytics.yolo.utils.metrics.ConfusionMatrix :::ultralytics.yolo.utils.metrics.ConfusionMatrix

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ultralytics import YOLO from ultralytics import RTDETR, YOLO
from ultralytics.yolo.data.build import load_inference_source from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS
@ -174,6 +174,9 @@ def test_export_paddle(enabled=False):
def test_all_model_yamls(): def test_all_model_yamls():
for m in list((ROOT / 'models').rglob('yolo*.yaml')): for m in list((ROOT / 'models').rglob('yolo*.yaml')):
if m.name == 'yolov8-rtdetr.yaml': # except the rtdetr model
RTDETR(m.name)
else:
YOLO(m.name) YOLO(m.name)

@ -0,0 +1,46 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)

@ -163,220 +163,187 @@ class RTDETRDecoder(nn.Module):
self, self,
nc=80, nc=80,
ch=(512, 1024, 2048), ch=(512, 1024, 2048),
hidden_dim=256, hd=256, # hidden dim
num_queries=300, nq=300, # num queries
strides=(8, 16, 32), # TODO ndp=4, # num decoder points
nl=3, nh=8, # num head
num_decoder_points=4, ndl=6, # num decoder layers
nhead=8, d_ffn=1024, # dim of feedforward
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0., dropout=0.,
act=nn.ReLU(), act=nn.ReLU(),
eval_idx=-1, eval_idx=-1,
# training args # training args
num_denoising=100, nd=100, # num denoising
label_noise_ratio=0.5, label_noise_ratio=0.5,
box_noise_scale=1.0, box_noise_scale=1.0,
learnt_init_query=False): learnt_init_query=False):
super().__init__() super().__init__()
assert len(ch) <= nl self.hidden_dim = hd
assert len(strides) == len(ch) self.nhead = nh
for _ in range(nl - len(strides)): self.nl = len(ch) # num level
strides.append(strides[-1] * 2)
self.hidden_dim = hidden_dim
self.nhead = nhead
self.feat_strides = strides
self.nl = nl
self.nc = nc self.nc = nc
self.num_queries = num_queries self.num_queries = nq
self.num_decoder_layers = num_decoder_layers self.num_decoder_layers = ndl
# backbone feature projection # backbone feature projection
self._build_input_proj_layer(ch) self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
# NOTE: simplified version but it's not consistent with .pt weights.
# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
# Transformer module # Transformer module
decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, act, nl, decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
num_decoder_points) self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
self.decoder = DeformableTransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
# denoising part # denoising part
self.denoising_class_embed = nn.Embedding(nc, hidden_dim) self.denoising_class_embed = nn.Embedding(nc, hd)
self.num_denoising = num_denoising self.num_denoising = nd
self.label_noise_ratio = label_noise_ratio self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale self.box_noise_scale = box_noise_scale
# decoder embedding # decoder embedding
self.learnt_init_query = learnt_init_query self.learnt_init_query = learnt_init_query
if learnt_init_query: if learnt_init_query:
self.tgt_embed = nn.Embedding(num_queries, hidden_dim) self.tgt_embed = nn.Embedding(nq, hd)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
# encoder head # encoder head
self.enc_output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim)) self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
self.enc_score_head = nn.Linear(hidden_dim, nc) self.enc_score_head = nn.Linear(hd, nc)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
# decoder head # decoder head
self.dec_score_head = nn.ModuleList([nn.Linear(hidden_dim, nc) for _ in range(num_decoder_layers)]) self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
self.dec_bbox_head = nn.ModuleList([ self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
MLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(num_decoder_layers)])
self._reset_parameters() self._reset_parameters()
def forward(self, feats, gt_meta=None): def forward(self, x, batch=None):
from ultralytics.vit.utils.ops import get_cdn_group
# input projection and embedding # input projection and embedding
memory, spatial_shapes, _ = self._get_encoder_input(feats) feats, shapes = self._get_encoder_input(x)
# prepare denoising training # prepare denoising training
if self.training: dn_embed, dn_bbox, attn_mask, dn_meta = \
raise NotImplementedError get_cdn_group(batch,
# denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ self.nc,
# get_contrastive_denoising_training_group(gt_meta, self.num_queries,
# self.num_classes, self.denoising_class_embed.weight,
# self.num_queries, self.num_denoising,
# self.denoising_class_embed.weight, self.label_noise_ratio,
# self.num_denoising, self.box_noise_scale,
# self.label_noise_ratio, self.training)
# self.box_noise_scale)
else: embed, refer_bbox, enc_bboxes, enc_scores = \
denoising_class, denoising_bbox_unact, attn_mask = None, None, None self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
# decoder # decoder
out_bboxes, out_logits = self.decoder(target, dec_bboxes, dec_scores = self.decoder(embed,
init_ref_points_unact, refer_bbox,
memory, feats,
spatial_shapes, shapes,
self.dec_bbox_head, self.dec_bbox_head,
self.dec_score_head, self.dec_score_head,
self.query_pos_head, self.query_pos_head,
attn_mask=attn_mask) attn_mask=attn_mask)
if not self.training: if not self.training:
out_logits = out_logits.sigmoid_() dec_scores = dec_scores.sigmoid_()
return out_bboxes, out_logits # enc_topk_bboxes, enc_topk_logits, dn_meta return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
def _reset_parameters(self):
# class and bbox head init
bias_cls = bias_init_with_prob(0.01)
linear_init_(self.enc_score_head)
constant_(self.enc_score_head.bias, bias_cls)
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
linear_init_(cls_)
constant_(cls_.bias, bias_cls)
constant_(reg_.layers[-1].weight, 0.)
constant_(reg_.layers[-1].bias, 0.)
linear_init_(self.enc_output[0])
xavier_uniform_(self.enc_output[0].weight)
if self.learnt_init_query:
xavier_uniform_(self.tgt_embed.weight)
xavier_uniform_(self.query_pos_head.layers[0].weight)
xavier_uniform_(self.query_pos_head.layers[1].weight)
for layer in self.input_proj:
xavier_uniform_(layer[0].weight)
def _build_input_proj_layer(self, ch): def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
self.input_proj = nn.ModuleList()
for in_channels in ch:
self.input_proj.append(
nn.Sequential(nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(self.hidden_dim)))
in_channels = ch[-1]
for _ in range(self.nl - len(ch)):
self.input_proj.append(
nn.Sequential(nn.Conv2D(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(self.hidden_dim)))
in_channels = self.hidden_dim
def _generate_anchors(self, spatial_shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
anchors = [] anchors = []
for lvl, (h, w) in enumerate(spatial_shapes): for i, (h, w) in enumerate(shapes):
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=torch.float32), grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
torch.arange(end=w, dtype=torch.float32), torch.arange(end=w, dtype=dtype, device=device),
indexing='ij') indexing='ij')
grid_xy = torch.stack([grid_x, grid_y], -1) grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
valid_WH = torch.tensor([h, w]).to(torch.float32) valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
anchors.append(torch.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
anchors = torch.concat(anchors, 1) anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors)) anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.inf) anchors = torch.where(valid_mask, anchors, torch.inf)
return anchors.to(device=device, dtype=dtype), valid_mask.to(device=device) return anchors, valid_mask
def _get_encoder_input(self, feats): def _get_encoder_input(self, x):
# get projection features # get projection features
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
if self.nl > len(proj_feats):
len_srcs = len(proj_feats)
for i in range(len_srcs, self.nl):
if i == len_srcs:
proj_feats.append(self.input_proj[i](feats[-1]))
else:
proj_feats.append(self.input_proj[i](proj_feats[-1]))
# get encoder inputs # get encoder inputs
feat_flatten = [] feats = []
spatial_shapes = [] shapes = []
level_start_index = [0] for feat in x:
for feat in proj_feats: h, w = feat.shape[2:]
_, _, h, w = feat.shape
# [b, c, h, w] -> [b, h*w, c] # [b, c, h, w] -> [b, h*w, c]
feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) feats.append(feat.flatten(2).permute(0, 2, 1))
# [nl, 2] # [nl, 2]
spatial_shapes.append([h, w]) shapes.append([h, w])
# [l], start index of each level
level_start_index.append(h * w + level_start_index[-1])
# [b, l, c] # [b, h*w, c]
feat_flatten = torch.concat(feat_flatten, 1) feats = torch.cat(feats, 1)
level_start_index.pop() return feats, shapes
return feat_flatten, spatial_shapes, level_start_index
def _get_decoder_input(self, memory, spatial_shapes, denoising_class=None, denoising_bbox_unact=None): def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
bs, _, _ = memory.shape bs = len(feats)
# prepare input for decoder # prepare input for decoder
anchors, valid_mask = self._generate_anchors(spatial_shapes, dtype=memory.dtype, device=memory.device) anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
memory = torch.where(valid_mask, memory, 0) features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256
output_memory = self.enc_output(memory)
enc_outputs_class = self.enc_score_head(output_memory) # (bs, h*w, nc) enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors # (bs, h*w, 4) # dynamic anchors + static content
enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
# (bs, topk) # query selection
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1) # (bs, num_queries)
# extract region proposal boxes topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
# (bs, topk_ind) # (bs, num_queries)
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1) batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
topk_ind = topk_ind.view(-1)
# Unsigmoided # Unsigmoided
reference_points_unact = enc_outputs_coord_unact[batch_ind, topk_ind].view(bs, self.num_queries, -1) refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
# refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
enc_topk_bboxes = torch.sigmoid(reference_points_unact) enc_bboxes = refer_bbox.sigmoid()
if denoising_bbox_unact is not None: if dn_bbox is not None:
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
if self.training: if self.training:
reference_points_unact = reference_points_unact.detach() refer_bbox = refer_bbox.detach()
enc_topk_logits = enc_outputs_class[batch_ind, topk_ind].view(bs, self.num_queries, -1) enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
# extract region features
if self.learnt_init_query: if self.learnt_init_query:
target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
else: else:
target = output_memory[batch_ind, topk_ind].view(bs, self.num_queries, -1) embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
if self.training: if self.training:
target = target.detach() embeddings = embeddings.detach()
if denoising_class is not None: if dn_embed is not None:
target = torch.concat([denoising_class, target], 1) embeddings = torch.cat([dn_embed, embeddings], 1)
return embeddings, refer_bbox, enc_bboxes, enc_scores
# TODO
def _reset_parameters(self):
# class and bbox head init
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
# linear_init_(self.enc_score_head)
constant_(self.enc_score_head.bias, bias_cls)
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
# linear_init_(cls_)
constant_(cls_.bias, bias_cls)
constant_(reg_.layers[-1].weight, 0.)
constant_(reg_.layers[-1].bias, 0.)
return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits linear_init_(self.enc_output[0])
xavier_uniform_(self.enc_output[0].weight)
if self.learnt_init_query:
xavier_uniform_(self.tgt_embed.weight)
xavier_uniform_(self.query_pos_head.layers[0].weight)
xavier_uniform_(self.query_pos_head.layers[1].weight)
for layer in self.input_proj:
xavier_uniform_(layer[0].weight)

@ -229,23 +229,23 @@ class MSDeformAttn(nn.Module):
xavier_uniform_(self.output_proj.weight.data) xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.) constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, value, value_spatial_shapes, value_mask=None): def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
""" """
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
Args: Args:
query (Tensor): [bs, query_length, C] query (torch.Tensor): [bs, query_length, C]
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C] value (torch.Tensor): [bs, value_length, C]
value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns: Returns:
output (Tensor): [bs, Length_{query}, C] output (Tensor): [bs, Length_{query}, C]
""" """
bs, len_q = query.shape[:2] bs, len_q = query.shape[:2]
_, len_v = value.shape[:2] len_v = value.shape[1]
assert sum(s[0] * s[1] for s in value_spatial_shapes) == len_v assert sum(s[0] * s[1] for s in value_shapes) == len_v
value = self.value_proj(value) value = self.value_proj(value)
if value_mask is not None: if value_mask is not None:
@ -255,18 +255,17 @@ class MSDeformAttn(nn.Module):
attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points) attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2 # N, Len_q, n_heads, n_levels, n_points, 2
n = reference_points.shape[-1] num_points = refer_bbox.shape[-1]
if n == 2: if num_points == 2:
offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip(-1) offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
add = sampling_offsets / offset_normalizer[None, None, None, :, None, :] add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
sampling_locations = reference_points[:, :, None, :, None, :] + add sampling_locations = refer_bbox[:, :, None, :, None, :] + add
elif num_points == 4:
elif n == 4: add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
add = sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
sampling_locations = reference_points[:, :, None, :, None, :2] + add
else: else:
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {n}.') raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
output = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights) output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
output = self.output_proj(output) output = self.output_proj(output)
return output return output
@ -308,33 +307,24 @@ class DeformableTransformerDecoderLayer(nn.Module):
tgt = self.norm3(tgt) tgt = self.norm3(tgt)
return tgt return tgt
def forward(self, def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
tgt,
reference_points,
src,
src_spatial_shapes,
src_padding_mask=None,
attn_mask=None,
query_pos=None):
# self attention # self attention
q = k = self.with_pos_embed(tgt, query_pos) q = k = self.with_pos_embed(embed, query_pos)
if attn_mask is not None: tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
attn_mask = torch.where(attn_mask.astype('bool'), torch.zeros(attn_mask.shape, tgt.dtype), attn_mask=attn_mask)[0].transpose(0, 1)
torch.full(attn_mask.shape, float('-inf'), tgt.dtype)) embed = embed + self.dropout1(tgt)
tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) embed = self.norm1(embed)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# cross attention # cross attention
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), reference_points, src, src_spatial_shapes, tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
src_padding_mask) padding_mask)
tgt = tgt + self.dropout2(tgt2) embed = embed + self.dropout2(tgt)
tgt = self.norm2(tgt) embed = self.norm2(embed)
# ffn # ffn
tgt = self.forward_ffn(tgt) embed = self.forward_ffn(embed)
return tgt return embed
class DeformableTransformerDecoder(nn.Module): class DeformableTransformerDecoder(nn.Module):
@ -349,41 +339,40 @@ class DeformableTransformerDecoder(nn.Module):
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
def forward(self, def forward(
tgt, self,
reference_points, embed, # decoder embeddings
src, refer_bbox, # anchor
src_spatial_shapes, feats, # image features
shapes, # feature shapes
bbox_head, bbox_head,
score_head, score_head,
query_pos_head, pos_mlp,
attn_mask=None, attn_mask=None,
src_padding_mask=None): padding_mask=None):
output = tgt output = embed
dec_out_bboxes = [] dec_bboxes = []
dec_out_logits = [] dec_cls = []
ref_points = None last_refined_bbox = None
ref_points_detach = torch.sigmoid(reference_points) refer_bbox = refer_bbox.sigmoid()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
ref_points_input = ref_points_detach.unsqueeze(2) output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
query_pos_embed = query_pos_head(ref_points_detach)
output = layer(output, ref_points_input, src, src_spatial_shapes, src_padding_mask, attn_mask,
query_pos_embed)
inter_ref_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) # refine bboxes, (bs, num_queries+num_denoising, 4)
refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
if self.training: if self.training:
dec_out_logits.append(score_head[i](output)) dec_cls.append(score_head[i](output))
if i == 0: if i == 0:
dec_out_bboxes.append(inter_ref_bbox) dec_bboxes.append(refined_bbox)
else: else:
dec_out_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
elif i == self.eval_idx: elif i == self.eval_idx:
dec_out_logits.append(score_head[i](output)) dec_cls.append(score_head[i](output))
dec_out_bboxes.append(inter_ref_bbox) dec_bboxes.append(refined_bbox)
break break
ref_points = inter_ref_bbox last_refined_bbox = refined_bbox
ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox refer_bbox = refined_bbox.detach() if self.training else refined_bbox
return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits) return torch.stack(dec_bboxes), torch.stack(dec_cls)

@ -210,7 +210,9 @@ class BaseModel(nn.Module):
""" """
if not hasattr(self, 'criterion'): if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion() self.criterion = self.init_criterion()
return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
preds = self.forward(batch['img']) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self): def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads') raise NotImplementedError('compute_loss() needs to be implemented by task heads')
@ -410,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
"""Compute the classification loss between predictions and true labels.""" """Compute the classification loss between predictions and true labels."""
from ultralytics.vit.utils.loss import RTDETRDetectionLoss from ultralytics.vit.utils.loss import RTDETRDetectionLoss
return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True) return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
def loss(self, batch, preds=None): def loss(self, batch, preds=None):
if not hasattr(self, 'criterion'): if not hasattr(self, 'criterion'):
@ -420,31 +422,36 @@ class RTDETRDetectionModel(DetectionModel):
# NOTE: preprocess gt_bbox and gt_labels to list. # NOTE: preprocess gt_bbox and gt_labels to list.
bs = len(img) bs = len(img)
batch_idx = batch['batch_idx'] batch_idx = batch['batch_idx']
gt_bbox, gt_class = [], [] gt_groups = []
for i in range(bs): for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device)) gt_groups.append((batch_idx == i).sum().item())
gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long)) targets = {
targets = {'cls': gt_class, 'bboxes': gt_bbox} 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
'bboxes': batch['bboxes'].to(device=img.device),
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
'gt_groups': gt_groups}
preds = self.predict(img, batch=targets) if preds is None else preds preds = self.predict(img, batch=targets) if preds is None else preds
dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
# NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
if dn_meta is None: if dn_meta is None:
return 0, torch.zeros(3, device=dec_out_bboxes.device) dn_bboxes, dn_scores = None, None
dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2) else:
dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2) dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes]) dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits]) dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
loss = self.criterion((out_bboxes, out_logits), loss = self.criterion((dec_bboxes, dec_scores),
targets, targets,
dn_out_bboxes=dn_out_bboxes, dn_bboxes=dn_bboxes,
dn_out_logits=dn_out_logits, dn_scores=dn_scores,
dn_meta=dn_meta) dn_meta=dn_meta)
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']]) # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
device=img.device)
def predict(self, x, profile=False, visualize=False, batch=None): def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
""" """
Perform a forward pass through the network. Perform a forward pass through the network.

@ -3,4 +3,4 @@
from .rtdetr import RTDETR from .rtdetr import RTDETR
from .sam import SAM from .sam import SAM
__all__ = 'RTDETR', 'SAM', 'SAM' # allow simpler import __all__ = 'RTDETR', 'SAM' # allow simpler import

@ -5,15 +5,15 @@
from pathlib import Path from pathlib import Path
from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, yaml_model_load from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.torch_utils import model_info from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
from ...yolo.utils.torch_utils import smart_inference_mode
from .predict import RTDETRPredictor from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator from .val import RTDETRValidator
@ -24,6 +24,7 @@ class RTDETR:
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.') raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
# Load or create new YOLO model # Load or create new YOLO model
self.predictor = None self.predictor = None
self.ckpt = None
suffix = Path(model).suffix suffix = Path(model).suffix
if suffix == '.yaml': if suffix == '.yaml':
self._new(model) self._new(model)
@ -34,7 +35,7 @@ class RTDETR:
cfg_dict = yaml_model_load(cfg) cfg_dict = yaml_model_load(cfg)
self.cfg = cfg self.cfg = cfg
self.task = 'detect' self.task = 'detect'
self.model = DetectionModel(cfg_dict, verbose=verbose) # build model self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
# Below added to allow export from yamls # Below added to allow export from yamls
self.model.args = DEFAULT_CFG_DICT # attach args to model self.model.args = DEFAULT_CFG_DICT # attach args to model
@ -42,10 +43,20 @@ class RTDETR:
@smart_inference_mode() @smart_inference_mode()
def _load(self, weights: str): def _load(self, weights: str):
self.model, _ = attempt_load_one_weight(weights) self.model, self.ckpt = attempt_load_one_weight(weights)
self.model.args = DEFAULT_CFG_DICT # attach args to model self.model.args = DEFAULT_CFG_DICT # attach args to model
self.task = self.model.args['task'] self.task = self.model.args['task']
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
@smart_inference_mode() @smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs): def predict(self, source=None, stream=False, **kwargs):
""" """
@ -74,8 +85,30 @@ class RTDETR:
return self.predictor(source, stream=stream) return self.predictor(source, stream=stream)
def train(self, **kwargs): def train(self, **kwargs):
"""Function trains models but raises an error as RTDETR models do not support training.""" """
raise NotImplementedError("RTDETR models don't support training") Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
overrides = dict(task='detect', mode='train')
overrides.update(kwargs)
overrides['deterministic'] = False
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = RTDETRTrainer(overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def val(self, **kwargs): def val(self, **kwargs):
"""Run validation given dataset.""" """Run validation given dataset."""

@ -0,0 +1,78 @@
from copy import copy
import torch
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
from ultralytics.yolo.v8.detect import DetectionTrainer
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path, mode='val', batch=None):
"""Build RTDETR Dataset
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == 'train', # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
def get_validator(self):
"""Returns a DetectionValidator for RTDETR model validation."""
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch = super().preprocess_batch(batch)
bs = len(batch['img'])
batch_idx = batch['batch_idx']
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize RTDETR model given training data and device."""
model = 'rtdetr-l.yaml'
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
args = dict(model=model,
data=data,
device=device,
imgsz=640,
exist_ok=True,
batch=4,
deterministic=False,
amp=False)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

@ -2,10 +2,12 @@
from pathlib import Path from pathlib import Path
import cv2
import numpy as np
import torch import torch
from ultralytics.yolo.data import YOLODataset from ultralytics.yolo.data import YOLODataset
from ultralytics.yolo.data.augment import Compose, Format, LetterBox from ultralytics.yolo.data.augment import Compose, Format, v8_transforms
from ultralytics.yolo.utils import colorstr, ops from ultralytics.yolo.utils import colorstr, ops
from ultralytics.yolo.v8.detect import DetectionValidator from ultralytics.yolo.v8.detect import DetectionValidator
@ -18,9 +20,41 @@ class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs): def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs) super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
# NOTE: add stretch version load_image for rtdetr mosaic
def load_image(self, i):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
# Add to buffer if training with augmentations
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
self.buffer.append(i)
if len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None): def build_transforms(self, hyp=None):
"""Temporarily, only for evaluation.""" """Temporarily, only for evaluation."""
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
else:
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
transforms = Compose([])
transforms.append( transforms.append(
Format(bbox_format='xywh', Format(bbox_format='xywh',
normalize=True, normalize=True,
@ -65,6 +99,8 @@ class RTDETRValidator(DetectionValidator):
# Do not need threshold for evaluation as only got 300 boxes here. # Do not need threshold for evaluation as only got 300 boxes here.
# idx = score > self.args.conf # idx = score > self.args.conf
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
# sort by confidence to correctly get internal metrics.
pred = pred[score.argsort(descending=True)]
outputs[i] = pred # [idx] outputs[i] = pred # [idx]
return outputs return outputs
@ -100,7 +136,8 @@ class RTDETRValidator(DetectionValidator):
tbox[..., [0, 2]] *= shape[1] # native-space pred tbox[..., [0, 2]] *= shape[1] # native-space pred
tbox[..., [1, 3]] *= shape[0] # native-space pred tbox[..., [1, 3]] *= shape[0] # native-space pred
labelsn = torch.cat((cls, tbox), 1) # native-space labels labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn) # NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
correct_bboxes = self._process_batch(predn.float(), labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable # TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots: if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn) self.confusion_matrix.process_batch(predn, labelsn)

@ -256,10 +256,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
return mask, False return mask, False
fill_labels = [0] + small_regions fill_labels = [0] + small_regions
if not correct_holes: if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest # If every region is below threshold, keep largest
if not fill_labels: fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels) mask = np.isin(regions, fill_labels)
return mask, True return mask, True

@ -18,14 +18,12 @@ class Sam(nn.Module):
mask_threshold: float = 0.0 mask_threshold: float = 0.0
image_format: str = 'RGB' image_format: str = 'RGB'
def __init__( def __init__(self,
self,
image_encoder: ImageEncoderViT, image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder, prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder, mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_mean: List[float] = None,
pixel_std: List[float] = [58.395, 57.12, 57.375], pixel_std: List[float] = None) -> None:
) -> None:
""" """
SAM predicts object masks from an image and input prompts. SAM predicts object masks from an image and input prompts.
@ -38,6 +36,10 @@ class Sam(nn.Module):
pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image.
""" """
if pixel_mean is None:
pixel_mean = [123.675, 116.28, 103.53]
if pixel_std is None:
pixel_std = [58.395, 57.12, 57.375]
super().__init__() super().__init__()
self.image_encoder = image_encoder self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder self.prompt_encoder = prompt_encoder

@ -0,0 +1,291 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.vit.utils.ops import HungarianMatcher
from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.yolo.utils.metrics import bbox_iou
class DETRLoss(nn.Module):
def __init__(self,
nc=80,
loss_gain=None,
aux_loss=True,
use_fl=True,
use_vfl=False,
use_uni_match=False,
uni_match_ind=0):
"""
Args:
nc (int): The number of classes.
loss_gain (dict): The coefficient of loss.
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
use_focal_loss (bool): Use focal loss or not.
use_vfl (bool): Use VarifocalLoss or not.
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
uni_match_ind (int): The fixed indices of a layer.
"""
super().__init__()
if loss_gain is None:
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
self.nc = nc
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
self.loss_gain = loss_gain
self.aux_loss = aux_loss
self.fl = FocalLoss() if use_fl else None
self.vfl = VarifocalLoss() if use_vfl else None
self.use_uni_match = use_uni_match
self.uni_match_ind = uni_match_ind
self.device = None
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f'loss_class{postfix}'
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
one_hot = one_hot[..., :-1]
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
if self.fl:
if num_gts and self.vfl:
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
else:
loss_cls = self.fl(pred_scores, one_hot.float())
loss_cls /= max(num_gts, 1) / nq
else:
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f'loss_bbox{postfix}'
name_giou = f'loss_giou{postfix}'
loss = {}
if len(gt_bboxes) == 0:
loss[name_bbox] = torch.tensor(0., device=self.device)
loss[name_giou] = torch.tensor(0., device=self.device)
return loss
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
loss = {k: v.squeeze() for k, v in loss.items()}
return loss
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
name_mask = f'loss_mask{postfix}'
name_dice = f'loss_dice{postfix}'
loss = {}
if sum(len(a) for a in gt_mask) == 0:
loss[name_mask] = torch.tensor(0., device=self.device)
loss[name_dice] = torch.tensor(0., device=self.device)
return loss
num_gts = len(gt_mask)
src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
# TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
torch.tensor([num_gts], dtype=torch.float32))
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
return loss
def _dice_loss(self, inputs, targets, num_gts):
inputs = F.sigmoid(inputs)
inputs = inputs.flatten(1)
targets = targets.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_gts
def _get_loss_aux(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
match_indices=None,
postfix='',
masks=None,
gt_mask=None):
"""Get auxiliary losses"""
# NOTE: loss class, bbox, giou, mask, dice
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
if match_indices is None and self.use_uni_match:
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
gt_bboxes,
gt_cls,
gt_groups,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask)
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
aux_masks = masks[i] if masks is not None else None
loss_ = self._get_loss(aux_bboxes,
aux_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=aux_masks,
gt_mask=gt_mask,
postfix=postfix,
match_indices=match_indices)
loss[0] += loss_[f'loss_class{postfix}']
loss[1] += loss_[f'loss_bbox{postfix}']
loss[2] += loss_[f'loss_giou{postfix}']
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
loss = {
f'loss_class_aux{postfix}': loss[0],
f'loss_bbox_aux{postfix}': loss[1],
f'loss_giou_aux{postfix}': loss[2]}
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
return loss
def _get_index(self, match_indices):
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
src_idx = torch.cat([src for (src, _) in match_indices])
dst_idx = torch.cat([dst for (_, dst) in match_indices])
return (batch_idx, src_idx), dst_idx
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
pred_assigned = torch.cat([
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (I, _) in zip(pred_bboxes, match_indices)])
gt_assigned = torch.cat([
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, J) in zip(gt_bboxes, match_indices)])
return pred_assigned, gt_assigned
def _get_loss(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=None,
gt_mask=None,
postfix='',
match_indices=None):
"""Get losses"""
if match_indices is None:
match_indices = self.matcher(pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=masks,
gt_mask=gt_mask)
idx, gt_idx = self._get_index(match_indices)
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
bs, nq = pred_scores.shape[:2]
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
targets[idx] = gt_cls[gt_idx]
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
if len(gt_bboxes):
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
loss = {}
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
# if masks is not None and gt_mask is not None:
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
return loss
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
"""
Args:
pred_bboxes (torch.Tensor): [l, b, query, 4]
pred_scores (torch.Tensor): [l, b, query, num_classes]
batch (dict): A dict includes:
gt_cls (torch.Tensor) with shape [num_gts, ],
gt_bboxes (torch.Tensor): [num_gts, 4],
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
postfix (str): postfix of loss name.
"""
self.device = pred_bboxes.device
match_indices = kwargs.get('match_indices', None)
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
total_loss = self._get_loss(pred_bboxes[-1],
pred_scores[-1],
gt_bboxes,
gt_cls,
gt_groups,
postfix=postfix,
match_indices=match_indices)
if self.aux_loss:
total_loss.update(
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
postfix))
return total_loss
class RTDETRDetectionLoss(DETRLoss):
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
pred_bboxes, pred_scores = preds
total_loss = super().forward(pred_bboxes, pred_scores, batch)
if dn_meta is not None:
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
assert len(batch['gt_groups']) == len(dn_pos_idx)
# denoising match indices
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
# compute denoising training loss
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
total_loss.update(dn_loss)
else:
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
return total_loss
@staticmethod
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
"""Get the match indices for denoising.
Args:
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
dn_num_group (int): The number of groups of denoising.
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
Returns:
dn_match_indices (List(tuple)): Matched indices.
"""
dn_match_indices = []
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
for i, num_gt in enumerate(gt_groups):
if num_gt > 0:
gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i]
gt_idx = gt_idx.repeat(dn_num_group)
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
dn_match_indices.append((dn_pos_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32)))
return dn_match_indices

@ -0,0 +1,230 @@
# TODO: license
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from ultralytics.yolo.utils.metrics import bbox_iou
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
class HungarianMatcher(nn.Module):
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
"""
Args:
matcher_coeff (dict): The coefficient of hungarian matcher cost.
"""
super().__init__()
if cost_gain is None:
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
self.cost_gain = cost_gain
self.use_fl = use_fl
self.with_mask = with_mask
self.num_sample_points = num_sample_points
self.alpha = alpha
self.gamma = gamma
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
"""
Args:
pred_bboxes (Tensor): [b, query, 4]
pred_scores (Tensor): [b, query, num_classes]
gt_cls (torch.Tensor) with shape [num_gts, ]
gt_bboxes (torch.Tensor): [num_gts, 4]
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
masks (Tensor|None): [b, query, h, w]
gt_mask (List(Tensor)): list[[n, H, W]]
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, nq, nc = pred_scores.shape
if sum(gt_groups) == 0:
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
# We flatten to compute the cost matrices in a batch
# [batch_size * num_queries, num_classes]
pred_scores = pred_scores.detach().view(-1, nc)
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
# [batch_size * num_queries, 4]
pred_bboxes = pred_bboxes.detach().view(-1, 4)
# Compute the classification cost
pred_scores = pred_scores[:, gt_cls]
if self.use_fl:
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -pred_scores
# Compute the L1 cost between boxes
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
# Final cost matrix
C = self.cost_gain['class'] * cost_class + \
self.cost_gain['bbox'] * cost_bbox + \
self.cost_gain['giou'] * cost_giou
# Compute the mask cost and dice cost
if self.with_mask:
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
C = C.view(bs, nq, -1).cpu()
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
# (idx for queries, idx for gt)
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
for k, (i, j) in enumerate(indices)]
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
# all masks share the same set of points for efficient matching
sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
sample_points = 2.0 * sample_points - 1.0
out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
out_mask = out_mask.flatten(0, 1)
tgt_mask = torch.cat(gt_mask).unsqueeze(1)
sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
with torch.cuda.amp.autocast(False):
# binary cross entropy cost
pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
cost_mask /= self.num_sample_points
# dice cost
out_mask = F.sigmoid(out_mask)
numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
cost_dice = 1 - (numerator + 1) / (denominator + 1)
C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
return C
def get_cdn_group(batch,
num_classes,
num_queries,
class_embed,
num_dn=100,
cls_noise_ratio=0.5,
box_noise_scale=1.0,
training=False):
"""Get contrastive denoising training group
Args:
batch (dict): A dict includes:
gt_cls (torch.Tensor) with shape [num_gts, ],
gt_bboxes (torch.Tensor): [num_gts, 4],
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
num_classes (int): Number of classes.
num_queries (int): Number of queries.
class_embed (torch.Tensor): Embedding weights to map cls to embedding space.
num_dn (int): Number of denoising.
cls_noise_ratio (float): Noise ratio for class.
box_noise_scale (float): Noise scale for bbox.
training (bool): If it's training or not.
Returns:
"""
if (not training) or num_dn <= 0:
return None, None, None, None
gt_groups = batch['gt_groups']
total_num = sum(gt_groups)
max_nums = max(gt_groups)
if max_nums == 0:
return None, None, None, None
num_group = num_dn // max_nums
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch['cls'] # (bs*num, )
gt_bbox = batch['bboxes'] # bs*num, 4
b_idx = batch['batch_idx']
# each group has positive and negative queries.
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
# positive and negative mask
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
if cls_noise_ratio > 0:
# half of bbox prob
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# randomly put a new one here
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label
if box_noise_scale > 0:
known_bbox = xywh2xyxy(dn_bbox)
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(dn_bbox)
rand_part[neg_idx] += 1.0
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
dn_bbox = xyxy2xywh(known_bbox)
dn_bbox = inverse_sigmoid(dn_bbox)
# total denoising queries
num_dn = int(max_nums * 2 * num_group)
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
# match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
if i == num_group - 1:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
else:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
dn_meta = {
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)],
'dn_num_group': num_group,
'dn_num_split': [num_dn, num_queries]}
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
class_embed.device), dn_meta
def inverse_sigmoid(x, eps=1e-6):
x = x.clip(min=0., max=1.)
return torch.log(x / (1 - x + eps) + eps)

@ -759,7 +759,7 @@ class Format:
return masks, instances, cls return masks, instances, cls
def v8_transforms(dataset, imgsz, hyp): def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training.""" """Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose([ pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
@ -770,7 +770,7 @@ def v8_transforms(dataset, imgsz, hyp):
scale=hyp.scale, scale=hyp.scale,
shear=hyp.shear, shear=hyp.shear,
perspective=hyp.perspective, perspective=hyp.perspective,
pre_transform=LetterBox(new_shape=(imgsz, imgsz)), pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
)]) )])
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
if dataset.use_keypoints: if dataset.use_keypoints:

@ -278,7 +278,8 @@ class BaseTrainer:
self.epoch_time_start = time.time() self.epoch_time_start = time.time()
self.train_time_start = time.time() self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations nw = max(round(self.args.warmup_epochs *
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
last_opt_step = -1 last_opt_step = -1
self.run_callbacks('on_train_start') self.run_callbacks('on_train_start')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'

@ -24,10 +24,34 @@ class VarifocalLoss(nn.Module):
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).sum() weight).mean(1).sum()
return loss return loss
# Losses
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, ):
super().__init__()
def forward(self, pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = pred.sigmoid() # prob from logits
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
modulating_factor = (1.0 - p_t) ** gamma
loss *= modulating_factor
if alpha > 0:
alpha_factor = label * alpha + (1 - label) * (1 - alpha)
loss *= alpha_factor
return loss.mean(1).sum()
class BboxLoss(nn.Module): class BboxLoss(nn.Module):
def __init__(self, reg_max, use_dfl=False): def __init__(self, reg_max, use_dfl=False):

@ -9,7 +9,6 @@ from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
@ -175,40 +174,6 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
return 1.0 - 0.5 * eps, 0.5 * eps return 1.0 - 0.5 * eps, 0.5 * eps
# Losses
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
"""Initialize FocalLoss object with given loss function and hyperparameters."""
super().__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element
def forward(self, pred, true):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = self.loss_fcn(pred, true)
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = torch.sigmoid(pred) # prob from logits
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'None'
return loss
class ConfusionMatrix: class ConfusionMatrix:
""" """
A class for calculating and updating a confusion matrix for object detection and classification tasks. A class for calculating and updating a confusion matrix for object detection and classification tasks.

@ -327,6 +327,9 @@ def init_seeds(seed=0, deterministic=False):
os.environ['PYTHONHASHSEED'] = str(seed) os.environ['PYTHONHASHSEED'] = str(seed)
else: else:
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.') LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
else:
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
class ModelEMA: class ModelEMA:

Loading…
Cancel
Save