Move loss to task heads (#2825)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent f23a03596d
commit d19c5b6ce8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,25 +48,22 @@ trainer.train()
You now realize that you need to customize the trainer further to: You now realize that you need to customize the trainer further to:
* Customize the `loss function`. * * Customize the `loss function`.
* Add `callback` that uploads model to your Google Drive after every 10 `epochs` * Add `callback` that uploads model to your Google Drive after every 10 `epochs`
Here's how you can do it: Here's how you can do it:
```python ```python
from ultralytics.yolo.v8.detect import DetectionTrainer from ultralytics.yolo.v8.detect import DetectionTrainer
from ultralytcs.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel):
class CustomTrainer(DetectionTrainer): def init_criterion():
def get_model(self, cfg, weights):
... ...
def criterion(self, preds, batch):
# get ground truth
imgs = batch["imgs"]
bboxes = batch["bboxes"]
...
return loss, loss_items # see Reference-> Trainer for details on the expected format
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
return MyCustomModel(...)
# callback to upload model weights # callback to upload model weights
def log_model(trainer): def log_model(trainer):

@ -13,6 +13,7 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
Segment) Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.yolo.utils.plotting import feature_visualization from ultralytics.yolo.utils.plotting import feature_visualization
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
intersect_dicts, make_divisible, model_info, scale_img, time_sync) intersect_dicts, make_divisible, model_info, scale_img, time_sync)
@ -173,6 +174,23 @@ class BaseModel(nn.Module):
if verbose: if verbose:
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
def loss(self, batch, preds=None):
"""
Compute loss
Args:
batch (dict): Batch to compute loss on
pred (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, 'criterion'):
self.criterion = self.init_criterion()
preds = self.forward(batch['img']) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self):
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
class DetectionModel(BaseModel): class DetectionModel(BaseModel):
"""YOLOv8 detection model.""" """YOLOv8 detection model."""
@ -249,6 +267,9 @@ class DetectionModel(BaseModel):
y[-1] = y[-1][..., i:] # small y[-1] = y[-1][..., i:] # small
return y return y
def init_criterion(self):
return v8DetectionLoss(self)
class SegmentationModel(DetectionModel): class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model.""" """YOLOv8 segmentation model."""
@ -261,6 +282,9 @@ class SegmentationModel(DetectionModel):
"""Undocumented function.""" """Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')) raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
def init_criterion(self):
return v8SegmentationLoss(self)
class PoseModel(DetectionModel): class PoseModel(DetectionModel):
"""YOLOv8 pose model.""" """YOLOv8 pose model."""
@ -274,6 +298,9 @@ class PoseModel(DetectionModel):
cfg['kpt_shape'] = data_kpt_shape cfg['kpt_shape'] = data_kpt_shape
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
return v8PoseLoss(self)
class ClassificationModel(BaseModel): class ClassificationModel(BaseModel):
"""YOLOv8 classification model.""" """YOLOv8 classification model."""
@ -341,6 +368,10 @@ class ClassificationModel(BaseModel):
if m[i].out_channels != nc: if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
def init_criterion(self):
"""Compute the classification loss between predictions and true labels."""
return v8ClassificationLoss()
class Ensemble(nn.ModuleList): class Ensemble(nn.ModuleList):
"""Ensemble of models.""" """Ensemble of models."""

@ -325,8 +325,7 @@ class BaseTrainer:
# Forward # Forward
with torch.cuda.amp.autocast(self.amp): with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
preds = self.model(batch['img']) self.loss, self.loss_items = de_parallel(self.model).loss(batch)
self.loss, self.loss_items = self.criterion(preds, batch)
if RANK != -1: if RANK != -1:
self.loss *= world_size self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
@ -496,12 +495,6 @@ class BaseTrainer:
"""Build dataset""" """Build dataset"""
raise NotImplementedError('build_dataset function not implemented in trainer') raise NotImplementedError('build_dataset function not implemented in trainer')
def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor.
"""
raise NotImplementedError('criterion function not implemented in trainer')
def label_loss_items(self, loss_items=None, prefix='train'): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor

