Move loss calculation to head (#2874)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-05-30 22:54:30 +05:30
committed by GitHub
parent 7f077f7654
commit facb7861cf
9 changed files with 417 additions and 348 deletions

View File

@ -13,6 +13,7 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
RTDETRDecoder, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.yolo.utils.plotting import feature_visualization
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
@ -28,29 +29,30 @@ class BaseModel(nn.Module):
The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
"""
def forward(self, x, profile=False, visualize=False):
def forward(self, x, *args, **kwargs):
"""
Forward pass of the model on a single scale.
Wrapper for `_forward_once` method.
Args:
x (torch.Tensor): The input image tensor
profile (bool): Whether to profile the model, defaults to False
visualize (bool): Whether to return the intermediate feature maps, defaults to False
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
Returns:
(torch.Tensor): The output of the network.
"""
return self._forward_once(x, profile, visualize)
if isinstance(x, dict): # for cases of training and validating while training.
return self.loss(x, *args, **kwargs)
return self.predict(x, *args, **kwargs)
def _forward_once(self, x, profile=False, visualize=False):
def predict(self, x, profile=False, visualize=False, augment=False):
"""
Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False
visualize (bool): Save the feature maps of the model if True, defaults to False.
augment (bool): Augment image during prediction, defaults to False.
Returns:
(torch.Tensor): The last output of the model.
@ -175,6 +177,21 @@ class BaseModel(nn.Module):
if verbose:
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
def loss(self, batch, preds=None):
"""
Compute loss
Args:
batch (dict): Batch to compute loss on
preds (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
class DetectionModel(BaseModel):
"""YOLOv8 detection model."""
@ -208,11 +225,11 @@ class DetectionModel(BaseModel):
self.info()
LOGGER.info('')
def forward(self, x, augment=False, profile=False, visualize=False):
def predict(self, x, augment=False, profile=False, visualize=False):
"""Run forward pass on input image(s) with optional augmentation and profiling."""
if augment:
return self._forward_augment(x) # augmented inference, None
return self._forward_once(x, profile, visualize) # single-scale inference, train
return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train
def _forward_augment(self, x):
"""Perform augmentations on input image x and return augmented inference and train outputs."""
@ -222,7 +239,7 @@ class DetectionModel(BaseModel):
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self._forward_once(xi)[0] # forward
yi = super().predict(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi)
@ -251,6 +268,9 @@ class DetectionModel(BaseModel):
y[-1] = y[-1][..., i:] # small
return y
def init_criterion(self):
return v8DetectionLoss(self)
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
@ -263,6 +283,9 @@ class SegmentationModel(DetectionModel):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
def init_criterion(self):
return v8SegmentationLoss(self)
class PoseModel(DetectionModel):
"""YOLOv8 pose model."""
@ -276,6 +299,13 @@ class PoseModel(DetectionModel):
cfg['kpt_shape'] = data_kpt_shape
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
return v8PoseLoss(self)
def _forward_augment(self, x):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!'))
class ClassificationModel(BaseModel):
"""YOLOv8 classification model."""
@ -343,6 +373,85 @@ class ClassificationModel(BaseModel):
if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
def init_criterion(self):
"""Compute the classification loss between predictions and true labels."""
return v8ClassificationLoss()
class RTDETRDetectionModel(DetectionModel):
def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
"""Compute the classification loss between predictions and true labels."""
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True)
def loss(self, batch, preds=None):
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
img = batch['img']
# NOTE: preprocess gt_bbox and gt_labels to list.
bs = len(img)
batch_idx = batch['batch_idx']
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long))
targets = {'cls': gt_class, 'bboxes': gt_bbox}
preds = self.predict(img, batch=targets) if preds is None else preds
dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds
# NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
if dn_meta is None:
return 0, torch.zeros(3, device=dec_out_bboxes.device)
dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2)
out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits])
loss = self.criterion((out_bboxes, out_logits),
targets,
dn_out_bboxes=dn_out_bboxes,
dn_out_logits=dn_out_logits,
dn_meta=dn_meta)
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']])
def predict(self, x, profile=False, visualize=False, batch=None):
"""
Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False
batch (dict): A dict including gt boxes and labels from dataloader.
Returns:
(torch.Tensor): The last output of the model.
"""
y, dt = [], [] # outputs
for m in self.model[:-1]: # except the head part
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
self._profile_one_layer(m, x, dt)
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
head = self.model[-1]
x = head([y[j] for j in head.f], batch) # head inference
return x
def _forward_augment(self, x):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!'))
class Ensemble(nn.ModuleList):
"""Ensemble of models."""