Add RTDETR Trainer (#2745)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>single_channel
parent
03bce07848
commit
a0ba8ef5f0
@ -0,0 +1,46 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 80 # number of classes
|
||||||
|
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
||||||
|
# [depth, width, max_channels]
|
||||||
|
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
||||||
|
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
||||||
|
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
||||||
|
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
||||||
|
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
- [-1, 1, SPPF, [1024, 5]] # 9
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
||||||
|
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
||||||
|
- [-1, 3, C2f, [512]] # 12
|
||||||
|
|
||||||
|
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
||||||
|
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
||||||
|
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
||||||
|
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]]
|
||||||
|
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
||||||
|
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
||||||
|
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]]
|
||||||
|
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
||||||
|
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
||||||
|
|
||||||
|
- [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
@ -0,0 +1,78 @@
|
|||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||||
|
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
|
||||||
|
from ultralytics.yolo.v8.detect import DetectionTrainer
|
||||||
|
|
||||||
|
from .val import RTDETRDataset, RTDETRValidator
|
||||||
|
|
||||||
|
|
||||||
|
class RTDETRTrainer(DetectionTrainer):
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return a YOLO detection model."""
|
||||||
|
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode='val', batch=None):
|
||||||
|
"""Build RTDETR Dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
return RTDETRDataset(
|
||||||
|
img_path=img_path,
|
||||||
|
imgsz=self.args.imgsz,
|
||||||
|
batch_size=batch,
|
||||||
|
augment=mode == 'train', # no augmentation
|
||||||
|
hyp=self.args,
|
||||||
|
rect=False, # no rect
|
||||||
|
cache=self.args.cache or None,
|
||||||
|
prefix=colorstr(f'{mode}: '),
|
||||||
|
data=self.data)
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a DetectionValidator for RTDETR model validation."""
|
||||||
|
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
||||||
|
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||||
|
batch = super().preprocess_batch(batch)
|
||||||
|
bs = len(batch['img'])
|
||||||
|
batch_idx = batch['batch_idx']
|
||||||
|
gt_bbox, gt_class = [], []
|
||||||
|
for i in range(bs):
|
||||||
|
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
|
||||||
|
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||||
|
"""Train and optimize RTDETR model given training data and device."""
|
||||||
|
model = 'rtdetr-l.yaml'
|
||||||
|
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
|
||||||
|
device = cfg.device if cfg.device is not None else ''
|
||||||
|
|
||||||
|
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
|
||||||
|
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
|
||||||
|
args = dict(model=model,
|
||||||
|
data=data,
|
||||||
|
device=device,
|
||||||
|
imgsz=640,
|
||||||
|
exist_ok=True,
|
||||||
|
batch=4,
|
||||||
|
deterministic=False,
|
||||||
|
amp=False)
|
||||||
|
trainer = RTDETRTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train()
|
@ -0,0 +1,291 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ultralytics.vit.utils.ops import HungarianMatcher
|
||||||
|
from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss
|
||||||
|
from ultralytics.yolo.utils.metrics import bbox_iou
|
||||||
|
|
||||||
|
|
||||||
|
class DETRLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
nc=80,
|
||||||
|
loss_gain=None,
|
||||||
|
aux_loss=True,
|
||||||
|
use_fl=True,
|
||||||
|
use_vfl=False,
|
||||||
|
use_uni_match=False,
|
||||||
|
uni_match_ind=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
nc (int): The number of classes.
|
||||||
|
loss_gain (dict): The coefficient of loss.
|
||||||
|
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
|
||||||
|
use_focal_loss (bool): Use focal loss or not.
|
||||||
|
use_vfl (bool): Use VarifocalLoss or not.
|
||||||
|
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
|
||||||
|
uni_match_ind (int): The fixed indices of a layer.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if loss_gain is None:
|
||||||
|
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
|
||||||
|
self.nc = nc
|
||||||
|
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
|
||||||
|
self.loss_gain = loss_gain
|
||||||
|
self.aux_loss = aux_loss
|
||||||
|
self.fl = FocalLoss() if use_fl else None
|
||||||
|
self.vfl = VarifocalLoss() if use_vfl else None
|
||||||
|
|
||||||
|
self.use_uni_match = use_uni_match
|
||||||
|
self.uni_match_ind = uni_match_ind
|
||||||
|
self.device = None
|
||||||
|
|
||||||
|
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
|
||||||
|
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||||
|
name_class = f'loss_class{postfix}'
|
||||||
|
bs, nq = pred_scores.shape[:2]
|
||||||
|
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||||
|
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||||
|
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
|
||||||
|
one_hot = one_hot[..., :-1]
|
||||||
|
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
|
||||||
|
|
||||||
|
if self.fl:
|
||||||
|
if num_gts and self.vfl:
|
||||||
|
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
|
||||||
|
else:
|
||||||
|
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||||
|
loss_cls /= max(num_gts, 1) / nq
|
||||||
|
else:
|
||||||
|
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||||
|
|
||||||
|
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
|
||||||
|
|
||||||
|
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
|
||||||
|
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||||
|
name_bbox = f'loss_bbox{postfix}'
|
||||||
|
name_giou = f'loss_giou{postfix}'
|
||||||
|
|
||||||
|
loss = {}
|
||||||
|
if len(gt_bboxes) == 0:
|
||||||
|
loss[name_bbox] = torch.tensor(0., device=self.device)
|
||||||
|
loss[name_giou] = torch.tensor(0., device=self.device)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
|
||||||
|
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||||
|
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||||
|
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
||||||
|
loss = {k: v.squeeze() for k, v in loss.items()}
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||||
|
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||||
|
name_mask = f'loss_mask{postfix}'
|
||||||
|
name_dice = f'loss_dice{postfix}'
|
||||||
|
|
||||||
|
loss = {}
|
||||||
|
if sum(len(a) for a in gt_mask) == 0:
|
||||||
|
loss[name_mask] = torch.tensor(0., device=self.device)
|
||||||
|
loss[name_dice] = torch.tensor(0., device=self.device)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
num_gts = len(gt_mask)
|
||||||
|
src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||||
|
src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||||
|
# TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||||
|
loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||||
|
torch.tensor([num_gts], dtype=torch.float32))
|
||||||
|
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _dice_loss(self, inputs, targets, num_gts):
|
||||||
|
inputs = F.sigmoid(inputs)
|
||||||
|
inputs = inputs.flatten(1)
|
||||||
|
targets = targets.flatten(1)
|
||||||
|
numerator = 2 * (inputs * targets).sum(1)
|
||||||
|
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||||
|
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||||
|
return loss.sum() / num_gts
|
||||||
|
|
||||||
|
def _get_loss_aux(self,
|
||||||
|
pred_bboxes,
|
||||||
|
pred_scores,
|
||||||
|
gt_bboxes,
|
||||||
|
gt_cls,
|
||||||
|
gt_groups,
|
||||||
|
match_indices=None,
|
||||||
|
postfix='',
|
||||||
|
masks=None,
|
||||||
|
gt_mask=None):
|
||||||
|
"""Get auxiliary losses"""
|
||||||
|
# NOTE: loss class, bbox, giou, mask, dice
|
||||||
|
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||||
|
if match_indices is None and self.use_uni_match:
|
||||||
|
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
|
||||||
|
pred_scores[self.uni_match_ind],
|
||||||
|
gt_bboxes,
|
||||||
|
gt_cls,
|
||||||
|
gt_groups,
|
||||||
|
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||||
|
gt_mask=gt_mask)
|
||||||
|
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||||
|
aux_masks = masks[i] if masks is not None else None
|
||||||
|
loss_ = self._get_loss(aux_bboxes,
|
||||||
|
aux_scores,
|
||||||
|
gt_bboxes,
|
||||||
|
gt_cls,
|
||||||
|
gt_groups,
|
||||||
|
masks=aux_masks,
|
||||||
|
gt_mask=gt_mask,
|
||||||
|
postfix=postfix,
|
||||||
|
match_indices=match_indices)
|
||||||
|
loss[0] += loss_[f'loss_class{postfix}']
|
||||||
|
loss[1] += loss_[f'loss_bbox{postfix}']
|
||||||
|
loss[2] += loss_[f'loss_giou{postfix}']
|
||||||
|
# if masks is not None and gt_mask is not None:
|
||||||
|
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||||
|
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||||
|
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||||
|
|
||||||
|
loss = {
|
||||||
|
f'loss_class_aux{postfix}': loss[0],
|
||||||
|
f'loss_bbox_aux{postfix}': loss[1],
|
||||||
|
f'loss_giou_aux{postfix}': loss[2]}
|
||||||
|
# if masks is not None and gt_mask is not None:
|
||||||
|
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||||
|
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _get_index(self, match_indices):
|
||||||
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||||
|
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||||
|
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||||
|
return (batch_idx, src_idx), dst_idx
|
||||||
|
|
||||||
|
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
||||||
|
pred_assigned = torch.cat([
|
||||||
|
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||||
|
for t, (I, _) in zip(pred_bboxes, match_indices)])
|
||||||
|
gt_assigned = torch.cat([
|
||||||
|
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||||
|
for t, (_, J) in zip(gt_bboxes, match_indices)])
|
||||||
|
return pred_assigned, gt_assigned
|
||||||
|
|
||||||
|
def _get_loss(self,
|
||||||
|
pred_bboxes,
|
||||||
|
pred_scores,
|
||||||
|
gt_bboxes,
|
||||||
|
gt_cls,
|
||||||
|
gt_groups,
|
||||||
|
masks=None,
|
||||||
|
gt_mask=None,
|
||||||
|
postfix='',
|
||||||
|
match_indices=None):
|
||||||
|
"""Get losses"""
|
||||||
|
if match_indices is None:
|
||||||
|
match_indices = self.matcher(pred_bboxes,
|
||||||
|
pred_scores,
|
||||||
|
gt_bboxes,
|
||||||
|
gt_cls,
|
||||||
|
gt_groups,
|
||||||
|
masks=masks,
|
||||||
|
gt_mask=gt_mask)
|
||||||
|
|
||||||
|
idx, gt_idx = self._get_index(match_indices)
|
||||||
|
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||||
|
|
||||||
|
bs, nq = pred_scores.shape[:2]
|
||||||
|
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
|
||||||
|
targets[idx] = gt_cls[gt_idx]
|
||||||
|
|
||||||
|
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
|
||||||
|
if len(gt_bboxes):
|
||||||
|
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
|
||||||
|
|
||||||
|
loss = {}
|
||||||
|
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
|
||||||
|
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
|
||||||
|
# if masks is not None and gt_mask is not None:
|
||||||
|
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
||||||
|
pred_scores (torch.Tensor): [l, b, query, num_classes]
|
||||||
|
batch (dict): A dict includes:
|
||||||
|
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||||||
|
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||||||
|
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||||
|
postfix (str): postfix of loss name.
|
||||||
|
"""
|
||||||
|
self.device = pred_bboxes.device
|
||||||
|
match_indices = kwargs.get('match_indices', None)
|
||||||
|
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
|
||||||
|
|
||||||
|
total_loss = self._get_loss(pred_bboxes[-1],
|
||||||
|
pred_scores[-1],
|
||||||
|
gt_bboxes,
|
||||||
|
gt_cls,
|
||||||
|
gt_groups,
|
||||||
|
postfix=postfix,
|
||||||
|
match_indices=match_indices)
|
||||||
|
|
||||||
|
if self.aux_loss:
|
||||||
|
total_loss.update(
|
||||||
|
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
|
||||||
|
postfix))
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
|
class RTDETRDetectionLoss(DETRLoss):
|
||||||
|
|
||||||
|
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
|
||||||
|
pred_bboxes, pred_scores = preds
|
||||||
|
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||||
|
|
||||||
|
if dn_meta is not None:
|
||||||
|
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
||||||
|
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
||||||
|
|
||||||
|
# denoising match indices
|
||||||
|
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
||||||
|
|
||||||
|
# compute denoising training loss
|
||||||
|
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
||||||
|
total_loss.update(dn_loss)
|
||||||
|
else:
|
||||||
|
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
|
||||||
|
"""Get the match indices for denoising.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
|
||||||
|
dn_num_group (int): The number of groups of denoising.
|
||||||
|
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dn_match_indices (List(tuple)): Matched indices.
|
||||||
|
|
||||||
|
"""
|
||||||
|
dn_match_indices = []
|
||||||
|
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||||
|
for i, num_gt in enumerate(gt_groups):
|
||||||
|
if num_gt > 0:
|
||||||
|
gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i]
|
||||||
|
gt_idx = gt_idx.repeat(dn_num_group)
|
||||||
|
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
||||||
|
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
||||||
|
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||||
|
else:
|
||||||
|
dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32)))
|
||||||
|
return dn_match_indices
|
@ -0,0 +1,230 @@
|
|||||||
|
# TODO: license
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils.metrics import bbox_iou
|
||||||
|
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
|
||||||
|
|
||||||
|
|
||||||
|
class HungarianMatcher(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
matcher_coeff (dict): The coefficient of hungarian matcher cost.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if cost_gain is None:
|
||||||
|
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
||||||
|
self.cost_gain = cost_gain
|
||||||
|
self.use_fl = use_fl
|
||||||
|
self.with_mask = with_mask
|
||||||
|
self.num_sample_points = num_sample_points
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pred_bboxes (Tensor): [b, query, 4]
|
||||||
|
pred_scores (Tensor): [b, query, num_classes]
|
||||||
|
gt_cls (torch.Tensor) with shape [num_gts, ]
|
||||||
|
gt_bboxes (torch.Tensor): [num_gts, 4]
|
||||||
|
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||||
|
masks (Tensor|None): [b, query, h, w]
|
||||||
|
gt_mask (List(Tensor)): list[[n, H, W]]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||||
|
- index_i is the indices of the selected predictions (in order)
|
||||||
|
- index_j is the indices of the corresponding selected targets (in order)
|
||||||
|
For each batch element, it holds:
|
||||||
|
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||||
|
"""
|
||||||
|
bs, nq, nc = pred_scores.shape
|
||||||
|
|
||||||
|
if sum(gt_groups) == 0:
|
||||||
|
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
|
||||||
|
|
||||||
|
# We flatten to compute the cost matrices in a batch
|
||||||
|
# [batch_size * num_queries, num_classes]
|
||||||
|
pred_scores = pred_scores.detach().view(-1, nc)
|
||||||
|
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
||||||
|
# [batch_size * num_queries, 4]
|
||||||
|
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
||||||
|
|
||||||
|
# Compute the classification cost
|
||||||
|
pred_scores = pred_scores[:, gt_cls]
|
||||||
|
if self.use_fl:
|
||||||
|
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||||||
|
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||||||
|
cost_class = pos_cost_class - neg_cost_class
|
||||||
|
else:
|
||||||
|
cost_class = -pred_scores
|
||||||
|
|
||||||
|
# Compute the L1 cost between boxes
|
||||||
|
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
||||||
|
|
||||||
|
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
|
||||||
|
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||||||
|
|
||||||
|
# Final cost matrix
|
||||||
|
C = self.cost_gain['class'] * cost_class + \
|
||||||
|
self.cost_gain['bbox'] * cost_bbox + \
|
||||||
|
self.cost_gain['giou'] * cost_giou
|
||||||
|
# Compute the mask cost and dice cost
|
||||||
|
if self.with_mask:
|
||||||
|
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||||||
|
|
||||||
|
C = C.view(bs, nq, -1).cpu()
|
||||||
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||||
|
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||||
|
# (idx for queries, idx for gt)
|
||||||
|
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
|
||||||
|
for k, (i, j) in enumerate(indices)]
|
||||||
|
|
||||||
|
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||||
|
assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
||||||
|
# all masks share the same set of points for efficient matching
|
||||||
|
sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
||||||
|
sample_points = 2.0 * sample_points - 1.0
|
||||||
|
|
||||||
|
out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
||||||
|
out_mask = out_mask.flatten(0, 1)
|
||||||
|
|
||||||
|
tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
||||||
|
sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
||||||
|
tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(False):
|
||||||
|
# binary cross entropy cost
|
||||||
|
pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
||||||
|
neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
||||||
|
cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
||||||
|
cost_mask /= self.num_sample_points
|
||||||
|
|
||||||
|
# dice cost
|
||||||
|
out_mask = F.sigmoid(out_mask)
|
||||||
|
numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
||||||
|
denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
||||||
|
cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
||||||
|
|
||||||
|
C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def get_cdn_group(batch,
|
||||||
|
num_classes,
|
||||||
|
num_queries,
|
||||||
|
class_embed,
|
||||||
|
num_dn=100,
|
||||||
|
cls_noise_ratio=0.5,
|
||||||
|
box_noise_scale=1.0,
|
||||||
|
training=False):
|
||||||
|
"""Get contrastive denoising training group
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (dict): A dict includes:
|
||||||
|
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||||||
|
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||||||
|
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||||
|
num_classes (int): Number of classes.
|
||||||
|
num_queries (int): Number of queries.
|
||||||
|
class_embed (torch.Tensor): Embedding weights to map cls to embedding space.
|
||||||
|
num_dn (int): Number of denoising.
|
||||||
|
cls_noise_ratio (float): Noise ratio for class.
|
||||||
|
box_noise_scale (float): Noise scale for bbox.
|
||||||
|
training (bool): If it's training or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if (not training) or num_dn <= 0:
|
||||||
|
return None, None, None, None
|
||||||
|
gt_groups = batch['gt_groups']
|
||||||
|
total_num = sum(gt_groups)
|
||||||
|
max_nums = max(gt_groups)
|
||||||
|
if max_nums == 0:
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
num_group = num_dn // max_nums
|
||||||
|
num_group = 1 if num_group == 0 else num_group
|
||||||
|
# pad gt to max_num of a batch
|
||||||
|
bs = len(gt_groups)
|
||||||
|
gt_cls = batch['cls'] # (bs*num, )
|
||||||
|
gt_bbox = batch['bboxes'] # bs*num, 4
|
||||||
|
b_idx = batch['batch_idx']
|
||||||
|
|
||||||
|
# each group has positive and negative queries.
|
||||||
|
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||||
|
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||||
|
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||||
|
|
||||||
|
# positive and negative mask
|
||||||
|
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||||
|
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||||
|
|
||||||
|
if cls_noise_ratio > 0:
|
||||||
|
# half of bbox prob
|
||||||
|
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
||||||
|
idx = torch.nonzero(mask).squeeze(-1)
|
||||||
|
# randomly put a new one here
|
||||||
|
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
||||||
|
dn_cls[idx] = new_label
|
||||||
|
|
||||||
|
if box_noise_scale > 0:
|
||||||
|
known_bbox = xywh2xyxy(dn_bbox)
|
||||||
|
|
||||||
|
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
||||||
|
|
||||||
|
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
||||||
|
rand_part = torch.rand_like(dn_bbox)
|
||||||
|
rand_part[neg_idx] += 1.0
|
||||||
|
rand_part *= rand_sign
|
||||||
|
known_bbox += rand_part * diff
|
||||||
|
known_bbox.clip_(min=0.0, max=1.0)
|
||||||
|
dn_bbox = xyxy2xywh(known_bbox)
|
||||||
|
dn_bbox = inverse_sigmoid(dn_bbox)
|
||||||
|
|
||||||
|
# total denoising queries
|
||||||
|
num_dn = int(max_nums * 2 * num_group)
|
||||||
|
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
||||||
|
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||||
|
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||||
|
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
||||||
|
|
||||||
|
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
||||||
|
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
||||||
|
|
||||||
|
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
||||||
|
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
||||||
|
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
||||||
|
|
||||||
|
tgt_size = num_dn + num_queries
|
||||||
|
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
||||||
|
# match query cannot see the reconstruct
|
||||||
|
attn_mask[num_dn:, :num_dn] = True
|
||||||
|
# reconstruct cannot see each other
|
||||||
|
for i in range(num_group):
|
||||||
|
if i == 0:
|
||||||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||||
|
if i == num_group - 1:
|
||||||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
|
||||||
|
else:
|
||||||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
||||||
|
dn_meta = {
|
||||||
|
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)],
|
||||||
|
'dn_num_group': num_group,
|
||||||
|
'dn_num_split': [num_dn, num_queries]}
|
||||||
|
|
||||||
|
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
||||||
|
class_embed.device), dn_meta
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_sigmoid(x, eps=1e-6):
|
||||||
|
x = x.clip(min=0., max=1.)
|
||||||
|
return torch.log(x / (1 - x + eps) + eps)
|
Loading…
Reference in new issue