@ -162,7 +162,8 @@ class BaseValidator:
# Loss # Loss
with dt[2]: with dt[2]:
if self.training: if self.training:
self.loss += trainer.criterion(preds, batch)[1] loss_items = model.loss(batch, preds)
self.loss += loss_items[1]
# Postprocess # Postprocess
with dt[3]: with dt[3]:

@ -4,6 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ultralytics.yolo.utils.metrics import OKS_SIGMA
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
from .metrics import bbox_iou from .metrics import bbox_iou
from .tal import bbox2dist from .tal import bbox2dist
@ -73,3 +77,292 @@ class KeypointLoss(nn.Module):
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean() return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
# Criterion class for computing Detection training losses
class v8DetectionLoss:
def __init__(self, model): # model must be de-paralleled
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.no
self.reg_max = m.reg_max
self.device = device
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
out[j, :n] = targets[matches, 1:]
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
return out
def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds, batch):
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
# Criterion class for computing training losses
class v8SegmentationLoss(v8DetectionLoss):
def __init__(self, model, overlap=True): # model must be de-paralleled
super().__init__(model)
self.nm = model.model[-1].nm # number of masks
self.overlap = overlap
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
if fg_mask.sum():
# bbox loss
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
target_scores, target_scores_sum, fg_mask)
# masks loss
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
for i in range(batch_size):
if fg_mask[i].sum():
mask_idx = target_gt_idx[i][fg_mask[i]]
if self.overlap:
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box / batch_size # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
"""Mask loss for one image."""
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
# Criterion class for computing training losses
class v8PoseLoss(v8DetectionLoss):
def __init__(self, model): # model must be de-paralleled
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0] # number of keypoints
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds, batch):
"""Calculate the total loss and detach it."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
batch_size = pred_scores.shape[0]
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
keypoints = batch['keypoints'].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
for i in range(batch_size):
if fg_mask[i].sum():
idx = target_gt_idx[i][fg_mask[i]]
gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51)
gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
pred_kpt = pred_kpts[i][fg_mask[i]]
kpt_mask = gt_kpt[..., 2] != 0
loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
# kpt_score loss
if pred_kpt.shape[-1] == 3:
loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose / batch_size # pose gain
loss[2] *= self.hyp.kobj / batch_size # kobj gain
loss[3] *= self.hyp.cls # cls gain
loss[4] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def kpts_decode(self, anchor_points, pred_kpts):
"""Decodes predicted keypoints to image coordinates."""
y = pred_kpts.clone()
y[..., :2] *= 2.0
y[..., 0] += anchor_points[:, [0]] - 0.5
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64 # TODO: remove hardcoding
loss_items = loss.detach()
return loss, loss_items

@ -41,7 +41,6 @@ class ClassificationTrainer(BaseTrainer):
m.p = self.args.dropout # set dropout m.p = self.args.dropout # set dropout
for p in model.parameters(): for p in model.parameters():
p.requires_grad = True # for training p.requires_grad = True # for training
return model return model
def setup_model(self): def setup_model(self):
@ -103,12 +102,6 @@ class ClassificationTrainer(BaseTrainer):
self.loss_names = ['loss'] self.loss_names = ['loss']
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir) return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
def criterion(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
loss_items = loss.detach()
return loss, loss_items
def label_loss_items(self, loss_items=None, prefix='train'): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor

@ -2,8 +2,6 @@
from copy import copy from copy import copy
import numpy as np import numpy as np
import torch
import torch.nn as nn
from ultralytics.nn.tasks import DetectionModel from ultralytics.nn.tasks import DetectionModel
from ultralytics.yolo import v8 from ultralytics.yolo import v8
@ -11,10 +9,7 @@ from ultralytics.yolo.data import build_dataloader, build_yolo_dataset
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.yolo.utils.loss import BboxLoss
from ultralytics.yolo.utils.ops import xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first
@ -91,12 +86,6 @@ class DetectionTrainer(BaseTrainer):
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch):
"""Compute loss for YOLO prediction and ground-truth."""
if not hasattr(self, 'compute_loss'):
self.compute_loss = Loss(de_parallel(self.model))
return self.compute_loss(preds, batch)
def label_loss_items(self, loss_items=None, prefix='train'): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor
@ -135,102 +124,6 @@ class DetectionTrainer(BaseTrainer):
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
# Criterion class for computing training losses
class Loss:
def __init__(self, model): # model must be de-paralleled
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.no
self.reg_max = m.reg_max
self.device = device
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
out[j, :n] = targets[matches, 1:]
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
return out
def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds, batch):
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize YOLO model given training data and device.""" """Train and optimize YOLO model given training data and device."""
model = cfg.model or 'yolov8n.pt' model = cfg.model or 'yolov8n.pt'

