Move loss to task heads (#2825)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Ayush Chaurasia
2023-05-25 16:07:54 +05:30
committed by GitHub
parent f23a03596d
commit d19c5b6ce8
9 changed files with 335 additions and 346 deletions

View File

@ -13,6 +13,7 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.yolo.utils.plotting import feature_visualization
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
@ -173,6 +174,23 @@ class BaseModel(nn.Module):
if verbose:
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
def loss(self, batch, preds=None):
"""
Compute loss
Args:
batch (dict): Batch to compute loss on
pred (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
preds = self.forward(batch['img']) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
class DetectionModel(BaseModel):
"""YOLOv8 detection model."""
@ -249,6 +267,9 @@ class DetectionModel(BaseModel):
y[-1] = y[-1][..., i:] # small
return y
def init_criterion(self):
return v8DetectionLoss(self)
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
@ -261,6 +282,9 @@ class SegmentationModel(DetectionModel):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
def init_criterion(self):
return v8SegmentationLoss(self)
class PoseModel(DetectionModel):
"""YOLOv8 pose model."""
@ -274,6 +298,9 @@ class PoseModel(DetectionModel):
cfg['kpt_shape'] = data_kpt_shape
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
return v8PoseLoss(self)
class ClassificationModel(BaseModel):
"""YOLOv8 classification model."""
@ -341,6 +368,10 @@ class ClassificationModel(BaseModel):
if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
def init_criterion(self):
"""Compute the classification loss between predictions and true labels."""
return v8ClassificationLoss()
class Ensemble(nn.ModuleList):
"""Ensemble of models."""