Revert loss head PR (#2873)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -13,7 +13,6 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
|
||||
RTDETRDecoder, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.yolo.utils.plotting import feature_visualization
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||
@ -176,23 +175,6 @@ class BaseModel(nn.Module):
|
||||
if verbose:
|
||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on
|
||||
pred (torch.Tensor | List[torch.Tensor]): Predictions.
|
||||
"""
|
||||
if not hasattr(self, 'criterion'):
|
||||
self.criterion = self.init_criterion()
|
||||
|
||||
preds = self.forward(batch['img']) if preds is None else preds
|
||||
return self.criterion(preds, batch)
|
||||
|
||||
def init_criterion(self):
|
||||
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
"""YOLOv8 detection model."""
|
||||
@ -269,9 +251,6 @@ class DetectionModel(BaseModel):
|
||||
y[-1] = y[-1][..., i:] # small
|
||||
return y
|
||||
|
||||
def init_criterion(self):
|
||||
return v8DetectionLoss(self)
|
||||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
"""YOLOv8 segmentation model."""
|
||||
@ -284,9 +263,6 @@ class SegmentationModel(DetectionModel):
|
||||
"""Undocumented function."""
|
||||
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
|
||||
|
||||
def init_criterion(self):
|
||||
return v8SegmentationLoss(self)
|
||||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
"""YOLOv8 pose model."""
|
||||
@ -300,9 +276,6 @@ class PoseModel(DetectionModel):
|
||||
cfg['kpt_shape'] = data_kpt_shape
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def init_criterion(self):
|
||||
return v8PoseLoss(self)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
"""YOLOv8 classification model."""
|
||||
@ -370,10 +343,6 @@ class ClassificationModel(BaseModel):
|
||||
if m[i].out_channels != nc:
|
||||
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
||||
|
||||
def init_criterion(self):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
return v8ClassificationLoss()
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
Reference in New Issue
Block a user