From a0ba8ef5f0d3e8bcb920afbf443275a8288c486d Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 17 Jun 2023 17:16:18 +0530 Subject: [PATCH] Add RTDETR Trainer (#2745) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> --- docs/models/rtdetr.md | 5 +- docs/reference/yolo/utils/loss.md | 5 + docs/reference/yolo/utils/metrics.md | 5 - tests/test_python.py | 7 +- .../rt-detr/{rt-detr-l.yaml => rtdetr-l.yaml} | 0 .../rt-detr/{rt-detr-x.yaml => rtdetr-x.yaml} | 0 ultralytics/models/v8/yolov8-rtdetr.yaml | 46 +++ ultralytics/nn/modules/head.py | 271 +++++++--------- ultralytics/nn/modules/transformer.py | 121 ++++---- ultralytics/nn/tasks.py | 43 +-- ultralytics/vit/__init__.py | 2 +- ultralytics/vit/rtdetr/model.py | 49 ++- ultralytics/vit/rtdetr/train.py | 78 +++++ ultralytics/vit/rtdetr/val.py | 43 ++- ultralytics/vit/sam/amg.py | 4 +- ultralytics/vit/sam/modules/sam.py | 18 +- ultralytics/vit/utils/loss.py | 291 ++++++++++++++++++ ultralytics/vit/utils/ops.py | 230 ++++++++++++++ ultralytics/yolo/data/augment.py | 4 +- ultralytics/yolo/engine/trainer.py | 3 +- ultralytics/yolo/utils/loss.py | 26 +- ultralytics/yolo/utils/metrics.py | 35 --- ultralytics/yolo/utils/torch_utils.py | 3 + 23 files changed, 982 insertions(+), 307 deletions(-) rename ultralytics/models/rt-detr/{rt-detr-l.yaml => rtdetr-l.yaml} (100%) rename ultralytics/models/rt-detr/{rt-detr-x.yaml => rtdetr-x.yaml} (100%) create mode 100644 ultralytics/models/v8/yolov8-rtdetr.yaml create mode 100644 ultralytics/vit/rtdetr/train.py create mode 100644 ultralytics/vit/utils/loss.py create mode 100644 ultralytics/vit/utils/ops.py diff --git a/docs/models/rtdetr.md b/docs/models/rtdetr.md index 61d156a..f2c6516 100644 --- a/docs/models/rtdetr.md +++ b/docs/models/rtdetr.md @@ -35,6 +35,7 @@ from ultralytics import RTDETR model = RTDETR("rtdetr-l.pt") model.info() # display model information +model.train(data="coco8.yaml") # train model.predict("path/to/image.jpg") # predict ``` @@ -51,7 +52,7 @@ model.predict("path/to/image.jpg") # predict |------------|--------------------| | Inference | :heavy_check_mark: | | Validation | :heavy_check_mark: | -| Training | :x: (Coming soon) | +| Training | :heavy_check_mark: | # Citations and Acknowledgements @@ -70,4 +71,4 @@ If you use Baidu's RT-DETR in your research or development work, please cite the We would like to acknowledge Baidu and the [PaddlePaddle](https://github.com/PaddlePaddle/PaddleDetection) team for creating and maintaining this valuable resource for the computer vision community. Their contribution to the field with the development of the Vision Transformers-based real-time object detector, RT-DETR, is greatly appreciated. -*Keywords: RT-DETR, Transformer, ViT, Vision Transformers, Baidu RT-DETR, PaddlePaddle, Paddle Paddle RT-DETR, real-time object detection, Vision Transformers-based object detection, pre-trained PaddlePaddle RT-DETR models, Baidu's RT-DETR usage, Ultralytics Python API* \ No newline at end of file +*Keywords: RT-DETR, Transformer, ViT, Vision Transformers, Baidu RT-DETR, PaddlePaddle, Paddle Paddle RT-DETR, real-time object detection, Vision Transformers-based object detection, pre-trained PaddlePaddle RT-DETR models, Baidu's RT-DETR usage, Ultralytics Python API* diff --git a/docs/reference/yolo/utils/loss.md b/docs/reference/yolo/utils/loss.md index 38c4f0a..23cba8f 100644 --- a/docs/reference/yolo/utils/loss.md +++ b/docs/reference/yolo/utils/loss.md @@ -8,6 +8,11 @@ keywords: Ultralytics, YOLO, loss functions, object detection, keypoint detectio :::ultralytics.yolo.utils.loss.VarifocalLoss

+# FocalLoss +--- +:::ultralytics.yolo.utils.loss.FocalLoss +

+ # BboxLoss --- :::ultralytics.yolo.utils.loss.BboxLoss diff --git a/docs/reference/yolo/utils/metrics.md b/docs/reference/yolo/utils/metrics.md index 7f14803..c230c2c 100644 --- a/docs/reference/yolo/utils/metrics.md +++ b/docs/reference/yolo/utils/metrics.md @@ -3,11 +3,6 @@ description: Explore Ultralytics YOLO's FocalLoss, DetMetrics, PoseMetrics, Clas keywords: YOLOv5, metrics, losses, confusion matrix, detection metrics, pose metrics, classification metrics, intersection over area, intersection over union, keypoint intersection over union, average precision, per class average precision, Ultralytics Docs --- -# FocalLoss ---- -:::ultralytics.yolo.utils.metrics.FocalLoss -

- # ConfusionMatrix --- :::ultralytics.yolo.utils.metrics.ConfusionMatrix diff --git a/tests/test_python.py b/tests/test_python.py index 53e2789..f633bd6 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -7,7 +7,7 @@ import numpy as np import torch from PIL import Image -from ultralytics import YOLO +from ultralytics import RTDETR, YOLO from ultralytics.yolo.data.build import load_inference_source from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS @@ -174,7 +174,10 @@ def test_export_paddle(enabled=False): def test_all_model_yamls(): for m in list((ROOT / 'models').rglob('yolo*.yaml')): - YOLO(m.name) + if m.name == 'yolov8-rtdetr.yaml': # except the rtdetr model + RTDETR(m.name) + else: + YOLO(m.name) def test_workflow(): diff --git a/ultralytics/models/rt-detr/rt-detr-l.yaml b/ultralytics/models/rt-detr/rtdetr-l.yaml similarity index 100% rename from ultralytics/models/rt-detr/rt-detr-l.yaml rename to ultralytics/models/rt-detr/rtdetr-l.yaml diff --git a/ultralytics/models/rt-detr/rt-detr-x.yaml b/ultralytics/models/rt-detr/rtdetr-x.yaml similarity index 100% rename from ultralytics/models/rt-detr/rt-detr-x.yaml rename to ultralytics/models/rt-detr/rtdetr-x.yaml diff --git a/ultralytics/models/v8/yolov8-rtdetr.yaml b/ultralytics/models/v8/yolov8-rtdetr.yaml new file mode 100644 index 0000000..a058106 --- /dev/null +++ b/ultralytics/models/v8/yolov8-rtdetr.yaml @@ -0,0 +1,46 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index f7105bf..8b55fd0 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -163,220 +163,187 @@ class RTDETRDecoder(nn.Module): self, nc=80, ch=(512, 1024, 2048), - hidden_dim=256, - num_queries=300, - strides=(8, 16, 32), # TODO - nl=3, - num_decoder_points=4, - nhead=8, - num_decoder_layers=6, - dim_feedforward=1024, + hd=256, # hidden dim + nq=300, # num queries + ndp=4, # num decoder points + nh=8, # num head + ndl=6, # num decoder layers + d_ffn=1024, # dim of feedforward dropout=0., act=nn.ReLU(), eval_idx=-1, # training args - num_denoising=100, + nd=100, # num denoising label_noise_ratio=0.5, box_noise_scale=1.0, learnt_init_query=False): super().__init__() - assert len(ch) <= nl - assert len(strides) == len(ch) - for _ in range(nl - len(strides)): - strides.append(strides[-1] * 2) - - self.hidden_dim = hidden_dim - self.nhead = nhead - self.feat_strides = strides - self.nl = nl + self.hidden_dim = hd + self.nhead = nh + self.nl = len(ch) # num level self.nc = nc - self.num_queries = num_queries - self.num_decoder_layers = num_decoder_layers + self.num_queries = nq + self.num_decoder_layers = ndl # backbone feature projection - self._build_input_proj_layer(ch) + self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch) + # NOTE: simplified version but it's not consistent with .pt weights. + # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch) # Transformer module - decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, act, nl, - num_decoder_points) - self.decoder = DeformableTransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx) + decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp) + self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx) # denoising part - self.denoising_class_embed = nn.Embedding(nc, hidden_dim) - self.num_denoising = num_denoising + self.denoising_class_embed = nn.Embedding(nc, hd) + self.num_denoising = nd self.label_noise_ratio = label_noise_ratio self.box_noise_scale = box_noise_scale # decoder embedding self.learnt_init_query = learnt_init_query if learnt_init_query: - self.tgt_embed = nn.Embedding(num_queries, hidden_dim) - self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) + self.tgt_embed = nn.Embedding(nq, hd) + self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2) # encoder head - self.enc_output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim)) - self.enc_score_head = nn.Linear(hidden_dim, nc) - self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) + self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd)) + self.enc_score_head = nn.Linear(hd, nc) + self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3) # decoder head - self.dec_score_head = nn.ModuleList([nn.Linear(hidden_dim, nc) for _ in range(num_decoder_layers)]) - self.dec_bbox_head = nn.ModuleList([ - MLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(num_decoder_layers)]) + self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)]) + self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)]) self._reset_parameters() - def forward(self, feats, gt_meta=None): + def forward(self, x, batch=None): + from ultralytics.vit.utils.ops import get_cdn_group + # input projection and embedding - memory, spatial_shapes, _ = self._get_encoder_input(feats) + feats, shapes = self._get_encoder_input(x) # prepare denoising training - if self.training: - raise NotImplementedError - # denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ - # get_contrastive_denoising_training_group(gt_meta, - # self.num_classes, - # self.num_queries, - # self.denoising_class_embed.weight, - # self.num_denoising, - # self.label_noise_ratio, - # self.box_noise_scale) - else: - denoising_class, denoising_bbox_unact, attn_mask = None, None, None - - target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ - self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact) + dn_embed, dn_bbox, attn_mask, dn_meta = \ + get_cdn_group(batch, + self.nc, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + self.training) + + embed, refer_bbox, enc_bboxes, enc_scores = \ + self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) # decoder - out_bboxes, out_logits = self.decoder(target, - init_ref_points_unact, - memory, - spatial_shapes, + dec_bboxes, dec_scores = self.decoder(embed, + refer_bbox, + feats, + shapes, self.dec_bbox_head, self.dec_score_head, self.query_pos_head, attn_mask=attn_mask) if not self.training: - out_logits = out_logits.sigmoid_() - return out_bboxes, out_logits # enc_topk_bboxes, enc_topk_logits, dn_meta - - def _reset_parameters(self): - # class and bbox head init - bias_cls = bias_init_with_prob(0.01) - linear_init_(self.enc_score_head) - constant_(self.enc_score_head.bias, bias_cls) - constant_(self.enc_bbox_head.layers[-1].weight, 0.) - constant_(self.enc_bbox_head.layers[-1].bias, 0.) - for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): - linear_init_(cls_) - constant_(cls_.bias, bias_cls) - constant_(reg_.layers[-1].weight, 0.) - constant_(reg_.layers[-1].bias, 0.) + dec_scores = dec_scores.sigmoid_() + return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta - linear_init_(self.enc_output[0]) - xavier_uniform_(self.enc_output[0].weight) - if self.learnt_init_query: - xavier_uniform_(self.tgt_embed.weight) - xavier_uniform_(self.query_pos_head.layers[0].weight) - xavier_uniform_(self.query_pos_head.layers[1].weight) - for layer in self.input_proj: - xavier_uniform_(layer[0].weight) - - def _build_input_proj_layer(self, ch): - self.input_proj = nn.ModuleList() - for in_channels in ch: - self.input_proj.append( - nn.Sequential(nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(self.hidden_dim))) - in_channels = ch[-1] - for _ in range(self.nl - len(ch)): - self.input_proj.append( - nn.Sequential(nn.Conv2D(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(self.hidden_dim))) - in_channels = self.hidden_dim - - def _generate_anchors(self, spatial_shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): + def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): anchors = [] - for lvl, (h, w) in enumerate(spatial_shapes): - grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=torch.float32), - torch.arange(end=w, dtype=torch.float32), + for i, (h, w) in enumerate(shapes): + grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device), + torch.arange(end=w, dtype=dtype, device=device), indexing='ij') - grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2) - valid_WH = torch.tensor([h, w]).to(torch.float32) - grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH - wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) - anchors.append(torch.concat([grid_xy, wh], -1).reshape([-1, h * w, 4])) + valid_WH = torch.tensor([h, w], dtype=dtype, device=device) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2) + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i) + anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4) - anchors = torch.concat(anchors, 1) - valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 anchors = torch.log(anchors / (1 - anchors)) anchors = torch.where(valid_mask, anchors, torch.inf) - return anchors.to(device=device, dtype=dtype), valid_mask.to(device=device) + return anchors, valid_mask - def _get_encoder_input(self, feats): + def _get_encoder_input(self, x): # get projection features - proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] - if self.nl > len(proj_feats): - len_srcs = len(proj_feats) - for i in range(len_srcs, self.nl): - if i == len_srcs: - proj_feats.append(self.input_proj[i](feats[-1])) - else: - proj_feats.append(self.input_proj[i](proj_feats[-1])) - + x = [self.input_proj[i](feat) for i, feat in enumerate(x)] # get encoder inputs - feat_flatten = [] - spatial_shapes = [] - level_start_index = [0] - for feat in proj_feats: - _, _, h, w = feat.shape + feats = [] + shapes = [] + for feat in x: + h, w = feat.shape[2:] # [b, c, h, w] -> [b, h*w, c] - feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) + feats.append(feat.flatten(2).permute(0, 2, 1)) # [nl, 2] - spatial_shapes.append([h, w]) - # [l], start index of each level - level_start_index.append(h * w + level_start_index[-1]) + shapes.append([h, w]) - # [b, l, c] - feat_flatten = torch.concat(feat_flatten, 1) - level_start_index.pop() - return feat_flatten, spatial_shapes, level_start_index + # [b, h*w, c] + feats = torch.cat(feats, 1) + return feats, shapes - def _get_decoder_input(self, memory, spatial_shapes, denoising_class=None, denoising_bbox_unact=None): - bs, _, _ = memory.shape + def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None): + bs = len(feats) # prepare input for decoder - anchors, valid_mask = self._generate_anchors(spatial_shapes, dtype=memory.dtype, device=memory.device) - memory = torch.where(valid_mask, memory, 0) - output_memory = self.enc_output(memory) + anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) + features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256 - enc_outputs_class = self.enc_score_head(output_memory) # (bs, h*w, nc) - enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors # (bs, h*w, 4) + enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) + # dynamic anchors + static content + enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4) - # (bs, topk) - _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1) - # extract region proposal boxes - # (bs, topk_ind) + # query selection + # (bs, num_queries) + topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1) + # (bs, num_queries) batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1) - topk_ind = topk_ind.view(-1) # Unsigmoided - reference_points_unact = enc_outputs_coord_unact[batch_ind, topk_ind].view(bs, self.num_queries, -1) + refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1) + # refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4)) - enc_topk_bboxes = torch.sigmoid(reference_points_unact) - if denoising_bbox_unact is not None: - reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + enc_bboxes = refer_bbox.sigmoid() + if dn_bbox is not None: + refer_bbox = torch.cat([dn_bbox, refer_bbox], 1) if self.training: - reference_points_unact = reference_points_unact.detach() - enc_topk_logits = enc_outputs_class[batch_ind, topk_ind].view(bs, self.num_queries, -1) + refer_bbox = refer_bbox.detach() + enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1) - # extract region features if self.learnt_init_query: - target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) + embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) else: - target = output_memory[batch_ind, topk_ind].view(bs, self.num_queries, -1) + embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1) if self.training: - target = target.detach() - if denoising_class is not None: - target = torch.concat([denoising_class, target], 1) + embeddings = embeddings.detach() + if dn_embed is not None: + embeddings = torch.cat([dn_embed, embeddings], 1) + + return embeddings, refer_bbox, enc_bboxes, enc_scores - return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits + # TODO + def _reset_parameters(self): + # class and bbox head init + bias_cls = bias_init_with_prob(0.01) / 80 * self.nc + # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets. + # linear_init_(self.enc_score_head) + constant_(self.enc_score_head.bias, bias_cls) + constant_(self.enc_bbox_head.layers[-1].weight, 0.) + constant_(self.enc_bbox_head.layers[-1].bias, 0.) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + # linear_init_(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight, 0.) + constant_(reg_.layers[-1].bias, 0.) + + linear_init_(self.enc_output[0]) + xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for layer in self.input_proj: + xavier_uniform_(layer[0].weight) diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index 4745905..89a7d91 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -229,23 +229,23 @@ class MSDeformAttn(nn.Module): xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) - def forward(self, query, reference_points, value, value_spatial_shapes, value_mask=None): + def forward(self, query, refer_bbox, value, value_shapes, value_mask=None): """ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py Args: - query (Tensor): [bs, query_length, C] - reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + query (torch.Tensor): [bs, query_length, C] + refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area - value (Tensor): [bs, value_length, C] - value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value (torch.Tensor): [bs, value_length, C] + value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements Returns: output (Tensor): [bs, Length_{query}, C] """ bs, len_q = query.shape[:2] - _, len_v = value.shape[:2] - assert sum(s[0] * s[1] for s in value_spatial_shapes) == len_v + len_v = value.shape[1] + assert sum(s[0] * s[1] for s in value_shapes) == len_v value = self.value_proj(value) if value_mask is not None: @@ -255,18 +255,17 @@ class MSDeformAttn(nn.Module): attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 - n = reference_points.shape[-1] - if n == 2: - offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip(-1) + num_points = refer_bbox.shape[-1] + if num_points == 2: + offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1) add = sampling_offsets / offset_normalizer[None, None, None, :, None, :] - sampling_locations = reference_points[:, :, None, :, None, :] + add - - elif n == 4: - add = sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 - sampling_locations = reference_points[:, :, None, :, None, :2] + add + sampling_locations = refer_bbox[:, :, None, :, None, :] + add + elif num_points == 4: + add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5 + sampling_locations = refer_bbox[:, :, None, :, None, :2] + add else: - raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {n}.') - output = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights) + raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.') + output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output @@ -308,33 +307,24 @@ class DeformableTransformerDecoderLayer(nn.Module): tgt = self.norm3(tgt) return tgt - def forward(self, - tgt, - reference_points, - src, - src_spatial_shapes, - src_padding_mask=None, - attn_mask=None, - query_pos=None): + def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None): # self attention - q = k = self.with_pos_embed(tgt, query_pos) - if attn_mask is not None: - attn_mask = torch.where(attn_mask.astype('bool'), torch.zeros(attn_mask.shape, tgt.dtype), - torch.full(attn_mask.shape, float('-inf'), tgt.dtype)) - tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) + q = k = self.with_pos_embed(embed, query_pos) + tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), + attn_mask=attn_mask)[0].transpose(0, 1) + embed = embed + self.dropout1(tgt) + embed = self.norm1(embed) # cross attention - tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), reference_points, src, src_spatial_shapes, - src_padding_mask) - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) + tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, + padding_mask) + embed = embed + self.dropout2(tgt) + embed = self.norm2(embed) # ffn - tgt = self.forward_ffn(tgt) + embed = self.forward_ffn(embed) - return tgt + return embed class DeformableTransformerDecoder(nn.Module): @@ -349,41 +339,40 @@ class DeformableTransformerDecoder(nn.Module): self.hidden_dim = hidden_dim self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx - def forward(self, - tgt, - reference_points, - src, - src_spatial_shapes, - bbox_head, - score_head, - query_pos_head, - attn_mask=None, - src_padding_mask=None): - output = tgt - dec_out_bboxes = [] - dec_out_logits = [] - ref_points = None - ref_points_detach = torch.sigmoid(reference_points) + def forward( + self, + embed, # decoder embeddings + refer_bbox, # anchor + feats, # image features + shapes, # feature shapes + bbox_head, + score_head, + pos_mlp, + attn_mask=None, + padding_mask=None): + output = embed + dec_bboxes = [] + dec_cls = [] + last_refined_bbox = None + refer_bbox = refer_bbox.sigmoid() for i, layer in enumerate(self.layers): - ref_points_input = ref_points_detach.unsqueeze(2) - query_pos_embed = query_pos_head(ref_points_detach) - output = layer(output, ref_points_input, src, src_spatial_shapes, src_padding_mask, attn_mask, - query_pos_embed) + output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox)) - inter_ref_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) + # refine bboxes, (bs, num_queries+num_denoising, 4) + refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox)) if self.training: - dec_out_logits.append(score_head[i](output)) + dec_cls.append(score_head[i](output)) if i == 0: - dec_out_bboxes.append(inter_ref_bbox) + dec_bboxes.append(refined_bbox) else: - dec_out_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) + dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox))) elif i == self.eval_idx: - dec_out_logits.append(score_head[i](output)) - dec_out_bboxes.append(inter_ref_bbox) + dec_cls.append(score_head[i](output)) + dec_bboxes.append(refined_bbox) break - ref_points = inter_ref_bbox - ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox + last_refined_bbox = refined_bbox + refer_bbox = refined_bbox.detach() if self.training else refined_bbox - return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits) + return torch.stack(dec_bboxes), torch.stack(dec_cls) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 7af3f34..6e75553 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -210,7 +210,9 @@ class BaseModel(nn.Module): """ if not hasattr(self, 'criterion'): self.criterion = self.init_criterion() - return self.criterion(self.predict(batch['img']) if preds is None else preds, batch) + + preds = self.forward(batch['img']) if preds is None else preds + return self.criterion(preds, batch) def init_criterion(self): raise NotImplementedError('compute_loss() needs to be implemented by task heads') @@ -410,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel): """Compute the classification loss between predictions and true labels.""" from ultralytics.vit.utils.loss import RTDETRDetectionLoss - return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True) + return RTDETRDetectionLoss(nc=self.nc, use_vfl=True) def loss(self, batch, preds=None): if not hasattr(self, 'criterion'): @@ -420,31 +422,36 @@ class RTDETRDetectionModel(DetectionModel): # NOTE: preprocess gt_bbox and gt_labels to list. bs = len(img) batch_idx = batch['batch_idx'] - gt_bbox, gt_class = [], [] + gt_groups = [] for i in range(bs): - gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device)) - gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long)) - targets = {'cls': gt_class, 'bboxes': gt_bbox} + gt_groups.append((batch_idx == i).sum().item()) + targets = { + 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1), + 'bboxes': batch['bboxes'].to(device=img.device), + 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1), + 'gt_groups': gt_groups} preds = self.predict(img, batch=targets) if preds is None else preds - dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds - # NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported. + dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if dn_meta is None: - return 0, torch.zeros(3, device=dec_out_bboxes.device) - dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2) - dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2) + dn_bboxes, dn_scores = None, None + else: + dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2) + dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2) - out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes]) - out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits]) + dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4) + dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) - loss = self.criterion((out_bboxes, out_logits), + loss = self.criterion((dec_bboxes, dec_scores), targets, - dn_out_bboxes=dn_out_bboxes, - dn_out_logits=dn_out_logits, + dn_bboxes=dn_bboxes, + dn_scores=dn_scores, dn_meta=dn_meta) - return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']]) + # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. + return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']], + device=img.device) - def predict(self, x, profile=False, visualize=False, batch=None): + def predict(self, x, profile=False, visualize=False, batch=None, augment=False): """ Perform a forward pass through the network. diff --git a/ultralytics/vit/__init__.py b/ultralytics/vit/__init__.py index e142705..8e96f91 100644 --- a/ultralytics/vit/__init__.py +++ b/ultralytics/vit/__init__.py @@ -3,4 +3,4 @@ from .rtdetr import RTDETR from .sam import SAM -__all__ = 'RTDETR', 'SAM', 'SAM' # allow simpler import +__all__ = 'RTDETR', 'SAM' # allow simpler import diff --git a/ultralytics/vit/rtdetr/model.py b/ultralytics/vit/rtdetr/model.py index 1297e15..322912c 100644 --- a/ultralytics/vit/rtdetr/model.py +++ b/ultralytics/vit/rtdetr/model.py @@ -5,15 +5,15 @@ from pathlib import Path -from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, yaml_model_load +from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter -from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir +from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir from ultralytics.yolo.utils.checks import check_imgsz -from ultralytics.yolo.utils.torch_utils import model_info +from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode -from ...yolo.utils.torch_utils import smart_inference_mode from .predict import RTDETRPredictor +from .train import RTDETRTrainer from .val import RTDETRValidator @@ -24,6 +24,7 @@ class RTDETR: raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.') # Load or create new YOLO model self.predictor = None + self.ckpt = None suffix = Path(model).suffix if suffix == '.yaml': self._new(model) @@ -34,7 +35,7 @@ class RTDETR: cfg_dict = yaml_model_load(cfg) self.cfg = cfg self.task = 'detect' - self.model = DetectionModel(cfg_dict, verbose=verbose) # build model + self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model # Below added to allow export from yamls self.model.args = DEFAULT_CFG_DICT # attach args to model @@ -42,10 +43,20 @@ class RTDETR: @smart_inference_mode() def _load(self, weights: str): - self.model, _ = attempt_load_one_weight(weights) + self.model, self.ckpt = attempt_load_one_weight(weights) self.model.args = DEFAULT_CFG_DICT # attach args to model self.task = self.model.args['task'] + @smart_inference_mode() + def load(self, weights='yolov8n.pt'): + """ + Transfers parameters with matching names and shapes from 'weights' to model. + """ + if isinstance(weights, (str, Path)): + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + @smart_inference_mode() def predict(self, source=None, stream=False, **kwargs): """ @@ -74,8 +85,30 @@ class RTDETR: return self.predictor(source, stream=stream) def train(self, **kwargs): - """Function trains models but raises an error as RTDETR models do not support training.""" - raise NotImplementedError("RTDETR models don't support training") + """ + Trains the model on a given dataset. + + Args: + **kwargs (Any): Any number of arguments representing the training configuration. + """ + overrides = dict(task='detect', mode='train') + overrides.update(kwargs) + overrides['deterministic'] = False + if not overrides.get('data'): + raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") + if overrides.get('resume'): + overrides['resume'] = self.ckpt_path + self.task = overrides.get('task') or self.task + self.trainer = RTDETRTrainer(overrides=overrides) + if not overrides.get('resume'): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + self.trainer.train() + # Update model and cfg after training + if RANK in (-1, 0): + self.model, _ = attempt_load_one_weight(str(self.trainer.best)) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP def val(self, **kwargs): """Run validation given dataset.""" diff --git a/ultralytics/vit/rtdetr/train.py b/ultralytics/vit/rtdetr/train.py new file mode 100644 index 0000000..5a29589 --- /dev/null +++ b/ultralytics/vit/rtdetr/train.py @@ -0,0 +1,78 @@ +from copy import copy + +import torch + +from ultralytics.nn.tasks import RTDETRDetectionModel +from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr +from ultralytics.yolo.v8.detect import DetectionTrainer + +from .val import RTDETRDataset, RTDETRValidator + + +class RTDETRTrainer(DetectionTrainer): + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return a YOLO detection model.""" + model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model + + def build_dataset(self, img_path, mode='val', batch=None): + """Build RTDETR Dataset + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + return RTDETRDataset( + img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=mode == 'train', # no augmentation + hyp=self.args, + rect=False, # no rect + cache=self.args.cache or None, + prefix=colorstr(f'{mode}: '), + data=self.data) + + def get_validator(self): + """Returns a DetectionValidator for RTDETR model validation.""" + self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' + return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + + def preprocess_batch(self, batch): + """Preprocesses a batch of images by scaling and converting to float.""" + batch = super().preprocess_batch(batch) + bs = len(batch['img']) + batch_idx = batch['batch_idx'] + gt_bbox, gt_class = [], [] + for i in range(bs): + gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) + gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) + return batch + + +def train(cfg=DEFAULT_CFG, use_python=False): + """Train and optimize RTDETR model given training data and device.""" + model = 'rtdetr-l.yaml' + data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") + device = cfg.device if cfg.device is not None else '' + + # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True + # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching + args = dict(model=model, + data=data, + device=device, + imgsz=640, + exist_ok=True, + batch=4, + deterministic=False, + amp=False) + trainer = RTDETRTrainer(overrides=args) + trainer.train() + + +if __name__ == '__main__': + train() diff --git a/ultralytics/vit/rtdetr/val.py b/ultralytics/vit/rtdetr/val.py index e1619a6..57376a6 100644 --- a/ultralytics/vit/rtdetr/val.py +++ b/ultralytics/vit/rtdetr/val.py @@ -2,10 +2,12 @@ from pathlib import Path +import cv2 +import numpy as np import torch from ultralytics.yolo.data import YOLODataset -from ultralytics.yolo.data.augment import Compose, Format, LetterBox +from ultralytics.yolo.data.augment import Compose, Format, v8_transforms from ultralytics.yolo.utils import colorstr, ops from ultralytics.yolo.v8.detect import DetectionValidator @@ -18,9 +20,41 @@ class RTDETRDataset(YOLODataset): def __init__(self, *args, data=None, **kwargs): super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs) + # NOTE: add stretch version load_image for rtdetr mosaic + def load_image(self, i): + """Loads 1 image from dataset index 'i', returns (im, resized hw).""" + im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] + if im is None: # not cached in RAM + if fn.exists(): # load npy + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + if im is None: + raise FileNotFoundError(f'Image Not Found {f}') + h0, w0 = im.shape[:2] # orig hw + im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR) + + # Add to buffer if training with augmentations + if self.augment: + self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized + self.buffer.append(i) + if len(self.buffer) >= self.max_buffer_length: + j = self.buffer.pop(0) + self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None + + return im, (h0, w0), im.shape[:2] + + return self.ims[i], self.im_hw0[i], self.im_hw[i] + def build_transforms(self, hyp=None): """Temporarily, only for evaluation.""" - transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) + if self.augment: + hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 + hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 + transforms = v8_transforms(self, self.imgsz, hyp, stretch=True) + else: + # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) + transforms = Compose([]) transforms.append( Format(bbox_format='xywh', normalize=True, @@ -65,6 +99,8 @@ class RTDETRValidator(DetectionValidator): # Do not need threshold for evaluation as only got 300 boxes here. # idx = score > self.args.conf pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter + # sort by confidence to correctly get internal metrics. + pred = pred[score.argsort(descending=True)] outputs[i] = pred # [idx] return outputs @@ -100,7 +136,8 @@ class RTDETRValidator(DetectionValidator): tbox[..., [0, 2]] *= shape[1] # native-space pred tbox[..., [1, 3]] *= shape[0] # native-space pred labelsn = torch.cat((cls, tbox), 1) # native-space labels - correct_bboxes = self._process_batch(predn, labelsn) + # NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type. + correct_bboxes = self._process_batch(predn.float(), labelsn) # TODO: maybe remove these `self.` arguments as they already are member variable if self.args.plots: self.confusion_matrix.process_batch(predn, labelsn) diff --git a/ultralytics/vit/sam/amg.py b/ultralytics/vit/sam/amg.py index 3c70f7c..1522931 100644 --- a/ultralytics/vit/sam/amg.py +++ b/ultralytics/vit/sam/amg.py @@ -256,10 +256,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup return mask, False fill_labels = [0] + small_regions if not correct_holes: - fill_labels = [i for i in range(n_labels) if i not in fill_labels] # If every region is below threshold, keep largest - if not fill_labels: - fill_labels = [int(np.argmax(sizes)) + 1] + fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1] mask = np.isin(regions, fill_labels) return mask, True diff --git a/ultralytics/vit/sam/modules/sam.py b/ultralytics/vit/sam/modules/sam.py index 50a30ee..34963f1 100644 --- a/ultralytics/vit/sam/modules/sam.py +++ b/ultralytics/vit/sam/modules/sam.py @@ -18,14 +18,12 @@ class Sam(nn.Module): mask_threshold: float = 0.0 image_format: str = 'RGB' - def __init__( - self, - image_encoder: ImageEncoderViT, - prompt_encoder: PromptEncoder, - mask_decoder: MaskDecoder, - pixel_mean: List[float] = [123.675, 116.28, 103.53], - pixel_std: List[float] = [58.395, 57.12, 57.375], - ) -> None: + def __init__(self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = None, + pixel_std: List[float] = None) -> None: """ SAM predicts object masks from an image and input prompts. @@ -38,6 +36,10 @@ class Sam(nn.Module): pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ + if pixel_mean is None: + pixel_mean = [123.675, 116.28, 103.53] + if pixel_std is None: + pixel_std = [58.395, 57.12, 57.375] super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder diff --git a/ultralytics/vit/utils/loss.py b/ultralytics/vit/utils/loss.py new file mode 100644 index 0000000..1a5ba29 --- /dev/null +++ b/ultralytics/vit/utils/loss.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.vit.utils.ops import HungarianMatcher +from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss +from ultralytics.yolo.utils.metrics import bbox_iou + + +class DETRLoss(nn.Module): + + def __init__(self, + nc=80, + loss_gain=None, + aux_loss=True, + use_fl=True, + use_vfl=False, + use_uni_match=False, + uni_match_ind=0): + """ + Args: + nc (int): The number of classes. + loss_gain (dict): The coefficient of loss. + aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used. + use_focal_loss (bool): Use focal loss or not. + use_vfl (bool): Use VarifocalLoss or not. + use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch. + uni_match_ind (int): The fixed indices of a layer. + """ + super().__init__() + + if loss_gain is None: + loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1} + self.nc = nc + self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2}) + self.loss_gain = loss_gain + self.aux_loss = aux_loss + self.fl = FocalLoss() if use_fl else None + self.vfl = VarifocalLoss() if use_vfl else None + + self.use_uni_match = use_uni_match + self.uni_match_ind = uni_match_ind + self.device = None + + def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''): + # logits: [b, query, num_classes], gt_class: list[[n, 1]] + name_class = f'loss_class{postfix}' + bs, nq = pred_scores.shape[:2] + # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes) + one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) + one_hot.scatter_(2, targets.unsqueeze(-1), 1) + one_hot = one_hot[..., :-1] + gt_scores = gt_scores.view(bs, nq, 1) * one_hot + + if self.fl: + if num_gts and self.vfl: + loss_cls = self.vfl(pred_scores, gt_scores, one_hot) + else: + loss_cls = self.fl(pred_scores, one_hot.float()) + loss_cls /= max(num_gts, 1) / nq + else: + loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss + + return {name_class: loss_cls.squeeze() * self.loss_gain['class']} + + def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''): + # boxes: [b, query, 4], gt_bbox: list[[n, 4]] + name_bbox = f'loss_bbox{postfix}' + name_giou = f'loss_giou{postfix}' + + loss = {} + if len(gt_bboxes) == 0: + loss[name_bbox] = torch.tensor(0., device=self.device) + loss[name_giou] = torch.tensor(0., device=self.device) + return loss + + loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes) + loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) + loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) + loss[name_giou] = self.loss_gain['giou'] * loss[name_giou] + loss = {k: v.squeeze() for k, v in loss.items()} + return loss + + def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''): + # masks: [b, query, h, w], gt_mask: list[[n, H, W]] + name_mask = f'loss_mask{postfix}' + name_dice = f'loss_dice{postfix}' + + loss = {} + if sum(len(a) for a in gt_mask) == 0: + loss[name_mask] = torch.tensor(0., device=self.device) + loss[name_dice] = torch.tensor(0., device=self.device) + return loss + + num_gts = len(gt_mask) + src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices) + src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0] + # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now. + loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks, + torch.tensor([num_gts], dtype=torch.float32)) + loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts) + return loss + + def _dice_loss(self, inputs, targets, num_gts): + inputs = F.sigmoid(inputs) + inputs = inputs.flatten(1) + targets = targets.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_gts + + def _get_loss_aux(self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + match_indices=None, + postfix='', + masks=None, + gt_mask=None): + """Get auxiliary losses""" + # NOTE: loss class, bbox, giou, mask, dice + loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) + if match_indices is None and self.use_uni_match: + match_indices = self.matcher(pred_bboxes[self.uni_match_ind], + pred_scores[self.uni_match_ind], + gt_bboxes, + gt_cls, + gt_groups, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask) + for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): + aux_masks = masks[i] if masks is not None else None + loss_ = self._get_loss(aux_bboxes, + aux_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=aux_masks, + gt_mask=gt_mask, + postfix=postfix, + match_indices=match_indices) + loss[0] += loss_[f'loss_class{postfix}'] + loss[1] += loss_[f'loss_bbox{postfix}'] + loss[2] += loss_[f'loss_giou{postfix}'] + # if masks is not None and gt_mask is not None: + # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) + # loss[3] += loss_[f'loss_mask{postfix}'] + # loss[4] += loss_[f'loss_dice{postfix}'] + + loss = { + f'loss_class_aux{postfix}': loss[0], + f'loss_bbox_aux{postfix}': loss[1], + f'loss_giou_aux{postfix}': loss[2]} + # if masks is not None and gt_mask is not None: + # loss[f'loss_mask_aux{postfix}'] = loss[3] + # loss[f'loss_dice_aux{postfix}'] = loss[4] + return loss + + def _get_index(self, match_indices): + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) + src_idx = torch.cat([src for (src, _) in match_indices]) + dst_idx = torch.cat([dst for (_, dst) in match_indices]) + return (batch_idx, src_idx), dst_idx + + def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): + pred_assigned = torch.cat([ + t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (I, _) in zip(pred_bboxes, match_indices)]) + gt_assigned = torch.cat([ + t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (_, J) in zip(gt_bboxes, match_indices)]) + return pred_assigned, gt_assigned + + def _get_loss(self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=None, + gt_mask=None, + postfix='', + match_indices=None): + """Get losses""" + if match_indices is None: + match_indices = self.matcher(pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=masks, + gt_mask=gt_mask) + + idx, gt_idx = self._get_index(match_indices) + pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] + + bs, nq = pred_scores.shape[:2] + targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype) + targets[idx] = gt_cls[gt_idx] + + gt_scores = torch.zeros([bs, nq], device=pred_scores.device) + if len(gt_bboxes): + gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1) + + loss = {} + loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix)) + loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix)) + # if masks is not None and gt_mask is not None: + # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix)) + return loss + + def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs): + """ + Args: + pred_bboxes (torch.Tensor): [l, b, query, 4] + pred_scores (torch.Tensor): [l, b, query, num_classes] + batch (dict): A dict includes: + gt_cls (torch.Tensor) with shape [num_gts, ], + gt_bboxes (torch.Tensor): [num_gts, 4], + gt_groups (List(int)): a list of batch size length includes the number of gts of each image. + postfix (str): postfix of loss name. + """ + self.device = pred_bboxes.device + match_indices = kwargs.get('match_indices', None) + gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups'] + + total_loss = self._get_loss(pred_bboxes[-1], + pred_scores[-1], + gt_bboxes, + gt_cls, + gt_groups, + postfix=postfix, + match_indices=match_indices) + + if self.aux_loss: + total_loss.update( + self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, + postfix)) + + return total_loss + + +class RTDETRDetectionLoss(DETRLoss): + + def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None): + pred_bboxes, pred_scores = preds + total_loss = super().forward(pred_bboxes, pred_scores, batch) + + if dn_meta is not None: + dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group'] + assert len(batch['gt_groups']) == len(dn_pos_idx) + + # denoising match indices + match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups']) + + # compute denoising training loss + dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices) + total_loss.update(dn_loss) + else: + total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()}) + + return total_loss + + @staticmethod + def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups): + """Get the match indices for denoising. + + Args: + dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising. + dn_num_group (int): The number of groups of denoising. + gt_groups (List(int)): a list of batch size length includes the number of gts of each image. + + Returns: + dn_match_indices (List(tuple)): Matched indices. + + """ + dn_match_indices = [] + idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) + for i, num_gt in enumerate(gt_groups): + if num_gt > 0: + gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i] + gt_idx = gt_idx.repeat(dn_num_group) + assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, ' + f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.' + dn_match_indices.append((dn_pos_idx[i], gt_idx)) + else: + dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32))) + return dn_match_indices diff --git a/ultralytics/vit/utils/ops.py b/ultralytics/vit/utils/ops.py new file mode 100644 index 0000000..5b92963 --- /dev/null +++ b/ultralytics/vit/utils/ops.py @@ -0,0 +1,230 @@ +# TODO: license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from ultralytics.yolo.utils.metrics import bbox_iou +from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh + + +class HungarianMatcher(nn.Module): + + def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0): + """ + Args: + matcher_coeff (dict): The coefficient of hungarian matcher cost. + """ + super().__init__() + if cost_gain is None: + cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1} + self.cost_gain = cost_gain + self.use_fl = use_fl + self.with_mask = with_mask + self.num_sample_points = num_sample_points + self.alpha = alpha + self.gamma = gamma + + def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): + """ + Args: + pred_bboxes (Tensor): [b, query, 4] + pred_scores (Tensor): [b, query, num_classes] + gt_cls (torch.Tensor) with shape [num_gts, ] + gt_bboxes (torch.Tensor): [num_gts, 4] + gt_groups (List(int)): a list of batch size length includes the number of gts of each image. + masks (Tensor|None): [b, query, h, w] + gt_mask (List(Tensor)): list[[n, H, W]] + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, nq, nc = pred_scores.shape + + if sum(gt_groups) == 0: + return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)] + + # We flatten to compute the cost matrices in a batch + # [batch_size * num_queries, num_classes] + pred_scores = pred_scores.detach().view(-1, nc) + pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1) + # [batch_size * num_queries, 4] + pred_bboxes = pred_bboxes.detach().view(-1, 4) + + # Compute the classification cost + pred_scores = pred_scores[:, gt_cls] + if self.use_fl: + neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log()) + pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) + cost_class = pos_cost_class - neg_cost_class + else: + cost_class = -pred_scores + + # Compute the L1 cost between boxes + cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt) + + # Compute the GIoU cost between boxes, (bs*num_queries, num_gt) + cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) + + # Final cost matrix + C = self.cost_gain['class'] * cost_class + \ + self.cost_gain['bbox'] * cost_bbox + \ + self.cost_gain['giou'] * cost_giou + # Compute the mask cost and dice cost + if self.with_mask: + C += self._cost_mask(bs, gt_groups, masks, gt_mask) + + C = C.view(bs, nq, -1).cpu() + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] + gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) + # (idx for queries, idx for gt) + return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k]) + for k, (i, j) in enumerate(indices)] + + def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): + assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`' + # all masks share the same set of points for efficient matching + sample_points = torch.rand([bs, 1, self.num_sample_points, 2]) + sample_points = 2.0 * sample_points - 1.0 + + out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2) + out_mask = out_mask.flatten(0, 1) + + tgt_mask = torch.cat(gt_mask).unsqueeze(1) + sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) + tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) + + with torch.cuda.amp.autocast(False): + # binary cross entropy cost + pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') + neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') + cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T) + cost_mask /= self.num_sample_points + + # dice cost + out_mask = F.sigmoid(out_mask) + numerator = 2 * torch.matmul(out_mask, tgt_mask.T) + denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0) + cost_dice = 1 - (numerator + 1) / (denominator + 1) + + C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice + return C + + +def get_cdn_group(batch, + num_classes, + num_queries, + class_embed, + num_dn=100, + cls_noise_ratio=0.5, + box_noise_scale=1.0, + training=False): + """Get contrastive denoising training group + + Args: + batch (dict): A dict includes: + gt_cls (torch.Tensor) with shape [num_gts, ], + gt_bboxes (torch.Tensor): [num_gts, 4], + gt_groups (List(int)): a list of batch size length includes the number of gts of each image. + num_classes (int): Number of classes. + num_queries (int): Number of queries. + class_embed (torch.Tensor): Embedding weights to map cls to embedding space. + num_dn (int): Number of denoising. + cls_noise_ratio (float): Noise ratio for class. + box_noise_scale (float): Noise scale for bbox. + training (bool): If it's training or not. + + Returns: + + """ + if (not training) or num_dn <= 0: + return None, None, None, None + gt_groups = batch['gt_groups'] + total_num = sum(gt_groups) + max_nums = max(gt_groups) + if max_nums == 0: + return None, None, None, None + + num_group = num_dn // max_nums + num_group = 1 if num_group == 0 else num_group + # pad gt to max_num of a batch + bs = len(gt_groups) + gt_cls = batch['cls'] # (bs*num, ) + gt_bbox = batch['bboxes'] # bs*num, 4 + b_idx = batch['batch_idx'] + + # each group has positive and negative queries. + dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, ) + dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4 + dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, ) + + # positive and negative mask + # (bs*num*num_group, ), the second total_num*num_group part as negative samples + neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num + + if cls_noise_ratio > 0: + # half of bbox prob + mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5) + idx = torch.nonzero(mask).squeeze(-1) + # randomly put a new one here + new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device) + dn_cls[idx] = new_label + + if box_noise_scale > 0: + known_bbox = xywh2xyxy(dn_bbox) + + diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4 + + rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(dn_bbox) + rand_part[neg_idx] += 1.0 + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + dn_bbox = xyxy2xywh(known_bbox) + dn_bbox = inverse_sigmoid(dn_bbox) + + # total denoising queries + num_dn = int(max_nums * 2 * num_group) + # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)]) + dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256 + padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device) + padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device) + + map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups]) + pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0) + + map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)]) + padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed + padding_bbox[(dn_b_idx, map_indices)] = dn_bbox + + tgt_size = num_dn + num_queries + attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool) + # match query cannot see the reconstruct + attn_mask[num_dn:, :num_dn] = True + # reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True + if i == num_group - 1: + attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True + else: + attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True + attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True + dn_meta = { + 'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)], + 'dn_num_group': num_group, + 'dn_num_split': [num_dn, num_queries]} + + return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to( + class_embed.device), dn_meta + + +def inverse_sigmoid(x, eps=1e-6): + x = x.clip(min=0., max=1.) + return torch.log(x / (1 - x + eps) + eps) diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index 4c7f167..42688c9 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -759,7 +759,7 @@ class Format: return masks, instances, cls -def v8_transforms(dataset, imgsz, hyp): +def v8_transforms(dataset, imgsz, hyp, stretch=False): """Convert images to a size suitable for YOLOv8 training.""" pre_transform = Compose([ Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), @@ -770,7 +770,7 @@ def v8_transforms(dataset, imgsz, hyp): scale=hyp.scale, shear=hyp.shear, perspective=hyp.perspective, - pre_transform=LetterBox(new_shape=(imgsz, imgsz)), + pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), )]) flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation if dataset.use_keypoints: diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 679c545..e4925b3 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -278,7 +278,8 @@ class BaseTrainer: self.epoch_time_start = time.time() self.train_time_start = time.time() nb = len(self.train_loader) # number of batches - nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations + nw = max(round(self.args.warmup_epochs * + nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations last_opt_step = -1 self.run_callbacks('on_train_start') LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py index f165acb..71ed0a5 100644 --- a/ultralytics/yolo/utils/loss.py +++ b/ultralytics/yolo/utils/loss.py @@ -24,10 +24,34 @@ class VarifocalLoss(nn.Module): weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label with torch.cuda.amp.autocast(enabled=False): loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * - weight).sum() + weight).mean(1).sum() return loss +# Losses +class FocalLoss(nn.Module): + """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" + + def __init__(self, ): + super().__init__() + + def forward(self, pred, label, gamma=1.5, alpha=0.25): + """Calculates and updates confusion matrix for object detection/classification tasks.""" + loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none') + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = pred.sigmoid() # prob from logits + p_t = label * pred_prob + (1 - label) * (1 - pred_prob) + modulating_factor = (1.0 - p_t) ** gamma + loss *= modulating_factor + if alpha > 0: + alpha_factor = label * alpha + (1 - label) * (1 - alpha) + loss *= alpha_factor + return loss.mean(1).sum() + + class BboxLoss(nn.Module): def __init__(self, reg_max, use_dfl=False): diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 30899e4..8544adf 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -9,7 +9,6 @@ from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch -import torch.nn as nn from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings @@ -175,40 +174,6 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss return 1.0 - 0.5 * eps, 0.5 * eps -# Losses -class FocalLoss(nn.Module): - """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" - - def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): - """Initialize FocalLoss object with given loss function and hyperparameters.""" - super().__init__() - self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() - self.gamma = gamma - self.alpha = alpha - self.reduction = loss_fcn.reduction - self.loss_fcn.reduction = 'none' # required to apply FL to each element - - def forward(self, pred, true): - """Calculates and updates confusion matrix for object detection/classification tasks.""" - loss = self.loss_fcn(pred, true) - # p_t = torch.exp(-loss) - # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability - - # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py - pred_prob = torch.sigmoid(pred) # prob from logits - p_t = true * pred_prob + (1 - true) * (1 - pred_prob) - alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) - modulating_factor = (1.0 - p_t) ** self.gamma - loss *= alpha_factor * modulating_factor - - if self.reduction == 'mean': - return loss.mean() - elif self.reduction == 'sum': - return loss.sum() - else: # 'None' - return loss - - class ConfusionMatrix: """ A class for calculating and updating a confusion matrix for object detection and classification tasks. diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index cb3725a..a9d7917 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -327,6 +327,9 @@ def init_seeds(seed=0, deterministic=False): os.environ['PYTHONHASHSEED'] = str(seed) else: LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.') + else: + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False class ModelEMA: