From d19c5b6ce8bceeb9aa5ac9c24de1e882a691ceb3 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 25 May 2023 16:07:54 +0530 Subject: [PATCH] Move loss to task heads (#2825) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> --- docs/usage/engine.md | 19 +- ultralytics/nn/tasks.py | 31 +++ ultralytics/yolo/engine/trainer.py | 9 +- ultralytics/yolo/engine/validator.py | 3 +- ultralytics/yolo/utils/loss.py | 293 ++++++++++++++++++++++++++ ultralytics/yolo/v8/classify/train.py | 7 - ultralytics/yolo/v8/detect/train.py | 107 ---------- ultralytics/yolo/v8/pose/train.py | 104 --------- ultralytics/yolo/v8/segment/train.py | 108 ---------- 9 files changed, 335 insertions(+), 346 deletions(-) diff --git a/docs/usage/engine.md b/docs/usage/engine.md index 2a77f7e..24bc22f 100644 --- a/docs/usage/engine.md +++ b/docs/usage/engine.md @@ -48,25 +48,22 @@ trainer.train() You now realize that you need to customize the trainer further to: -* Customize the `loss function`. +* * Customize the `loss function`. * Add `callback` that uploads model to your Google Drive after every 10 `epochs` Here's how you can do it: ```python from ultralytics.yolo.v8.detect import DetectionTrainer +from ultralytcs.nn.tasks import DetectionModel + +class MyCustomModel(DetectionModel): + def init_criterion(): + ... class CustomTrainer(DetectionTrainer): def get_model(self, cfg, weights): - ... - - def criterion(self, preds, batch): - # get ground truth - imgs = batch["imgs"] - bboxes = batch["bboxes"] - ... - return loss, loss_items # see Reference-> Trainer for details on the expected format - + return MyCustomModel(...) # callback to upload model weights def log_model(trainer): @@ -84,4 +81,4 @@ To know more about Callback triggering events and entry point, checkout our [Cal ## Other engine components There are other components that can be customized similarly like `Validators` and `Predictors` -See Reference section for more information on these. \ No newline at end of file +See Reference section for more information on these. diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 8564560..2a6ae97 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -13,6 +13,7 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec Segment) from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml +from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss from ultralytics.yolo.utils.plotting import feature_visualization from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, make_divisible, model_info, scale_img, time_sync) @@ -173,6 +174,23 @@ class BaseModel(nn.Module): if verbose: LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') + def loss(self, batch, preds=None): + """ + Compute loss + + Args: + batch (dict): Batch to compute loss on + pred (torch.Tensor | List[torch.Tensor]): Predictions. + """ + if not hasattr(self, 'criterion'): + self.criterion = self.init_criterion() + + preds = self.forward(batch['img']) if preds is None else preds + return self.criterion(preds, batch) + + def init_criterion(self): + raise NotImplementedError('compute_loss() needs to be implemented by task heads') + class DetectionModel(BaseModel): """YOLOv8 detection model.""" @@ -249,6 +267,9 @@ class DetectionModel(BaseModel): y[-1] = y[-1][..., i:] # small return y + def init_criterion(self): + return v8DetectionLoss(self) + class SegmentationModel(DetectionModel): """YOLOv8 segmentation model.""" @@ -261,6 +282,9 @@ class SegmentationModel(DetectionModel): """Undocumented function.""" raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')) + def init_criterion(self): + return v8SegmentationLoss(self) + class PoseModel(DetectionModel): """YOLOv8 pose model.""" @@ -274,6 +298,9 @@ class PoseModel(DetectionModel): cfg['kpt_shape'] = data_kpt_shape super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + def init_criterion(self): + return v8PoseLoss(self) + class ClassificationModel(BaseModel): """YOLOv8 classification model.""" @@ -341,6 +368,10 @@ class ClassificationModel(BaseModel): if m[i].out_channels != nc: m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) + def init_criterion(self): + """Compute the classification loss between predictions and true labels.""" + return v8ClassificationLoss() + class Ensemble(nn.ModuleList): """Ensemble of models.""" diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 80e8ca0..6926201 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -325,8 +325,7 @@ class BaseTrainer: # Forward with torch.cuda.amp.autocast(self.amp): batch = self.preprocess_batch(batch) - preds = self.model(batch['img']) - self.loss, self.loss_items = self.criterion(preds, batch) + self.loss, self.loss_items = de_parallel(self.model).loss(batch) if RANK != -1: self.loss *= world_size self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ @@ -496,12 +495,6 @@ class BaseTrainer: """Build dataset""" raise NotImplementedError('build_dataset function not implemented in trainer') - def criterion(self, preds, batch): - """ - Returns loss and individual loss items as Tensor. - """ - raise NotImplementedError('criterion function not implemented in trainer') - def label_loss_items(self, loss_items=None, prefix='train'): """ Returns a loss dict with labelled training loss items tensor diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index ad107a2..09f297b 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -162,7 +162,8 @@ class BaseValidator: # Loss with dt[2]: if self.training: - self.loss += trainer.criterion(preds, batch)[1] + loss_items = model.loss(batch, preds) + self.loss += loss_items[1] # Postprocess with dt[3]: diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py index 73aba68..5266982 100644 --- a/ultralytics/yolo/utils/loss.py +++ b/ultralytics/yolo/utils/loss.py @@ -4,6 +4,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ultralytics.yolo.utils.metrics import OKS_SIGMA +from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh +from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors + from .metrics import bbox_iou from .tal import bbox2dist @@ -73,3 +77,292 @@ class KeypointLoss(nn.Module): # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean() + + +# Criterion class for computing Detection training losses +class v8DetectionLoss: + + def __init__(self, model): # model must be de-paralleled + + device = next(model.parameters()).device # get model device + h = model.args # hyperparameters + + m = model.model[-1] # Detect() module + self.bce = nn.BCEWithLogitsLoss(reduction='none') + self.hyp = h + self.stride = m.stride # model strides + self.nc = m.nc # number of classes + self.no = m.no + self.reg_max = m.reg_max + self.device = device + + self.use_dfl = m.reg_max > 1 + + self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) + self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocesses the target counts and matches with the input batch size to output a tensor.""" + if targets.shape[0] == 0: + out = torch.zeros(batch_size, 0, 5, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), 5, device=self.device) + for j in range(batch_size): + matches = i == j + n = matches.sum() + if n: + out[j, :n] = targets[matches, 1:] + out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) + return out + + def bbox_decode(self, anchor_points, pred_dist): + """Decode predicted object bounding box coordinates from anchor points and distribution.""" + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) + return dist2bbox(pred_dist, anchor_points, xywh=False) + + def __call__(self, preds, batch): + """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats = preds[1] if isinstance(preds, tuple) else preds + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1) + + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + batch_size = pred_scores.shape[0] + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + + # pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + + target_scores_sum = max(target_scores.sum(), 1) + + # cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # bbox loss + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, + target_scores_sum, fg_mask) + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.cls # cls gain + loss[2] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + +# Criterion class for computing training losses +class v8SegmentationLoss(v8DetectionLoss): + + def __init__(self, model, overlap=True): # model must be de-paralleled + super().__init__(model) + self.nm = model.model[-1].nm # number of masks + self.overlap = overlap + + def __call__(self, preds, batch): + """Calculate and return the loss for the YOLO model.""" + loss = torch.zeros(4, device=self.device) # box, cls, dfl + feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] + batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1) + + # b, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_masks = pred_masks.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + try: + batch_idx = batch['batch_idx'].view(-1, 1) + targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + except RuntimeError as e: + raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n' + "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a " + "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' " + 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e + + # pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + + target_scores_sum = max(target_scores.sum(), 1) + + # cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + if fg_mask.sum(): + # bbox loss + loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, + target_scores, target_scores_sum, fg_mask) + # masks loss + masks = batch['masks'].to(self.device).float() + if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] + + for i in range(batch_size): + if fg_mask[i].sum(): + mask_idx = target_gt_idx[i][fg_mask[i]] + if self.overlap: + gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0) + else: + gt_mask = masks[batch_idx.view(-1) == i][mask_idx] + xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] + marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) + mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) + loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg + + # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.box / batch_size # seg gain + loss[2] *= self.hyp.cls # cls gain + loss[3] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): + """Mask loss for one image.""" + pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none') + return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() + + +# Criterion class for computing training losses +class v8PoseLoss(v8DetectionLoss): + + def __init__(self, model): # model must be de-paralleled + super().__init__(model) + self.kpt_shape = model.model[-1].kpt_shape + self.bce_pose = nn.BCEWithLogitsLoss() + is_pose = self.kpt_shape == [17, 3] + nkpt = self.kpt_shape[0] # number of keypoints + sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt + self.keypoint_loss = KeypointLoss(sigmas=sigmas) + + def __call__(self, preds, batch): + """Calculate the total loss and detach it.""" + loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility + feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1) + + # b, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_kpts = pred_kpts.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + batch_size = pred_scores.shape[0] + batch_idx = batch['batch_idx'].view(-1, 1) + targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + + # pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + + target_scores_sum = max(target_scores.sum(), 1) + + # cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # bbox loss + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, + target_scores_sum, fg_mask) + keypoints = batch['keypoints'].to(self.device).float().clone() + keypoints[..., 0] *= imgsz[1] + keypoints[..., 1] *= imgsz[0] + for i in range(batch_size): + if fg_mask[i].sum(): + idx = target_gt_idx[i][fg_mask[i]] + gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51) + gt_kpt[..., 0] /= stride_tensor[fg_mask[i]] + gt_kpt[..., 1] /= stride_tensor[fg_mask[i]] + area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True) + pred_kpt = pred_kpts[i][fg_mask[i]] + kpt_mask = gt_kpt[..., 2] != 0 + loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss + # kpt_score loss + if pred_kpt.shape[-1] == 3: + loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.pose / batch_size # pose gain + loss[2] *= self.hyp.kobj / batch_size # kobj gain + loss[3] *= self.hyp.cls # cls gain + loss[4] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + def kpts_decode(self, anchor_points, pred_kpts): + """Decodes predicted keypoints to image coordinates.""" + y = pred_kpts.clone() + y[..., :2] *= 2.0 + y[..., 0] += anchor_points[:, [0]] - 0.5 + y[..., 1] += anchor_points[:, [1]] - 0.5 + return y + + +class v8ClassificationLoss: + + def __call__(self, preds, batch): + """Compute the classification loss between predictions and true labels.""" + loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64 # TODO: remove hardcoding + loss_items = loss.detach() + return loss, loss_items diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 6c8b657..2949644 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -41,7 +41,6 @@ class ClassificationTrainer(BaseTrainer): m.p = self.args.dropout # set dropout for p in model.parameters(): p.requires_grad = True # for training - return model def setup_model(self): @@ -103,12 +102,6 @@ class ClassificationTrainer(BaseTrainer): self.loss_names = ['loss'] return v8.classify.ClassificationValidator(self.test_loader, self.save_dir) - def criterion(self, preds, batch): - """Compute the classification loss between predictions and true labels.""" - loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs - loss_items = loss.detach() - return loss, loss_items - def label_loss_items(self, loss_items=None, prefix='train'): """ Returns a loss dict with labelled training loss items tensor diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 7f0fc1d..1b475ed 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -2,8 +2,6 @@ from copy import copy import numpy as np -import torch -import torch.nn as nn from ultralytics.nn.tasks import DetectionModel from ultralytics.yolo import v8 @@ -11,10 +9,7 @@ from ultralytics.yolo.data import build_dataloader, build_yolo_dataset from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr -from ultralytics.yolo.utils.loss import BboxLoss -from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results -from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first @@ -91,12 +86,6 @@ class DetectionTrainer(BaseTrainer): self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) - def criterion(self, preds, batch): - """Compute loss for YOLO prediction and ground-truth.""" - if not hasattr(self, 'compute_loss'): - self.compute_loss = Loss(de_parallel(self.model)) - return self.compute_loss(preds, batch) - def label_loss_items(self, loss_items=None, prefix='train'): """ Returns a loss dict with labelled training loss items tensor @@ -135,102 +124,6 @@ class DetectionTrainer(BaseTrainer): plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) -# Criterion class for computing training losses -class Loss: - - def __init__(self, model): # model must be de-paralleled - - device = next(model.parameters()).device # get model device - h = model.args # hyperparameters - - m = model.model[-1] # Detect() module - self.bce = nn.BCEWithLogitsLoss(reduction='none') - self.hyp = h - self.stride = m.stride # model strides - self.nc = m.nc # number of classes - self.no = m.no - self.reg_max = m.reg_max - self.device = device - - self.use_dfl = m.reg_max > 1 - - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) - self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) - self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) - - def preprocess(self, targets, batch_size, scale_tensor): - """Preprocesses the target counts and matches with the input batch size to output a tensor.""" - if targets.shape[0] == 0: - out = torch.zeros(batch_size, 0, 5, device=self.device) - else: - i = targets[:, 0] # image index - _, counts = i.unique(return_counts=True) - counts = counts.to(dtype=torch.int32) - out = torch.zeros(batch_size, counts.max(), 5, device=self.device) - for j in range(batch_size): - matches = i == j - n = matches.sum() - if n: - out[j, :n] = targets[matches, 1:] - out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) - return out - - def bbox_decode(self, anchor_points, pred_dist): - """Decode predicted object bounding box coordinates from anchor points and distribution.""" - if self.use_dfl: - b, a, c = pred_dist.shape # batch, anchors, channels - pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) - # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) - # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) - return dist2bbox(pred_dist, anchor_points, xywh=False) - - def __call__(self, preds, batch): - """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" - loss = torch.zeros(3, device=self.device) # box, cls, dfl - feats = preds[1] if isinstance(preds, tuple) else preds - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) - - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - batch_size = pred_scores.shape[0] - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # targets - targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1) - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - - # pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) - - _, target_bboxes, target_scores, fg_mask, _ = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) - - target_scores_sum = max(target_scores.sum(), 1) - - # cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - # bbox loss - if fg_mask.sum(): - target_bboxes /= stride_tensor - loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.cls # cls gain - loss[2] *= self.hyp.dfl # dfl gain - - return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) - - def train(cfg=DEFAULT_CFG, use_python=False): """Train and optimize YOLO model given training data and device.""" model = cfg.model or 'yolov8n.pt' diff --git a/ultralytics/yolo/v8/pose/train.py b/ultralytics/yolo/v8/pose/train.py index b7890f8..af3043c 100644 --- a/ultralytics/yolo/v8/pose/train.py +++ b/ultralytics/yolo/v8/pose/train.py @@ -2,19 +2,10 @@ from copy import copy -import torch -import torch.nn as nn - from ultralytics.nn.tasks import PoseModel from ultralytics.yolo import v8 from ultralytics.yolo.utils import DEFAULT_CFG -from ultralytics.yolo.utils.loss import KeypointLoss -from ultralytics.yolo.utils.metrics import OKS_SIGMA -from ultralytics.yolo.utils.ops import xyxy2xywh from ultralytics.yolo.utils.plotting import plot_images, plot_results -from ultralytics.yolo.utils.tal import make_anchors -from ultralytics.yolo.utils.torch_utils import de_parallel -from ultralytics.yolo.v8.detect.train import Loss # BaseTrainer python usage @@ -45,12 +36,6 @@ class PoseTrainer(v8.detect.DetectionTrainer): self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) - def criterion(self, preds, batch): - """Computes pose loss for the YOLO model.""" - if not hasattr(self, 'compute_loss'): - self.compute_loss = PoseLoss(de_parallel(self.model)) - return self.compute_loss(preds, batch) - def plot_training_samples(self, batch, ni): """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" images = batch['img'] @@ -73,95 +58,6 @@ class PoseTrainer(v8.detect.DetectionTrainer): plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png -# Criterion class for computing training losses -class PoseLoss(Loss): - - def __init__(self, model): # model must be de-paralleled - super().__init__(model) - self.kpt_shape = model.model[-1].kpt_shape - self.bce_pose = nn.BCEWithLogitsLoss() - is_pose = self.kpt_shape == [17, 3] - nkpt = self.kpt_shape[0] # number of keypoints - sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt - self.keypoint_loss = KeypointLoss(sigmas=sigmas) - - def __call__(self, preds, batch): - """Calculate the total loss and detach it.""" - loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility - feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) - - # b, grids, .. - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - pred_kpts = pred_kpts.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # targets - batch_size = pred_scores.shape[0] - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - - # pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) - pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) - - _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) - - target_scores_sum = max(target_scores.sum(), 1) - - # cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - # bbox loss - if fg_mask.sum(): - target_bboxes /= stride_tensor - loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) - keypoints = batch['keypoints'].to(self.device).float().clone() - keypoints[..., 0] *= imgsz[1] - keypoints[..., 1] *= imgsz[0] - for i in range(batch_size): - if fg_mask[i].sum(): - idx = target_gt_idx[i][fg_mask[i]] - gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51) - gt_kpt[..., 0] /= stride_tensor[fg_mask[i]] - gt_kpt[..., 1] /= stride_tensor[fg_mask[i]] - area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True) - pred_kpt = pred_kpts[i][fg_mask[i]] - kpt_mask = gt_kpt[..., 2] != 0 - loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss - # kpt_score loss - if pred_kpt.shape[-1] == 3: - loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.pose / batch_size # pose gain - loss[2] *= self.hyp.kobj / batch_size # kobj gain - loss[3] *= self.hyp.cls # cls gain - loss[4] *= self.hyp.dfl # dfl gain - - return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) - - def kpts_decode(self, anchor_points, pred_kpts): - """Decodes predicted keypoints to image coordinates.""" - y = pred_kpts.clone() - y[..., :2] *= 2.0 - y[..., 0] += anchor_points[:, [0]] - 0.5 - y[..., 1] += anchor_points[:, [1]] - 0.5 - return y - - def train(cfg=DEFAULT_CFG, use_python=False): """Train the YOLO model on the given data and device.""" model = cfg.model or 'yolov8n-pose.yaml' diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 32ce510..ab66cf0 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -1,17 +1,10 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from copy import copy -import torch -import torch.nn.functional as F - from ultralytics.nn.tasks import SegmentationModel from ultralytics.yolo import v8 from ultralytics.yolo.utils import DEFAULT_CFG, RANK -from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh from ultralytics.yolo.utils.plotting import plot_images, plot_results -from ultralytics.yolo.utils.tal import make_anchors -from ultralytics.yolo.utils.torch_utils import de_parallel -from ultralytics.yolo.v8.detect.train import Loss # BaseTrainer python usage @@ -37,12 +30,6 @@ class SegmentationTrainer(v8.detect.DetectionTrainer): self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) - def criterion(self, preds, batch): - """Returns the computed loss using the SegLoss class on the given predictions and batch.""" - if not hasattr(self, 'compute_loss'): - self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask) - return self.compute_loss(preds, batch) - def plot_training_samples(self, batch, ni): """Creates a plot of training sample images with labels and box coordinates.""" plot_images(batch['img'], @@ -59,101 +46,6 @@ class SegmentationTrainer(v8.detect.DetectionTrainer): plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png -# Criterion class for computing training losses -class SegLoss(Loss): - - def __init__(self, model, overlap=True): # model must be de-paralleled - super().__init__(model) - self.nm = model.model[-1].nm # number of masks - self.overlap = overlap - - def __call__(self, preds, batch): - """Calculate and return the loss for the YOLO model.""" - loss = torch.zeros(4, device=self.device) # box, cls, dfl - feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] - batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) - - # b, grids, .. - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - pred_masks = pred_masks.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # targets - try: - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - except RuntimeError as e: - raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n' - "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " - "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a " - "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' " - 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e - - # pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) - - _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) - - target_scores_sum = max(target_scores.sum(), 1) - - # cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - if fg_mask.sum(): - # bbox loss - loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, - target_scores, target_scores_sum, fg_mask) - # masks loss - masks = batch['masks'].to(self.device).float() - if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample - masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] - - for i in range(batch_size): - if fg_mask[i].sum(): - mask_idx = target_gt_idx[i][fg_mask[i]] - if self.overlap: - gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0) - else: - gt_mask = masks[batch_idx.view(-1) == i][mask_idx] - xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] - marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) - mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) - loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg - - # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove - else: - loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss - - # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove - else: - loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.box / batch_size # seg gain - loss[2] *= self.hyp.cls # cls gain - loss[3] *= self.hyp.dfl # dfl gain - - return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) - - def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): - """Mask loss for one image.""" - pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80) - loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none') - return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() - - def train(cfg=DEFAULT_CFG, use_python=False): """Train a YOLO segmentation model based on passed arguments.""" model = cfg.model or 'yolov8n-seg.pt'