@ -2,19 +2,10 @@
from copy import copy from copy import copy
import torch
import torch.nn as nn
from ultralytics.nn.tasks import PoseModel from ultralytics.nn.tasks import PoseModel
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.utils import DEFAULT_CFG from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.loss import KeypointLoss
from ultralytics.yolo.utils.metrics import OKS_SIGMA
from ultralytics.yolo.utils.ops import xyxy2xywh
from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel
from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage # BaseTrainer python usage
@ -45,12 +36,6 @@ class PoseTrainer(v8.detect.DetectionTrainer):
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch):
"""Computes pose loss for the YOLO model."""
if not hasattr(self, 'compute_loss'):
self.compute_loss = PoseLoss(de_parallel(self.model))
return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
images = batch['img'] images = batch['img']
@ -73,95 +58,6 @@ class PoseTrainer(v8.detect.DetectionTrainer):
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses
class PoseLoss(Loss):
def __init__(self, model): # model must be de-paralleled
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0] # number of keypoints
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds, batch):
"""Calculate the total loss and detach it."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
batch_size = pred_scores.shape[0]
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
target_scores_sum, fg_mask)
keypoints = batch['keypoints'].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
for i in range(batch_size):
if fg_mask[i].sum():
idx = target_gt_idx[i][fg_mask[i]]
gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51)
gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
pred_kpt = pred_kpts[i][fg_mask[i]]
kpt_mask = gt_kpt[..., 2] != 0
loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
# kpt_score loss
if pred_kpt.shape[-1] == 3:
loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose / batch_size # pose gain
loss[2] *= self.hyp.kobj / batch_size # kobj gain
loss[3] *= self.hyp.cls # cls gain
loss[4] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def kpts_decode(self, anchor_points, pred_kpts):
"""Decodes predicted keypoints to image coordinates."""
y = pred_kpts.clone()
y[..., :2] *= 2.0
y[..., 0] += anchor_points[:, [0]] - 0.5
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO model on the given data and device.""" """Train the YOLO model on the given data and device."""
model = cfg.model or 'yolov8n-pose.yaml' model = cfg.model or 'yolov8n-pose.yaml'

@ -1,17 +1,10 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy from copy import copy
import torch
import torch.nn.functional as F
from ultralytics.nn.tasks import SegmentationModel from ultralytics.nn.tasks import SegmentationModel
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.utils import DEFAULT_CFG, RANK from ultralytics.yolo.utils import DEFAULT_CFG, RANK
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel
from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage # BaseTrainer python usage
@ -37,12 +30,6 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch):
"""Returns the computed loss using the SegLoss class on the given predictions and batch."""
if not hasattr(self, 'compute_loss'):
self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates.""" """Creates a plot of training sample images with labels and box coordinates."""
plot_images(batch['img'], plot_images(batch['img'],
@ -59,101 +46,6 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses
class SegLoss(Loss):
def __init__(self, model, overlap=True): # model must be de-paralleled
super().__init__(model)
self.nm = model.model[-1].nm # number of masks
self.overlap = overlap
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
if fg_mask.sum():
# bbox loss
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
target_scores, target_scores_sum, fg_mask)
# masks loss
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
for i in range(batch_size):
if fg_mask[i].sum():
mask_idx = target_gt_idx[i][fg_mask[i]]
if self.overlap:
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box / batch_size # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
"""Mask loss for one image."""
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train a YOLO segmentation model based on passed arguments.""" """Train a YOLO segmentation model based on passed arguments."""
model = cfg.model or 'yolov8n-seg.pt' model = cfg.model or 'yolov8n-seg.pt'

Loading…
Cancel
Save