From f56c9bcc26f49abf1010bd84146f5add6c2e6e4b Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 8 Nov 2022 20:57:57 +0530 Subject: [PATCH] Segmentation support & other enchancements (#40) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- .github/workflows/ci.yaml | 16 +- ultralytics/yolo/data/dataset.py | 3 +- ultralytics/yolo/engine/trainer.py | 57 ++- ultralytics/yolo/engine/validator.py | 40 +- ultralytics/yolo/utils/configs/default.yaml | 19 + ultralytics/yolo/utils/metrics.py | 489 ++++++++++++++++++++ ultralytics/yolo/utils/ops.py | 78 +++- ultralytics/yolo/utils/torch_utils.py | 10 + ultralytics/yolo/v8/__init__.py | 4 +- ultralytics/yolo/v8/classify/train.py | 15 +- ultralytics/yolo/v8/classify/val.py | 10 +- ultralytics/yolo/v8/models/yolov5n-seg.yaml | 48 ++ ultralytics/yolo/v8/models/yolov5n.yaml | 48 ++ ultralytics/yolo/v8/segment/__init__.py | 2 + ultralytics/yolo/v8/segment/train.py | 269 +++++++++++ ultralytics/yolo/v8/segment/val.py | 211 +++++++++ ultralytics/yolov5n-seg.yaml | 48 ++ 17 files changed, 1320 insertions(+), 47 deletions(-) create mode 100644 ultralytics/yolo/v8/models/yolov5n-seg.yaml create mode 100644 ultralytics/yolo/v8/models/yolov5n.yaml create mode 100644 ultralytics/yolo/v8/segment/__init__.py create mode 100644 ultralytics/yolo/v8/segment/train.py create mode 100644 ultralytics/yolo/v8/segment/val.py create mode 100644 ultralytics/yolov5n-seg.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dcc8165..c9d041f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,8 @@ jobs: os: [ ubuntu-latest ] python-version: [ '3.10' ] model: [ yolov5n ] - include: + torch: [ latest ] +# include: # - os: ubuntu-latest # python-version: '3.7' # '3.6.8' min # model: yolov5n @@ -31,10 +32,10 @@ jobs: # - os: ubuntu-latest # python-version: '3.9' # model: yolov5n - - os: ubuntu-latest - python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8 - model: yolov5n - torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/ +# - os: ubuntu-latest +# python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8 +# model: yolov5n +# torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/ steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -93,9 +94,8 @@ jobs: - name: Test segmentation shell: bash # for Windows compatibility run: | - echo "TODO" + python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=1 img_size=64 - name: Test classification shell: bash # for Windows compatibility run: | - echo "TODO" - # python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist2560 epochs=1 img_size=64 + python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist160 epochs=1 img_size=32 diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index e332aed..c8cf73e 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -1,6 +1,7 @@ from itertools import repeat from multiprocessing.pool import Pool from pathlib import Path +from typing import OrderedDict import torchvision from tqdm import tqdm @@ -205,7 +206,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] else: sample = self.torch_transforms(im) - return sample, j + return OrderedDict(img=sample, cls=j) # TODO: support semantic segmentation diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 3693fea..edbfa1b 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -1,12 +1,17 @@ """ Simple training loop; Boilerplate that could apply to any arbitrary neural network, """ +# TODOs +# 1. finish _set_model_attributes +# 2. allow num_class update for both pretrained and csv_loaded models +# 3. save import os import time from collections import defaultdict from datetime import datetime from pathlib import Path +from telnetlib import TLS from typing import Dict, Union import torch @@ -52,6 +57,8 @@ class BaseTrainer: # Model and Dataloaders. self.trainset, self.testset = self.get_dataset(self.args.data) + if self.args.cfg is not None: + self.model = self.load_cfg(self.args.cfg) if self.args.model is not None: self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) @@ -133,6 +140,20 @@ class BaseTrainer: self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) self.validator = self.get_validator() print("created testloader :", rank) + self.console.info(self.progress_string()) + + def _set_model_attributes(self): + # TODO: fix and use after self.data_dict is available + ''' + head = utils.torch_utils.de_parallel(self.model).model[-1] + self.args.box *= 3 / head.nl # scale to layers + self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers + self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers + model.nc = nc # attach number of classes to model + model.hyp = hyp # attach hyperparameters to model + model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights + model.names = names + ''' def _do_train(self, rank, world_size): if world_size > 1: @@ -153,13 +174,17 @@ class BaseTrainer: pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') - tloss = 0 - for i, (images, labels) in pbar: + tloss = None + for i, batch in pbar: + # img, label (classification)/ img, targets, paths, _, masks(detection) # callback hook. on_batch_start # forward - images, labels = self.preprocess_batch(images, labels) - self.loss = self.criterion(self.model(images), labels) - tloss = (tloss * i + self.loss.item()) / (i + 1) + batch = self.preprocess_batch(batch) + + # TODO: warmup, multiscale + preds = self.model(batch["img"]) + self.loss, self.loss_items = self.criterion(preds, batch) + tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items # backward self.model.zero_grad(set_to_none=True) @@ -170,9 +195,13 @@ class BaseTrainer: self.trigger_callbacks('on_batch_end') # log - mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) + mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) + loss_len = tloss.shape[0] if len(tloss.size()) else 1 + losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0) if rank in {-1, 0}: - pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 + pbar.set_description( + (" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses, + batch["img"].shape[-1])) if rank in [-1, 0]: # validation @@ -240,6 +269,9 @@ class BaseTrainer: return model + def load_cfg(self, cfg): + raise NotImplementedError("This task trainer doesn't support loading cfg files") + def get_validator(self): pass @@ -250,11 +282,11 @@ class BaseTrainer: self.scaler.update() self.optimizer.zero_grad() - def preprocess_batch(self, images, labels): + def preprocess_batch(self, batch): """ Allows custom preprocessing model inputs and ground truths depending on task type """ - return images.to(self.device, non_blocking=True), labels.to(self.device) + return batch def validate(self): """ @@ -270,14 +302,17 @@ class BaseTrainer: def build_targets(self, preds, targets): pass - def criterion(self, preds, targets): + def criterion(self, preds, batch): + """ + Returns loss and individual loss items as Tensor + """ pass def progress_string(self): """ Returns progress string depending on task type. """ - pass + return '' def usage_help(self): """ diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index b1fa885..7ede84c 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -1,8 +1,10 @@ import logging import torch +from omegaconf import DictConfig, OmegaConf from tqdm import tqdm +from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.torch_utils import select_device @@ -12,12 +14,15 @@ class BaseValidator: Base validator class. """ - def __init__(self, dataloader, device='', half=False, pbar=None, logger=None): + def __init__(self, dataloader, pbar=None, logger=None, args=None): self.dataloader = dataloader - self.half = half - self.device = select_device(device, dataloader.batch_size) self.pbar = pbar self.logger = logger or logging.getLogger() + self.args = args or OmegaConf.load(DEFAULT_CONFIG) + self.device = select_device(self.args.device, dataloader.batch_size) + self.cuda = self.device.type != 'cpu' + self.batch_i = None + self.training = True def __call__(self, trainer=None, model=None): """ @@ -25,45 +30,48 @@ class BaseValidator: if trainer is passed (trainer gets priority). """ training = trainer is not None + self.training = training # trainer = trainer or self.trainer_class.get_trainer() assert training or model is not None, "Either trainer or model is needed for validation" if training: model = trainer.model - self.half &= self.device.type != 'cpu' - model = model.half() if self.half else model + self.args.half &= self.device.type != 'cpu' + model = model.half() if self.args.half else model else: # TODO: handle this when detectMultiBackend is supported # model = DetectMultiBacked(model) pass + # TODO: implement init_model_attributes() model.eval() dt = Profile(), Profile(), Profile(), Profile() loss = 0 n_batches = len(self.dataloader) - desc = self.set_desc() + desc = self.get_desc() bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') - self.init_metrics() + self.init_metrics(model) with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'): - for images, labels in bar: + for batch_i, batch in enumerate(bar): + self.batch_i = batch_i # pre-process with dt[0]: - images, labels = self.preprocess_batch(images, labels) + batch = self.preprocess_batch(batch) # inference with dt[1]: - preds = model(images) + preds = model(batch["img"]) # TODO: remember to add native augmentation support when implementing model, like: # preds, train_out = model(im, augment=augment) # loss with dt[2]: if training: - loss += trainer.criterion(preds, labels) / images.shape[0] + loss += trainer.criterion(preds, batch)[0] # pre-process predictions with dt[3]: preds = self.preprocess_preds(preds) - self.update_metrics(preds, labels) + self.update_metrics(preds, batch) stats = self.get_stats() self.check_stats(stats) @@ -81,8 +89,8 @@ class BaseValidator: return stats - def preprocess_batch(self, images, labels): - return images.to(self.device, non_blocking=True), labels.to(self.device) + def preprocess_batch(self, batch): + return batch def preprocess_preds(self, preds): return preds @@ -90,7 +98,7 @@ class BaseValidator: def init_metrics(self): pass - def update_metrics(self, preds, targets): + def update_metrics(self, preds, batch): pass def get_stats(self): @@ -102,5 +110,5 @@ class BaseValidator: def print_results(self): pass - def set_desc(self): + def get_desc(self): pass diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml index 8508744..b85c63d 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/utils/configs/default.yaml @@ -4,6 +4,7 @@ # Train settings ------------------------------------------------------------------------------------------------------- model: null # i.e. yolov5s.pt +cfg: null # i.e. yolov5s.yaml data: null # i.e. coco128.yaml epochs: 300 batch_size: 16 @@ -20,6 +21,23 @@ optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] verbose: False seed: 0 local_rank: -1 +single_cls: False # train multi-class data as single-class +image_weights: False # use weighted image selection for training +shuffle: True +rect: False # support rectangular training +overlap_mask: True # Segmentation masks overlap +mask_ratio: 4 # Segmentation mask downsample ratio + +# Val/Test settings ---------------------------------------------------------------------------------------------------- +save_json: False +save_hybrid: False +conf_thres: 0.001 +iou_thres: 0.6 +max_det: 300 +half: True +plots: False +save_txt: False +task: 'val' # Hyperparameters ------------------------------------------------------------------------------------------------------ lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) @@ -51,6 +69,7 @@ fliplr: 0.5 # image flip left-right (probability) mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) copy_paste: 0.0 # segment copy-paste (probability) +label_smoothing: 0.0 # Hydra configs -------------------------------------------------------------------------------------------------------- hydra: diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 51172de..62bdcc9 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -2,11 +2,19 @@ """ Model validation metrics """ +import math +import warnings +from pathlib import Path +import matplotlib.pyplot as plt import numpy as np import torch +import torch.nn as nn +from ultralytics.yolo.utils import TryExcept + +# boxes def box_area(box): # box = xyxy(4,n) return (box[2] - box[0]) * (box[3] - box[1]) @@ -53,3 +61,484 @@ def box_iou(box1, box2, eps=1e-7): # IoU = inter / (area1 + area2 - inter) return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter + eps) + + +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4) + + # Get the coordinates of bounding boxes + if xywh: # transform from xywh to xyxy + (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1) + w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 + b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ + b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ + else: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1) + b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1) + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # Intersection area + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + + # Union Area + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union + if CIoU or DIoU or GIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + return iou - rho2 / c2 # DIoU + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf + return iou # IoU + + +def mask_iou(mask1, mask2, eps=1e-7): + """ + mask1: [N, n] m1 means number of predicted objects + mask2: [M, n] m2 means number of gt objects + Note: n means image_w x image_h + return: masks iou, [N, M] + """ + intersection = torch.matmul(mask1, mask2.t()).clamp(0) + union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection + return intersection / (union + eps) + + +def masks_iou(mask1, mask2, eps=1e-7): + """ + mask1: [N, n] m1 means number of predicted objects + mask2: [N, n] m2 means number of gt objects + Note: n means image_w x image_h + return: masks iou, (N, ) + """ + intersection = (mask1 * mask2).sum(1).clamp(0) # (N, ) + union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection + return intersection / (union + eps) + + +def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 + # return positive, negative label smoothing BCE targets + return 1.0 - 0.5 * eps, 0.5 * eps + + +# losses +class FocalLoss(nn.Module): + # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) + def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + super().__init__() + self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() + self.gamma = gamma + self.alpha = alpha + self.reduction = loss_fcn.reduction + self.loss_fcn.reduction = 'none' # required to apply FL to each element + + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = torch.sigmoid(pred) # prob from logits + p_t = true * pred_prob + (1 - true) * (1 - pred_prob) + alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) + modulating_factor = (1.0 - p_t) ** self.gamma + loss *= alpha_factor * modulating_factor + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + +class ConfusionMatrix: + # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix + def __init__(self, nc, conf=0.25, iou_thres=0.45): + self.matrix = np.zeros((nc + 1, nc + 1)) + self.nc = nc # number of classes + self.conf = conf + self.iou_thres = iou_thres + + def process_batch(self, detections, labels): + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + detections (Array[N, 6]), x1, y1, x2, y2, conf, class + labels (Array[M, 5]), class, x1, y1, x2, y2 + Returns: + None, updates confusion matrix accordingly + """ + if detections is None: + gt_classes = labels.int() + for gc in gt_classes: + self.matrix[self.nc, gc] += 1 # background FN + return + + detections = detections[detections[:, 4] > self.conf] + gt_classes = labels[:, 0].int() + detection_classes = detections[:, 5].int() + iou = box_iou(labels[:, 1:], detections[:, :4]) + + x = torch.where(iou > self.iou_thres) + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + else: + matches = np.zeros((0, 3)) + + n = matches.shape[0] > 0 + m0, m1, _ = matches.transpose().astype(int) + for i, gc in enumerate(gt_classes): + j = m0 == i + if n and sum(j) == 1: + self.matrix[detection_classes[m1[j]], gc] += 1 # correct + else: + self.matrix[self.nc, gc] += 1 # true background + + if n: + for i, dc in enumerate(detection_classes): + if not any(m1 == i): + self.matrix[dc, self.nc] += 1 # predicted background + + def matrix(self): + return self.matrix + + def tp_fp(self): + tp = self.matrix.diagonal() # true positives + fp = self.matrix.sum(1) - tp # false positives + # fn = self.matrix.sum(0) - tp # false negatives (missed detections) + return tp[:-1], fp[:-1] # remove background class + + @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure') + def plot(self, normalize=True, save_dir='', names=()): + import seaborn as sn + + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns + array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) + + fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) + nc, nn = self.nc, len(names) # number of classes, names + sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size + labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels + ticklabels = (names + ['background']) if labels else "auto" + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered + sn.heatmap(array, + ax=ax, + annot=nc < 30, + annot_kws={ + "size": 8}, + cmap='Blues', + fmt='.2f', + square=True, + vmin=0.0, + xticklabels=ticklabels, + yticklabels=ticklabels).set_facecolor((1, 1, 1)) + ax.set_ylabel('True') + ax.set_ylabel('Predicted') + ax.set_title('Confusion Matrix') + fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) + plt.close(fig) + + def print(self): + for i in range(self.nc + 1): + print(' '.join(map(str, self.matrix[i]))) + + +def fitness_detection(x): + # Model fitness as a weighted combination of metrics + w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] + return (x[:, :4] * w).sum(1) + + +def fitness_segmentation(x): + # Model fitness as a weighted combination of metrics + w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9] + return (x[:, :8] * w).sum(1) + + +def smooth(y, f=0.05): + # Box filter of fraction f + nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) + p = np.ones(nf // 2) # ones padding + yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded + return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed + + +def compute_ap(recall, precision): + """ Compute the average precision, given the recall and precision curves + # Arguments + recall: The recall curve (list) + precision: The precision curve (list) + # Returns + Average precision, precision curve, recall curve + """ + + # Append sentinel values to beginning and end + mrec = np.concatenate(([0.0], recall, [1.0])) + mpre = np.concatenate(([1.0], precision, [0.0])) + + # Compute the precision envelope + mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) + + # Integrate area under curve + method = 'interp' # methods: 'continuous', 'interp' + if method == 'interp': + x = np.linspace(0, 1, 101) # 101-point interp (COCO) + ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate + else: # 'continuous' + i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve + + return ap, mpre, mrec + + +def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=""): + """ Compute the average precision, given the recall and precision curves. + Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. + # Arguments + tp: True positives (nparray, nx1 or nx10). + conf: Objectness value from 0-1 (nparray). + pred_cls: Predicted object classes (nparray). + target_cls: True object classes (nparray). + plot: Plot precision-recall curve at mAP@0.5 + save_dir: Plot save directory + # Returns + The average precision as computed in py-faster-rcnn. + """ + + # Sort by objectness + i = np.argsort(-conf) + tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] + + # Find unique classes + unique_classes, nt = np.unique(target_cls, return_counts=True) + nc = unique_classes.shape[0] # number of classes, number of detections + + # Create Precision-Recall curve and compute AP for each class + px, py = np.linspace(0, 1, 1000), [] # for plotting + ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + for ci, c in enumerate(unique_classes): + i = pred_cls == c + n_l = nt[ci] # number of labels + n_p = i.sum() # number of predictions + if n_p == 0 or n_l == 0: + continue + + # Accumulate FPs and TPs + fpc = (1 - tp[i]).cumsum(0) + tpc = tp[i].cumsum(0) + + # Recall + recall = tpc / (n_l + eps) # recall curve + r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases + + # Precision + precision = tpc / (tpc + fpc) # precision curve + p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score + + # AP from recall-precision curve + for j in range(tp.shape[1]): + ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) + if plot and j == 0: + py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 + + # Compute F1 (harmonic mean of precision and recall) + f1 = 2 * p * r / (p + r + eps) + names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data + names = dict(enumerate(names)) # to dict + # TODO: plot + ''' + if plot: + plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names) + plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1') + plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision') + plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall') + ''' + + i = smooth(f1.mean(0), 0.1).argmax() # max F1 index + p, r, f1 = p[:, i], r[:, i], f1[:, i] + tp = (r * nt).round() # true positives + fp = (tp / (p + eps) - tp).round() # false positives + return tp, fp, p, r, f1, ap, unique_classes.astype(int) + + +def ap_per_class_box_and_mask( + tp_m, + tp_b, + conf, + pred_cls, + target_cls, + plot=False, + save_dir=".", + names=(), +): + """ + Args: + tp_b: tp of boxes. + tp_m: tp of masks. + other arguments see `func: ap_per_class`. + """ + results_boxes = ap_per_class(tp_b, + conf, + pred_cls, + target_cls, + plot=plot, + save_dir=save_dir, + names=names, + prefix="Box")[2:] + results_masks = ap_per_class(tp_m, + conf, + pred_cls, + target_cls, + plot=plot, + save_dir=save_dir, + names=names, + prefix="Mask")[2:] + + results = { + "boxes": { + "p": results_boxes[0], + "r": results_boxes[1], + "ap": results_boxes[3], + "f1": results_boxes[2], + "ap_class": results_boxes[4]}, + "masks": { + "p": results_masks[0], + "r": results_masks[1], + "ap": results_masks[3], + "f1": results_masks[2], + "ap_class": results_masks[4]}} + return results + + +class Metric: + + def __init__(self) -> None: + self.p = [] # (nc, ) + self.r = [] # (nc, ) + self.f1 = [] # (nc, ) + self.all_ap = [] # (nc, 10) + self.ap_class_index = [] # (nc, ) + + @property + def ap50(self): + """AP@0.5 of all classes. + Return: + (nc, ) or []. + """ + return self.all_ap[:, 0] if len(self.all_ap) else [] + + @property + def ap(self): + """AP@0.5:0.95 + Return: + (nc, ) or []. + """ + return self.all_ap.mean(1) if len(self.all_ap) else [] + + @property + def mp(self): + """mean precision of all classes. + Return: + float. + """ + return self.p.mean() if len(self.p) else 0.0 + + @property + def mr(self): + """mean recall of all classes. + Return: + float. + """ + return self.r.mean() if len(self.r) else 0.0 + + @property + def map50(self): + """Mean AP@0.5 of all classes. + Return: + float. + """ + return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 + + @property + def map(self): + """Mean AP@0.5:0.95 of all classes. + Return: + float. + """ + return self.all_ap.mean() if len(self.all_ap) else 0.0 + + def mean_results(self): + """Mean of results, return mp, mr, map50, map""" + return (self.mp, self.mr, self.map50, self.map) + + def class_result(self, i): + """class-aware result, return p[i], r[i], ap50[i], ap[i]""" + return (self.p[i], self.r[i], self.ap50[i], self.ap[i]) + + def get_maps(self, nc): + maps = np.zeros(nc) + self.map + for i, c in enumerate(self.ap_class_index): + maps[c] = self.ap[i] + return maps + + def update(self, results): + """ + Args: + results: tuple(p, r, ap, f1, ap_class) + """ + p, r, all_ap, f1, ap_class_index = results + self.p = p + self.r = r + self.all_ap = all_ap + self.f1 = f1 + self.ap_class_index = ap_class_index + + +class Metrics: + """Metric for boxes and masks.""" + + def __init__(self) -> None: + self.metric_box = Metric() + self.metric_mask = Metric() + + def update(self, results): + """ + Args: + results: Dict{'boxes': Dict{}, 'masks': Dict{}} + """ + self.metric_box.update(list(results["boxes"].values())) + self.metric_mask.update(list(results["masks"].values())) + + def mean_results(self): + return self.metric_box.mean_results() + self.metric_mask.mean_results() + + def class_result(self, i): + return self.metric_box.class_result(i) + self.metric_mask.class_result(i) + + def get_maps(self, nc): + return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc) + + @property + def ap_class_index(self): + # boxes and masks have the same ap_class_index + return self.metric_box.ap_class_index diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index 4fd5add..197e3ff 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -5,6 +5,7 @@ import time import cv2 import numpy as np import torch +import torch.nn.functional as F import torchvision from ultralytics.yolo.utils import LOGGER @@ -32,14 +33,23 @@ class Profile(contextlib.ContextDecorator): return time.time() +def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) + # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ + # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') + # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') + # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco + # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet + return [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + + def segment2box(segment, width=640, height=640): # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) x, y = segment.T # segment xy inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) - x, y, = ( - x[inside], - y[inside], - ) + x, y, = x[inside], y[inside] return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy @@ -304,3 +314,63 @@ def resample_segments(segments, n=1000): xp = np.arange(len(s)) segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy return segments + + +def crop_mask(masks, boxes): + """ + "Crop" predicted masks by zeroing out everything not in the predicted bbox. + Vectorized by Chong (thanks Chong). + Args: + - masks should be a size [h, w, n] tensor of masks + - boxes should be a size [n, 4] tensor of bbox coords in relative point form + """ + + n, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask_upsample(protos, masks_in, bboxes, shape): + """ + Crop after upsample. + proto_out: [mask_dim, mask_h, mask_w] + out_masks: [n, mask_dim], n is number of masks after nms + bboxes: [n, 4], n is number of masks after nms + shape:input_image_size, (h, w) + return: h, w, n + """ + + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) + masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.5) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Crop before upsample. + proto_out: [mask_dim, mask_h, mask_w] + out_masks: [n, mask_dim], n is number of masks after nms + bboxes: [n, 4], n is number of masks after nms + shape:input_image_size, (h, w) + return: h, w, n + """ + + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= mw / iw + downsampled_bboxes[:, 2] *= mw / iw + downsampled_bboxes[:, 3] *= mh / ih + downsampled_bboxes[:, 1] *= mh / ih + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + return masks.gt_(0.5) diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 0466810..c5c3be5 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -179,3 +179,13 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): def intersect_state_dicts(da, db, exclude=()): # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} + + +def is_parallel(model): + # Returns True if model is of type DP or DDP + return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + + +def de_parallel(model): + # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + return model.module if is_parallel(model) else model diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py index 6d856d6..a18b41a 100644 --- a/ultralytics/yolo/v8/__init__.py +++ b/ultralytics/yolo/v8/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path -from ultralytics.yolo.v8 import classify +from ultralytics.yolo.v8 import classify, segment ROOT = Path(__file__).parents[0] # yolov8 ROOT -__all__ = ["classify"] +__all__ = ["classify", "segment"] diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 8712d90..e6f9990 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -38,13 +38,22 @@ class ClassificationTrainer(BaseTrainer): return train_set, test_set def get_dataloader(self, dataset_path, batch_size=None, rank=0): - return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank) + return build_classification_dataloader(path=dataset_path, + imgsz=self.args.img_size, + batch_size=self.args.batch_size, + rank=rank) + + def preprocess_batch(self, batch): + batch["img"] = batch["img"].to(self.device) + batch["cls"] = batch["cls"].to(self.device) + return batch def get_validator(self): return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console) - def criterion(self, preds, targets): - return torch.nn.functional.cross_entropy(preds, targets) + def criterion(self, preds, batch): + loss = torch.nn.functional.cross_entropy(preds, batch["cls"]) + return loss, loss @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index 3d3b4e9..f24ae7f 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -5,10 +5,16 @@ from ultralytics.yolo.engine.validator import BaseValidator class ClassificationValidator(BaseValidator): - def init_metrics(self): + def init_metrics(self, model): self.correct = torch.tensor([]) - def update_metrics(self, preds, targets): + def preprocess_batch(self, batch): + batch["img"] = batch["img"].to(self.device) + batch["cls"] = batch["cls"].to(self.device) + return batch + + def update_metrics(self, preds, batch): + targets = batch["cls"] correct_in_batch = (targets[:, None] == preds).float() self.correct = torch.cat((self.correct, correct_in_batch)) diff --git a/ultralytics/yolo/v8/models/yolov5n-seg.yaml b/ultralytics/yolo/v8/models/yolov5n-seg.yaml new file mode 100644 index 0000000..c28225a --- /dev/null +++ b/ultralytics/yolo/v8/models/yolov5n-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] diff --git a/ultralytics/yolo/v8/models/yolov5n.yaml b/ultralytics/yolo/v8/models/yolov5n.yaml new file mode 100644 index 0000000..8a28a40 --- /dev/null +++ b/ultralytics/yolo/v8/models/yolov5n.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/ultralytics/yolo/v8/segment/__init__.py b/ultralytics/yolo/v8/segment/__init__.py new file mode 100644 index 0000000..3575c95 --- /dev/null +++ b/ultralytics/yolo/v8/segment/__init__.py @@ -0,0 +1,2 @@ +from ultralytics.yolo.v8.segment.train import SegmentationTrainer +from ultralytics.yolo.v8.segment.val import SegmentationValidator diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py new file mode 100644 index 0000000..5bc0137 --- /dev/null +++ b/ultralytics/yolo/v8/segment/train.py @@ -0,0 +1,269 @@ +import subprocess +import time +from pathlib import Path + +import hydra +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.yolo import v8 +from ultralytics.yolo.data import build_dataloader +from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer +from ultralytics.yolo.utils.downloads import download +from ultralytics.yolo.utils.files import WorkingDirectory +from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE +from ultralytics.yolo.utils.modeling.tasks import SegmentationModel +from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy +from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first + + +# BaseTrainer python usage +class SegmentationTrainer(BaseTrainer): + + def get_dataset(self, dataset): + # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module + data = Path("datasets") / dataset + with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): + data_dir = data if data.is_dir() else (Path.cwd() / data) + if not data_dir.is_dir(): + self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') + t = time.time() + if str(data) == 'imagenet': + subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) + else: + url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' + download(url, dir=data_dir.parent) + # TODO: add colorstr + s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" + self.console.info(s) + train_set = data_dir.parent / "coco128-seg" + test_set = train_set + return train_set, test_set + + def get_dataloader(self, dataset_path, batch_size, rank=0): + # TODO: manage splits differently + # calculate stride - check if model is initialized + gs = max(int(self.model.stride.max() if self.model else 0), 32) + loader = build_dataloader( + img_path=dataset_path, + img_size=self.args.img_size, + batch_size=batch_size, + single_cls=self.args.single_cls, + cache=self.args.cache, + image_weights=self.args.image_weights, + stride=gs, + rect=self.args.rect, + rank=rank, + workers=self.args.workers, + shuffle=self.args.shuffle, + use_segments=True, + )[0] + return loader + + def preprocess_batch(self, batch): + batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 + return batch + + def load_cfg(self, cfg): + return SegmentationModel(cfg, nc=80) + + def get_validator(self): + return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console) + + def criterion(self, preds, batch): + head = de_parallel(self.model).model[-1] + sort_obj_iou = False + autobalance = False + + # init losses + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device)) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device)) + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets + + # Focal loss + g = self.args.fl_gamma + if self.args.fl_gamma > 0: + BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) + + balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 + ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index + BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance + + def single_mask_loss(gt_mask, pred, proto, xyxy, area): + # Mask loss for one image + pred_mask = (pred @ proto.view(head.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") + return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() + + def build_targets(p, targets): + # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + nonlocal head + na, nt = head.na, targets.shape[0] # number of anchors, targets + tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], [] + gain = torch.ones(8, device=self.device) # normalized to gridspace gain + ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, + nt) # same as .repeat_interleave(nt) + if self.args.overlap_mask: + batch = p[0].shape[0] + ti = [] + for i in range(batch): + num = (targets[:, 0] == i).sum() # find number of targets of each image + ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1) # (na, num) + ti = torch.cat(ti, 1) # (na, nt) + else: + ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1) + targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor( + [ + [0, 0], + [1, 0], + [0, 1], + [-1, 0], + [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], + device=self.device).float() * g # offsets + + for i in range(head.nl): + anchors, shape = head.anchors[i], p[i].shape + gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + t = targets * gain # shape(3,n,7) + if nt: + # Matches + r = t[..., 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare + # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter + + # Offsets + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1 < g) & (gxy > 1)).T + l, m = ((gxi % 1 < g) & (gxi > 1)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 + + # Define + bc, gxy, gwh, at = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors + (a, tidx), (b, c) = at.long().T, bc.long().T # anchors, image, class + gij = (gxy - offsets).long() + gi, gj = gij.T # grid indices + + # Append + indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid + tbox.append(torch.cat((gxy - gij, gwh), 1)) # box + anch.append(anchors[a]) # anchors + tcls.append(c) # class + tidxs.append(tidx) + xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6]) # xywh normalized + + return tcls, tbox, indices, anch, tidxs, xywhn + + if self.model.training: + p, proto, = preds + else: + p, proto, train_out = preds + p = train_out + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + masks = batch["masks"] + targets, masks = targets.to(self.device), masks.to(self.device).float() + + bs, nm, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width + lcls = torch.zeros(1, device=self.device) + lbox = torch.zeros(1, device=self.device) + lobj = torch.zeros(1, device=self.device) + lseg = torch.zeros(1, device=self.device) + tcls, tbox, indices, anchors, tidxs, xywhn = build_targets(p, targets) + + # Losses + for i, pi in enumerate(p): # layer index, layer predictions + b, a, gj, gi = indices[i] # image, anchor, gridy, gridx + tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj + + n = b.shape[0] # number of targets + if n: + pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, head.nc, nm), 1) # subset of predictions + + # Box regression + pxy = pxy.sigmoid() * 2 - 0.5 + pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] + pbox = torch.cat((pxy, pwh), 1) # predicted box + iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss + + # Objectness + iou = iou.detach().clamp(0).type(tobj.dtype) + if sort_obj_iou: + j = iou.argsort() + b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j] + if gr < 1: + iou = (1.0 - gr) + gr * iou + tobj[b, a, gj, gi] = iou # iou ratio + + # Classification + if head.nc > 1: # cls loss (only if multiple classes) + t = torch.full_like(pcls, cn, device=self.device) # targets + t[range(n), tcls[i]] = cp + lcls += BCEcls(pcls, t) # BCE + + # Mask regression + if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] + marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized + mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)) + for bi in b.unique(): + j = b == bi # matching index + if True: + mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0) + else: + mask_gti = masks[tidxs[i]][j] + lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j]) + + obji = BCEobj(pi[..., 4], tobj) + lobj += obji * balance[i] # obj loss + if autobalance: + balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item() + + if autobalance: + balance = [x / balance[ssi] for x in balance] + lbox *= self.args.box + lobj *= self.args.obj + lcls *= self.args.cls + lseg *= self.args.box / bs + + loss = lbox + lobj + lcls + lseg + return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach() + + def progress_string(self): + return ('\n' + '%11s' * 7) % \ + ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size') + + +@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +def train(cfg): + cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml" + cfg.data = cfg.data or "coco128-segments" # or yolo.ClassificationDataset("mnist") + trainer = SegmentationTrainer(cfg) + trainer.train() + + +if __name__ == "__main__": + """ + CLI usage: + python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 + + TODO: + Direct cli support, i.e, yolov8 classify_train args.epochs 10 + """ + train() diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py new file mode 100644 index 0000000..40aa7ab --- /dev/null +++ b/ultralytics/yolo/v8/segment/val.py @@ -0,0 +1,211 @@ +import os +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.yolo.engine.validator import BaseValidator +from ultralytics.yolo.utils import ops +from ultralytics.yolo.utils.checks import check_requirements +from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou, + fitness_segmentation, mask_iou) +from ultralytics.yolo.utils.modeling import yaml_load +from ultralytics.yolo.utils.torch_utils import de_parallel + + +class SegmentationValidator(BaseValidator): + + def __init__(self, dataloader, pbar=None, logger=None, args=None): + super().__init__(dataloader, pbar, logger, args) + if self.args.save_json: + check_requirements(['pycocotools']) + self.process = ops.process_mask_upsample # more accurate + else: + self.process = ops.process_mask # faster + self.data_dict = yaml_load(self.args.data) if self.args.data else None + self.is_coco = False + self.class_map = None + self.targets = None + + def preprocess_batch(self, batch): + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225 + batch["bboxes"] = batch["bboxes"].to(self.device) + batch["masks"] = batch["masks"].to(self.device).float() + self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width + self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + self.lb = [self.targets[self.targets[:, 0] == i, 1:] + for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling + + return batch + + def init_metrics(self, model): + head = de_parallel(model).model[-1] + if self.data_dict: + self.is_coco = isinstance(self.data_dict.get('val'), + str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt') + self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + + self.nc = head.nc + self.nm = head.nm + self.names = model.names + if isinstance(self.names, (list, tuple)): # old format + self.names = dict(enumerate(self.names)) + + self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device) # iou vector for mAP@0.5:0.95 + self.niou = self.iouv.numel() + self.seen = 0 + self.confusion_matrix = ConfusionMatrix(nc=self.nc) + self.metrics = Metrics() + self.loss = torch.zeros(4, device=self.device) + self.jdict = [] + self.stats = [] + + def get_desc(self): + return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", + "R", "mAP50", "mAP50-95)") + + def preprocess_preds(self, preds): + p = ops.non_max_suppression(preds[0], + self.args.conf_thres, + self.args.iou_thres, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + nm=self.nm) + return (p, preds[0], preds[2]) + + def update_metrics(self, preds, batch): + # Metrics + plot_masks = [] # masks for plotting + for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): + labels = self.targets[self.targets[:, 0] == si, 1:] + nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions + shape = Path(batch["im_file"][si]) + # path = batch["shape"][si][0] + correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init + correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init + self.seen += 1 + + if npr == 0: + if nl: + self.stats.append((correct_masks, correct_bboxes, *torch.zeros( + (2, 0), device=self.device), labels[:, 0])) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0]) + continue + + # Masks + midx = [si] if self.args.overlap_mask else self.targets[:, 0] == si + gt_masks = batch["masks"][midx] + pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:]) + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = pred.clone() + ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1]) # native-space pred + + # Evaluate + if nl: + tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes + ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1]) # native-space labels + labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels + correct_bboxes = self._process_batch(predn, labelsn, self.iouv) + correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True) + if self.args.plots: + self.confusion_matrix.process_batch(predn, labelsn) + self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, + 0])) # (conf, pcls, tcls) + + pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) + if self.plots and self.batch_i < 3: + plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot + + # TODO: Save/log + ''' + if self.args.save_txt: + save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') + if self.args.save_json: + pred_masks = scale_image(im[si].shape[1:], + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1]) + save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary + # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) + ''' + + # TODO Plot images + ''' + if self.args.plots and self.batch_i < 3: + if len(plot_masks): + plot_masks = torch.cat(plot_masks, dim=0) + plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) + plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths, + save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred + ''' + + def get_stats(self): + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy + if len(stats) and stats[0].any(): + # TODO: save_dir + results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names) + self.metrics.update(results) + self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class + keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"] + metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))} + metrics |= zip(keys, self.metrics.mean_results()) + return metrics + + def print_results(self): + pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format + self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + if self.nt_per_class.sum() == 0: + self.logger.warning( + f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') + + # Print results per class + if (self.args.verbose or (self.nc < 50 and not self.training)) and self.nc > 1 and len(self.stats): + for i, c in enumerate(self.metrics.ap_class_index): + self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) + + # plot TODO: save_dir + if self.args.plots: + self.confusion_matrix.plot(save_dir='', names=list(self.names.values())) + + def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): + """ + Return correct prediction matrix + Arguments: + detections (array[N, 6]), x1, y1, x2, y2, conf, class + labels (array[M, 5]), class, x1, y1, x2, y2 + Returns: + correct (array[N, 10]), for 10 IoU levels + """ + if masks: + if overlap: + nl = len(labels) + index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 + gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) + gt_masks = torch.where(gt_masks == index, 1.0, 0.0) + if gt_masks.shape[1:] != pred_masks.shape[1:]: + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] + gt_masks = gt_masks.gt_(0.5) + iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) + else: # boxes + iou = box_iou(labels[:, 1:], detections[:, :4]) + + correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) + correct_class = labels[:, 0:1] == detections[:, 5] + for i in range(len(iouv)): + x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), + 1).cpu().numpy() # [label, detect, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=iouv.device) diff --git a/ultralytics/yolov5n-seg.yaml b/ultralytics/yolov5n-seg.yaml new file mode 100644 index 0000000..d414cad --- /dev/null +++ b/ultralytics/yolov5n-seg.yaml @@ -0,0 +1,48 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) + ] \ No newline at end of file