Revert loss head PR (#2873)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-05-28 19:45:41 +05:30
committed by GitHub
parent 6391c60089
commit 527a97759b
8 changed files with 335 additions and 327 deletions

View File

@ -13,7 +13,6 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
RTDETRDecoder, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.yolo.utils.plotting import feature_visualization
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
@ -176,23 +175,6 @@ class BaseModel(nn.Module):
if verbose:
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
def loss(self, batch, preds=None):
"""
Compute loss
Args:
batch (dict): Batch to compute loss on
pred (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
preds = self.forward(batch['img']) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
class DetectionModel(BaseModel):
"""YOLOv8 detection model."""
@ -269,9 +251,6 @@ class DetectionModel(BaseModel):
y[-1] = y[-1][..., i:] # small
return y
def init_criterion(self):
return v8DetectionLoss(self)
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
@ -284,9 +263,6 @@ class SegmentationModel(DetectionModel):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
def init_criterion(self):
return v8SegmentationLoss(self)
class PoseModel(DetectionModel):
"""YOLOv8 pose model."""
@ -300,9 +276,6 @@ class PoseModel(DetectionModel):
cfg['kpt_shape'] = data_kpt_shape
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
return v8PoseLoss(self)
class ClassificationModel(BaseModel):
"""YOLOv8 classification model."""
@ -370,10 +343,6 @@ class ClassificationModel(BaseModel):
if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
def init_criterion(self):
"""Compute the classification loss between predictions and true labels."""
return v8ClassificationLoss()
class Ensemble(nn.ModuleList):
"""Ensemble of models."""