Segmentation support & other enchancements (#40)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							
								
								
									
										16
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -21,7 +21,8 @@ jobs: | ||||
|         os: [ ubuntu-latest ] | ||||
|         python-version: [ '3.10' ] | ||||
|         model: [ yolov5n ] | ||||
|         include: | ||||
|         torch: [ latest ] | ||||
| #        include: | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.7'  # '3.6.8' min | ||||
| #            model: yolov5n | ||||
| @ -31,10 +32,10 @@ jobs: | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.9' | ||||
| #            model: yolov5n | ||||
|           - os: ubuntu-latest | ||||
|             python-version: '3.8'  # torch 1.7.0 requires python >=3.6, <=3.8 | ||||
|             model: yolov5n | ||||
|             torch: '1.7.0'  # min torch version CI https://pypi.org/project/torchvision/ | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.8'  # torch 1.7.0 requires python >=3.6, <=3.8 | ||||
| #            model: yolov5n | ||||
| #            torch: '1.7.0'  # min torch version CI https://pypi.org/project/torchvision/ | ||||
|     steps: | ||||
|       - uses: actions/checkout@v3 | ||||
|       - uses: actions/setup-python@v4 | ||||
| @ -93,9 +94,8 @@ jobs: | ||||
|       - name: Test segmentation | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           echo "TODO" | ||||
|           python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=1 img_size=64 | ||||
|       - name: Test classification | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           echo "TODO" | ||||
|           # python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist2560 epochs=1 img_size=64 | ||||
|           python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist160 epochs=1 img_size=32 | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| from itertools import repeat | ||||
| from multiprocessing.pool import Pool | ||||
| from pathlib import Path | ||||
| from typing import OrderedDict | ||||
|  | ||||
| import torchvision | ||||
| from tqdm import tqdm | ||||
| @ -205,7 +206,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): | ||||
|             sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] | ||||
|         else: | ||||
|             sample = self.torch_transforms(im) | ||||
|         return sample, j | ||||
|         return OrderedDict(img=sample, cls=j) | ||||
|  | ||||
|  | ||||
| # TODO: support semantic segmentation | ||||
|  | ||||
| @ -1,12 +1,17 @@ | ||||
| """ | ||||
| Simple training loop; Boilerplate that could apply to any arbitrary neural network, | ||||
| """ | ||||
| # TODOs | ||||
| # 1. finish _set_model_attributes | ||||
| # 2. allow num_class update for both pretrained and csv_loaded models | ||||
| # 3. save | ||||
|  | ||||
| import os | ||||
| import time | ||||
| from collections import defaultdict | ||||
| from datetime import datetime | ||||
| from pathlib import Path | ||||
| from telnetlib import TLS | ||||
| from typing import Dict, Union | ||||
|  | ||||
| import torch | ||||
| @ -52,6 +57,8 @@ class BaseTrainer: | ||||
|  | ||||
|         # Model and Dataloaders. | ||||
|         self.trainset, self.testset = self.get_dataset(self.args.data) | ||||
|         if self.args.cfg is not None: | ||||
|             self.model = self.load_cfg(self.args.cfg) | ||||
|         if self.args.model is not None: | ||||
|             self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) | ||||
|  | ||||
| @ -133,6 +140,20 @@ class BaseTrainer: | ||||
|             self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) | ||||
|             self.validator = self.get_validator() | ||||
|             print("created testloader :", rank) | ||||
|             self.console.info(self.progress_string()) | ||||
|  | ||||
|     def _set_model_attributes(self): | ||||
|         # TODO: fix and use after self.data_dict is available | ||||
|         ''' | ||||
|         head = utils.torch_utils.de_parallel(self.model).model[-1] | ||||
|         self.args.box *= 3 / head.nl  # scale to layers | ||||
|         self.args.cls *= head.nc / 80 * 3 / head.nl  # scale to classes and layers | ||||
|         self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl  # scale to image size and layers | ||||
|         model.nc = nc  # attach number of classes to model | ||||
|         model.hyp = hyp  # attach hyperparameters to model | ||||
|         model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights | ||||
|         model.names = names | ||||
|         ''' | ||||
|  | ||||
|     def _do_train(self, rank, world_size): | ||||
|         if world_size > 1: | ||||
| @ -153,13 +174,17 @@ class BaseTrainer: | ||||
|                 pbar = tqdm(enumerate(self.train_loader), | ||||
|                             total=len(self.train_loader), | ||||
|                             bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') | ||||
|             tloss = 0 | ||||
|             for i, (images, labels) in pbar: | ||||
|             tloss = None | ||||
|             for i, batch in pbar: | ||||
|                 # img, label (classification)/ img, targets, paths, _, masks(detection) | ||||
|                 # callback hook. on_batch_start | ||||
|                 # forward | ||||
|                 images, labels = self.preprocess_batch(images, labels) | ||||
|                 self.loss = self.criterion(self.model(images), labels) | ||||
|                 tloss = (tloss * i + self.loss.item()) / (i + 1) | ||||
|                 batch = self.preprocess_batch(batch) | ||||
|  | ||||
|                 # TODO: warmup, multiscale | ||||
|                 preds = self.model(batch["img"]) | ||||
|                 self.loss, self.loss_items = self.criterion(preds, batch) | ||||
|                 tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items | ||||
|  | ||||
|                 # backward | ||||
|                 self.model.zero_grad(set_to_none=True) | ||||
| @ -170,9 +195,13 @@ class BaseTrainer: | ||||
|                 self.trigger_callbacks('on_batch_end') | ||||
|  | ||||
|                 # log | ||||
|                 mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB) | ||||
|                 mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB) | ||||
|                 loss_len = tloss.shape[0] if len(tloss.size()) else 1 | ||||
|                 losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0) | ||||
|                 if rank in {-1, 0}: | ||||
|                     pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 | ||||
|                     pbar.set_description( | ||||
|                         (" {} " + "{:.3f}  " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses, | ||||
|                                                                       batch["img"].shape[-1])) | ||||
|  | ||||
|             if rank in [-1, 0]: | ||||
|                 # validation | ||||
| @ -240,6 +269,9 @@ class BaseTrainer: | ||||
|  | ||||
|         return model | ||||
|  | ||||
|     def load_cfg(self, cfg): | ||||
|         raise NotImplementedError("This task trainer doesn't support loading cfg files") | ||||
|  | ||||
|     def get_validator(self): | ||||
|         pass | ||||
|  | ||||
| @ -250,11 +282,11 @@ class BaseTrainer: | ||||
|         self.scaler.update() | ||||
|         self.optimizer.zero_grad() | ||||
|  | ||||
|     def preprocess_batch(self, images, labels): | ||||
|     def preprocess_batch(self, batch): | ||||
|         """ | ||||
|         Allows custom preprocessing model inputs and ground truths depending on task type | ||||
|         """ | ||||
|         return images.to(self.device, non_blocking=True), labels.to(self.device) | ||||
|         return batch | ||||
|  | ||||
|     def validate(self): | ||||
|         """ | ||||
| @ -270,14 +302,17 @@ class BaseTrainer: | ||||
|     def build_targets(self, preds, targets): | ||||
|         pass | ||||
|  | ||||
|     def criterion(self, preds, targets): | ||||
|     def criterion(self, preds, batch): | ||||
|         """ | ||||
|         Returns loss and individual loss items as Tensor | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def progress_string(self): | ||||
|         """ | ||||
|         Returns progress string depending on task type. | ||||
|         """ | ||||
|         pass | ||||
|         return '' | ||||
|  | ||||
|     def usage_help(self): | ||||
|         """ | ||||
|  | ||||
| @ -1,8 +1,10 @@ | ||||
| import logging | ||||
|  | ||||
| import torch | ||||
| from omegaconf import DictConfig, OmegaConf | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| from ultralytics.yolo.utils.torch_utils import select_device | ||||
|  | ||||
| @ -12,12 +14,15 @@ class BaseValidator: | ||||
|     Base validator class. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dataloader, device='', half=False, pbar=None, logger=None): | ||||
|     def __init__(self, dataloader, pbar=None, logger=None, args=None): | ||||
|         self.dataloader = dataloader | ||||
|         self.half = half | ||||
|         self.device = select_device(device, dataloader.batch_size) | ||||
|         self.pbar = pbar | ||||
|         self.logger = logger or logging.getLogger() | ||||
|         self.args = args or OmegaConf.load(DEFAULT_CONFIG) | ||||
|         self.device = select_device(self.args.device, dataloader.batch_size) | ||||
|         self.cuda = self.device.type != 'cpu' | ||||
|         self.batch_i = None | ||||
|         self.training = True | ||||
|  | ||||
|     def __call__(self, trainer=None, model=None): | ||||
|         """ | ||||
| @ -25,45 +30,48 @@ class BaseValidator: | ||||
|         if trainer is passed (trainer gets priority). | ||||
|         """ | ||||
|         training = trainer is not None | ||||
|         self.training = training | ||||
|         # trainer = trainer or self.trainer_class.get_trainer() | ||||
|         assert training or model is not None, "Either trainer or model is needed for validation" | ||||
|         if training: | ||||
|             model = trainer.model | ||||
|             self.half &= self.device.type != 'cpu' | ||||
|             model = model.half() if self.half else model | ||||
|             self.args.half &= self.device.type != 'cpu' | ||||
|             model = model.half() if self.args.half else model | ||||
|         else:  # TODO: handle this when detectMultiBackend is supported | ||||
|             # model = DetectMultiBacked(model) | ||||
|             pass | ||||
|             # TODO: implement init_model_attributes() | ||||
|  | ||||
|         model.eval() | ||||
|         dt = Profile(), Profile(), Profile(), Profile() | ||||
|         loss = 0 | ||||
|         n_batches = len(self.dataloader) | ||||
|         desc = self.set_desc() | ||||
|         desc = self.get_desc() | ||||
|         bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') | ||||
|         self.init_metrics() | ||||
|         self.init_metrics(model) | ||||
|         with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'): | ||||
|             for images, labels in bar: | ||||
|             for batch_i, batch in enumerate(bar): | ||||
|                 self.batch_i = batch_i | ||||
|                 # pre-process | ||||
|                 with dt[0]: | ||||
|                     images, labels = self.preprocess_batch(images, labels) | ||||
|                     batch = self.preprocess_batch(batch) | ||||
|  | ||||
|                 # inference | ||||
|                 with dt[1]: | ||||
|                     preds = model(images) | ||||
|                     preds = model(batch["img"]) | ||||
|                     # TODO: remember to add native augmentation support when implementing model, like: | ||||
|                     #  preds, train_out = model(im, augment=augment) | ||||
|  | ||||
|                 # loss | ||||
|                 with dt[2]: | ||||
|                     if training: | ||||
|                         loss += trainer.criterion(preds, labels) / images.shape[0] | ||||
|                         loss += trainer.criterion(preds, batch)[0] | ||||
|  | ||||
|                 # pre-process predictions | ||||
|                 with dt[3]: | ||||
|                     preds = self.preprocess_preds(preds) | ||||
|  | ||||
|                 self.update_metrics(preds, labels) | ||||
|                 self.update_metrics(preds, batch) | ||||
|  | ||||
|         stats = self.get_stats() | ||||
|         self.check_stats(stats) | ||||
| @ -81,8 +89,8 @@ class BaseValidator: | ||||
|  | ||||
|         return stats | ||||
|  | ||||
|     def preprocess_batch(self, images, labels): | ||||
|         return images.to(self.device, non_blocking=True), labels.to(self.device) | ||||
|     def preprocess_batch(self, batch): | ||||
|         return batch | ||||
|  | ||||
|     def preprocess_preds(self, preds): | ||||
|         return preds | ||||
| @ -90,7 +98,7 @@ class BaseValidator: | ||||
|     def init_metrics(self): | ||||
|         pass | ||||
|  | ||||
|     def update_metrics(self, preds, targets): | ||||
|     def update_metrics(self, preds, batch): | ||||
|         pass | ||||
|  | ||||
|     def get_stats(self): | ||||
| @ -102,5 +110,5 @@ class BaseValidator: | ||||
|     def print_results(self): | ||||
|         pass | ||||
|  | ||||
|     def set_desc(self): | ||||
|     def get_desc(self): | ||||
|         pass | ||||
|  | ||||
| @ -4,6 +4,7 @@ | ||||
|  | ||||
| # Train settings ------------------------------------------------------------------------------------------------------- | ||||
| model: null  # i.e. yolov5s.pt | ||||
| cfg: null  # i.e. yolov5s.yaml | ||||
| data: null  # i.e. coco128.yaml | ||||
| epochs: 300 | ||||
| batch_size: 16 | ||||
| @ -20,6 +21,23 @@ optimizer: 'SGD'  # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | ||||
| verbose: False | ||||
| seed: 0 | ||||
| local_rank: -1 | ||||
| single_cls: False  # train multi-class data as single-class | ||||
| image_weights: False  # use weighted image selection for training | ||||
| shuffle: True | ||||
| rect: False  # support rectangular training | ||||
| overlap_mask: True  # Segmentation masks overlap | ||||
| mask_ratio: 4  # Segmentation mask downsample ratio | ||||
|  | ||||
| # Val/Test settings ---------------------------------------------------------------------------------------------------- | ||||
| save_json: False | ||||
| save_hybrid: False | ||||
| conf_thres: 0.001 | ||||
| iou_thres: 0.6 | ||||
| max_det: 300 | ||||
| half: True | ||||
| plots: False | ||||
| save_txt: False | ||||
| task: 'val' | ||||
|  | ||||
| # Hyperparameters ------------------------------------------------------------------------------------------------------ | ||||
| lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3) | ||||
| @ -51,6 +69,7 @@ fliplr: 0.5  # image flip left-right (probability) | ||||
| mosaic: 1.0  # image mosaic (probability) | ||||
| mixup: 0.0  # image mixup (probability) | ||||
| copy_paste: 0.0  # segment copy-paste (probability) | ||||
| label_smoothing: 0.0 | ||||
|  | ||||
| # Hydra configs -------------------------------------------------------------------------------------------------------- | ||||
| hydra: | ||||
|  | ||||
| @ -2,11 +2,19 @@ | ||||
| """ | ||||
| Model validation metrics | ||||
| """ | ||||
| import math | ||||
| import warnings | ||||
| from pathlib import Path | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from ultralytics.yolo.utils import TryExcept | ||||
|  | ||||
|  | ||||
| # boxes | ||||
| def box_area(box): | ||||
|     # box = xyxy(4,n) | ||||
|     return (box[2] - box[0]) * (box[3] - box[1]) | ||||
| @ -53,3 +61,484 @@ def box_iou(box1, box2, eps=1e-7): | ||||
|  | ||||
|     # IoU = inter / (area1 + area2 - inter) | ||||
|     return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter + eps) | ||||
|  | ||||
|  | ||||
| def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): | ||||
|     # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4) | ||||
|  | ||||
|     # Get the coordinates of bounding boxes | ||||
|     if xywh:  # transform from xywh to xyxy | ||||
|         (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1) | ||||
|         w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 | ||||
|         b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ | ||||
|         b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ | ||||
|     else:  # x1, y1, x2, y2 = box1 | ||||
|         b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1) | ||||
|         b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1) | ||||
|         w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps | ||||
|         w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps | ||||
|  | ||||
|     # Intersection area | ||||
|     inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ | ||||
|             (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) | ||||
|  | ||||
|     # Union Area | ||||
|     union = w1 * h1 + w2 * h2 - inter + eps | ||||
|  | ||||
|     # IoU | ||||
|     iou = inter / union | ||||
|     if CIoU or DIoU or GIoU: | ||||
|         cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width | ||||
|         ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height | ||||
|         if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 | ||||
|             c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared | ||||
|             rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2 | ||||
|             if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 | ||||
|                 v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) | ||||
|                 with torch.no_grad(): | ||||
|                     alpha = v / (v - iou + (1 + eps)) | ||||
|                 return iou - (rho2 / c2 + v * alpha)  # CIoU | ||||
|             return iou - rho2 / c2  # DIoU | ||||
|         c_area = cw * ch + eps  # convex area | ||||
|         return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf | ||||
|     return iou  # IoU | ||||
|  | ||||
|  | ||||
| def mask_iou(mask1, mask2, eps=1e-7): | ||||
|     """ | ||||
|     mask1: [N, n] m1 means number of predicted objects | ||||
|     mask2: [M, n] m2 means number of gt objects | ||||
|     Note: n means image_w x image_h | ||||
|     return: masks iou, [N, M] | ||||
|     """ | ||||
|     intersection = torch.matmul(mask1, mask2.t()).clamp(0) | ||||
|     union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection  # (area1 + area2) - intersection | ||||
|     return intersection / (union + eps) | ||||
|  | ||||
|  | ||||
| def masks_iou(mask1, mask2, eps=1e-7): | ||||
|     """ | ||||
|     mask1: [N, n] m1 means number of predicted objects | ||||
|     mask2: [N, n] m2 means number of gt objects | ||||
|     Note: n means image_w x image_h | ||||
|     return: masks iou, (N, ) | ||||
|     """ | ||||
|     intersection = (mask1 * mask2).sum(1).clamp(0)  # (N, ) | ||||
|     union = (mask1.sum(1) + mask2.sum(1))[None] - intersection  # (area1 + area2) - intersection | ||||
|     return intersection / (union + eps) | ||||
|  | ||||
|  | ||||
| def smooth_BCE(eps=0.1):  # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 | ||||
|     # return positive, negative label smoothing BCE targets | ||||
|     return 1.0 - 0.5 * eps, 0.5 * eps | ||||
|  | ||||
|  | ||||
| # losses | ||||
| class FocalLoss(nn.Module): | ||||
|     # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) | ||||
|     def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): | ||||
|         super().__init__() | ||||
|         self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss() | ||||
|         self.gamma = gamma | ||||
|         self.alpha = alpha | ||||
|         self.reduction = loss_fcn.reduction | ||||
|         self.loss_fcn.reduction = 'none'  # required to apply FL to each element | ||||
|  | ||||
|     def forward(self, pred, true): | ||||
|         loss = self.loss_fcn(pred, true) | ||||
|         # p_t = torch.exp(-loss) | ||||
|         # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability | ||||
|  | ||||
|         # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py | ||||
|         pred_prob = torch.sigmoid(pred)  # prob from logits | ||||
|         p_t = true * pred_prob + (1 - true) * (1 - pred_prob) | ||||
|         alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) | ||||
|         modulating_factor = (1.0 - p_t) ** self.gamma | ||||
|         loss *= alpha_factor * modulating_factor | ||||
|  | ||||
|         if self.reduction == 'mean': | ||||
|             return loss.mean() | ||||
|         elif self.reduction == 'sum': | ||||
|             return loss.sum() | ||||
|         else:  # 'none' | ||||
|             return loss | ||||
|  | ||||
|  | ||||
| class ConfusionMatrix: | ||||
|     # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix | ||||
|     def __init__(self, nc, conf=0.25, iou_thres=0.45): | ||||
|         self.matrix = np.zeros((nc + 1, nc + 1)) | ||||
|         self.nc = nc  # number of classes | ||||
|         self.conf = conf | ||||
|         self.iou_thres = iou_thres | ||||
|  | ||||
|     def process_batch(self, detections, labels): | ||||
|         """ | ||||
|         Return intersection-over-union (Jaccard index) of boxes. | ||||
|         Both sets of boxes are expected to be in (x1, y1, x2, y2) format. | ||||
|         Arguments: | ||||
|             detections (Array[N, 6]), x1, y1, x2, y2, conf, class | ||||
|             labels (Array[M, 5]), class, x1, y1, x2, y2 | ||||
|         Returns: | ||||
|             None, updates confusion matrix accordingly | ||||
|         """ | ||||
|         if detections is None: | ||||
|             gt_classes = labels.int() | ||||
|             for gc in gt_classes: | ||||
|                 self.matrix[self.nc, gc] += 1  # background FN | ||||
|             return | ||||
|  | ||||
|         detections = detections[detections[:, 4] > self.conf] | ||||
|         gt_classes = labels[:, 0].int() | ||||
|         detection_classes = detections[:, 5].int() | ||||
|         iou = box_iou(labels[:, 1:], detections[:, :4]) | ||||
|  | ||||
|         x = torch.where(iou > self.iou_thres) | ||||
|         if x[0].shape[0]: | ||||
|             matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() | ||||
|             if x[0].shape[0] > 1: | ||||
|                 matches = matches[matches[:, 2].argsort()[::-1]] | ||||
|                 matches = matches[np.unique(matches[:, 1], return_index=True)[1]] | ||||
|                 matches = matches[matches[:, 2].argsort()[::-1]] | ||||
|                 matches = matches[np.unique(matches[:, 0], return_index=True)[1]] | ||||
|         else: | ||||
|             matches = np.zeros((0, 3)) | ||||
|  | ||||
|         n = matches.shape[0] > 0 | ||||
|         m0, m1, _ = matches.transpose().astype(int) | ||||
|         for i, gc in enumerate(gt_classes): | ||||
|             j = m0 == i | ||||
|             if n and sum(j) == 1: | ||||
|                 self.matrix[detection_classes[m1[j]], gc] += 1  # correct | ||||
|             else: | ||||
|                 self.matrix[self.nc, gc] += 1  # true background | ||||
|  | ||||
|         if n: | ||||
|             for i, dc in enumerate(detection_classes): | ||||
|                 if not any(m1 == i): | ||||
|                     self.matrix[dc, self.nc] += 1  # predicted background | ||||
|  | ||||
|     def matrix(self): | ||||
|         return self.matrix | ||||
|  | ||||
|     def tp_fp(self): | ||||
|         tp = self.matrix.diagonal()  # true positives | ||||
|         fp = self.matrix.sum(1) - tp  # false positives | ||||
|         # fn = self.matrix.sum(0) - tp  # false negatives (missed detections) | ||||
|         return tp[:-1], fp[:-1]  # remove background class | ||||
|  | ||||
|     @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure') | ||||
|     def plot(self, normalize=True, save_dir='', names=()): | ||||
|         import seaborn as sn | ||||
|  | ||||
|         array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1)  # normalize columns | ||||
|         array[array < 0.005] = np.nan  # don't annotate (would appear as 0.00) | ||||
|  | ||||
|         fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) | ||||
|         nc, nn = self.nc, len(names)  # number of classes, names | ||||
|         sn.set(font_scale=1.0 if nc < 50 else 0.8)  # for label size | ||||
|         labels = (0 < nn < 99) and (nn == nc)  # apply names to ticklabels | ||||
|         ticklabels = (names + ['background']) if labels else "auto" | ||||
|         with warnings.catch_warnings(): | ||||
|             warnings.simplefilter('ignore')  # suppress empty matrix RuntimeWarning: All-NaN slice encountered | ||||
|             sn.heatmap(array, | ||||
|                        ax=ax, | ||||
|                        annot=nc < 30, | ||||
|                        annot_kws={ | ||||
|                            "size": 8}, | ||||
|                        cmap='Blues', | ||||
|                        fmt='.2f', | ||||
|                        square=True, | ||||
|                        vmin=0.0, | ||||
|                        xticklabels=ticklabels, | ||||
|                        yticklabels=ticklabels).set_facecolor((1, 1, 1)) | ||||
|         ax.set_ylabel('True') | ||||
|         ax.set_ylabel('Predicted') | ||||
|         ax.set_title('Confusion Matrix') | ||||
|         fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) | ||||
|         plt.close(fig) | ||||
|  | ||||
|     def print(self): | ||||
|         for i in range(self.nc + 1): | ||||
|             print(' '.join(map(str, self.matrix[i]))) | ||||
|  | ||||
|  | ||||
| def fitness_detection(x): | ||||
|     # Model fitness as a weighted combination of metrics | ||||
|     w = [0.0, 0.0, 0.1, 0.9]  # weights for [P, R, mAP@0.5, mAP@0.5:0.95] | ||||
|     return (x[:, :4] * w).sum(1) | ||||
|  | ||||
|  | ||||
| def fitness_segmentation(x): | ||||
|     # Model fitness as a weighted combination of metrics | ||||
|     w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9] | ||||
|     return (x[:, :8] * w).sum(1) | ||||
|  | ||||
|  | ||||
| def smooth(y, f=0.05): | ||||
|     # Box filter of fraction f | ||||
|     nf = round(len(y) * f * 2) // 2 + 1  # number of filter elements (must be odd) | ||||
|     p = np.ones(nf // 2)  # ones padding | ||||
|     yp = np.concatenate((p * y[0], y, p * y[-1]), 0)  # y padded | ||||
|     return np.convolve(yp, np.ones(nf) / nf, mode='valid')  # y-smoothed | ||||
|  | ||||
|  | ||||
| def compute_ap(recall, precision): | ||||
|     """ Compute the average precision, given the recall and precision curves | ||||
|     # Arguments | ||||
|         recall:    The recall curve (list) | ||||
|         precision: The precision curve (list) | ||||
|     # Returns | ||||
|         Average precision, precision curve, recall curve | ||||
|     """ | ||||
|  | ||||
|     # Append sentinel values to beginning and end | ||||
|     mrec = np.concatenate(([0.0], recall, [1.0])) | ||||
|     mpre = np.concatenate(([1.0], precision, [0.0])) | ||||
|  | ||||
|     # Compute the precision envelope | ||||
|     mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) | ||||
|  | ||||
|     # Integrate area under curve | ||||
|     method = 'interp'  # methods: 'continuous', 'interp' | ||||
|     if method == 'interp': | ||||
|         x = np.linspace(0, 1, 101)  # 101-point interp (COCO) | ||||
|         ap = np.trapz(np.interp(x, mrec, mpre), x)  # integrate | ||||
|     else:  # 'continuous' | ||||
|         i = np.where(mrec[1:] != mrec[:-1])[0]  # points where x axis (recall) changes | ||||
|         ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])  # area under curve | ||||
|  | ||||
|     return ap, mpre, mrec | ||||
|  | ||||
|  | ||||
| def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=""): | ||||
|     """ Compute the average precision, given the recall and precision curves. | ||||
|     Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||||
|     # Arguments | ||||
|         tp:  True positives (nparray, nx1 or nx10). | ||||
|         conf:  Objectness value from 0-1 (nparray). | ||||
|         pred_cls:  Predicted object classes (nparray). | ||||
|         target_cls:  True object classes (nparray). | ||||
|         plot:  Plot precision-recall curve at mAP@0.5 | ||||
|         save_dir:  Plot save directory | ||||
|     # Returns | ||||
|         The average precision as computed in py-faster-rcnn. | ||||
|     """ | ||||
|  | ||||
|     # Sort by objectness | ||||
|     i = np.argsort(-conf) | ||||
|     tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] | ||||
|  | ||||
|     # Find unique classes | ||||
|     unique_classes, nt = np.unique(target_cls, return_counts=True) | ||||
|     nc = unique_classes.shape[0]  # number of classes, number of detections | ||||
|  | ||||
|     # Create Precision-Recall curve and compute AP for each class | ||||
|     px, py = np.linspace(0, 1, 1000), []  # for plotting | ||||
|     ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) | ||||
|     for ci, c in enumerate(unique_classes): | ||||
|         i = pred_cls == c | ||||
|         n_l = nt[ci]  # number of labels | ||||
|         n_p = i.sum()  # number of predictions | ||||
|         if n_p == 0 or n_l == 0: | ||||
|             continue | ||||
|  | ||||
|         # Accumulate FPs and TPs | ||||
|         fpc = (1 - tp[i]).cumsum(0) | ||||
|         tpc = tp[i].cumsum(0) | ||||
|  | ||||
|         # Recall | ||||
|         recall = tpc / (n_l + eps)  # recall curve | ||||
|         r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0)  # negative x, xp because xp decreases | ||||
|  | ||||
|         # Precision | ||||
|         precision = tpc / (tpc + fpc)  # precision curve | ||||
|         p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1)  # p at pr_score | ||||
|  | ||||
|         # AP from recall-precision curve | ||||
|         for j in range(tp.shape[1]): | ||||
|             ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) | ||||
|             if plot and j == 0: | ||||
|                 py.append(np.interp(px, mrec, mpre))  # precision at mAP@0.5 | ||||
|  | ||||
|     # Compute F1 (harmonic mean of precision and recall) | ||||
|     f1 = 2 * p * r / (p + r + eps) | ||||
|     names = [v for k, v in names.items() if k in unique_classes]  # list: only classes that have data | ||||
|     names = dict(enumerate(names))  # to dict | ||||
|     # TODO: plot | ||||
|     ''' | ||||
|     if plot: | ||||
|         plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names) | ||||
|         plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1') | ||||
|         plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision') | ||||
|         plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall') | ||||
|     ''' | ||||
|  | ||||
|     i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index | ||||
|     p, r, f1 = p[:, i], r[:, i], f1[:, i] | ||||
|     tp = (r * nt).round()  # true positives | ||||
|     fp = (tp / (p + eps) - tp).round()  # false positives | ||||
|     return tp, fp, p, r, f1, ap, unique_classes.astype(int) | ||||
|  | ||||
|  | ||||
| def ap_per_class_box_and_mask( | ||||
|         tp_m, | ||||
|         tp_b, | ||||
|         conf, | ||||
|         pred_cls, | ||||
|         target_cls, | ||||
|         plot=False, | ||||
|         save_dir=".", | ||||
|         names=(), | ||||
| ): | ||||
|     """ | ||||
|     Args: | ||||
|         tp_b: tp of boxes. | ||||
|         tp_m: tp of masks. | ||||
|         other arguments see `func: ap_per_class`. | ||||
|     """ | ||||
|     results_boxes = ap_per_class(tp_b, | ||||
|                                  conf, | ||||
|                                  pred_cls, | ||||
|                                  target_cls, | ||||
|                                  plot=plot, | ||||
|                                  save_dir=save_dir, | ||||
|                                  names=names, | ||||
|                                  prefix="Box")[2:] | ||||
|     results_masks = ap_per_class(tp_m, | ||||
|                                  conf, | ||||
|                                  pred_cls, | ||||
|                                  target_cls, | ||||
|                                  plot=plot, | ||||
|                                  save_dir=save_dir, | ||||
|                                  names=names, | ||||
|                                  prefix="Mask")[2:] | ||||
|  | ||||
|     results = { | ||||
|         "boxes": { | ||||
|             "p": results_boxes[0], | ||||
|             "r": results_boxes[1], | ||||
|             "ap": results_boxes[3], | ||||
|             "f1": results_boxes[2], | ||||
|             "ap_class": results_boxes[4]}, | ||||
|         "masks": { | ||||
|             "p": results_masks[0], | ||||
|             "r": results_masks[1], | ||||
|             "ap": results_masks[3], | ||||
|             "f1": results_masks[2], | ||||
|             "ap_class": results_masks[4]}} | ||||
|     return results | ||||
|  | ||||
|  | ||||
| class Metric: | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.p = []  # (nc, ) | ||||
|         self.r = []  # (nc, ) | ||||
|         self.f1 = []  # (nc, ) | ||||
|         self.all_ap = []  # (nc, 10) | ||||
|         self.ap_class_index = []  # (nc, ) | ||||
|  | ||||
|     @property | ||||
|     def ap50(self): | ||||
|         """AP@0.5 of all classes. | ||||
|         Return: | ||||
|             (nc, ) or []. | ||||
|         """ | ||||
|         return self.all_ap[:, 0] if len(self.all_ap) else [] | ||||
|  | ||||
|     @property | ||||
|     def ap(self): | ||||
|         """AP@0.5:0.95 | ||||
|         Return: | ||||
|             (nc, ) or []. | ||||
|         """ | ||||
|         return self.all_ap.mean(1) if len(self.all_ap) else [] | ||||
|  | ||||
|     @property | ||||
|     def mp(self): | ||||
|         """mean precision of all classes. | ||||
|         Return: | ||||
|             float. | ||||
|         """ | ||||
|         return self.p.mean() if len(self.p) else 0.0 | ||||
|  | ||||
|     @property | ||||
|     def mr(self): | ||||
|         """mean recall of all classes. | ||||
|         Return: | ||||
|             float. | ||||
|         """ | ||||
|         return self.r.mean() if len(self.r) else 0.0 | ||||
|  | ||||
|     @property | ||||
|     def map50(self): | ||||
|         """Mean AP@0.5 of all classes. | ||||
|         Return: | ||||
|             float. | ||||
|         """ | ||||
|         return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 | ||||
|  | ||||
|     @property | ||||
|     def map(self): | ||||
|         """Mean AP@0.5:0.95 of all classes. | ||||
|         Return: | ||||
|             float. | ||||
|         """ | ||||
|         return self.all_ap.mean() if len(self.all_ap) else 0.0 | ||||
|  | ||||
|     def mean_results(self): | ||||
|         """Mean of results, return mp, mr, map50, map""" | ||||
|         return (self.mp, self.mr, self.map50, self.map) | ||||
|  | ||||
|     def class_result(self, i): | ||||
|         """class-aware result, return p[i], r[i], ap50[i], ap[i]""" | ||||
|         return (self.p[i], self.r[i], self.ap50[i], self.ap[i]) | ||||
|  | ||||
|     def get_maps(self, nc): | ||||
|         maps = np.zeros(nc) + self.map | ||||
|         for i, c in enumerate(self.ap_class_index): | ||||
|             maps[c] = self.ap[i] | ||||
|         return maps | ||||
|  | ||||
|     def update(self, results): | ||||
|         """ | ||||
|         Args: | ||||
|             results: tuple(p, r, ap, f1, ap_class) | ||||
|         """ | ||||
|         p, r, all_ap, f1, ap_class_index = results | ||||
|         self.p = p | ||||
|         self.r = r | ||||
|         self.all_ap = all_ap | ||||
|         self.f1 = f1 | ||||
|         self.ap_class_index = ap_class_index | ||||
|  | ||||
|  | ||||
| class Metrics: | ||||
|     """Metric for boxes and masks.""" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.metric_box = Metric() | ||||
|         self.metric_mask = Metric() | ||||
|  | ||||
|     def update(self, results): | ||||
|         """ | ||||
|         Args: | ||||
|             results: Dict{'boxes': Dict{}, 'masks': Dict{}} | ||||
|         """ | ||||
|         self.metric_box.update(list(results["boxes"].values())) | ||||
|         self.metric_mask.update(list(results["masks"].values())) | ||||
|  | ||||
|     def mean_results(self): | ||||
|         return self.metric_box.mean_results() + self.metric_mask.mean_results() | ||||
|  | ||||
|     def class_result(self, i): | ||||
|         return self.metric_box.class_result(i) + self.metric_mask.class_result(i) | ||||
|  | ||||
|     def get_maps(self, nc): | ||||
|         return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc) | ||||
|  | ||||
|     @property | ||||
|     def ap_class_index(self): | ||||
|         # boxes and masks have the same ap_class_index | ||||
|         return self.metric_box.ap_class_index | ||||
|  | ||||
| @ -5,6 +5,7 @@ import time | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torchvision | ||||
|  | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
| @ -32,14 +33,23 @@ class Profile(contextlib.ContextDecorator): | ||||
|         return time.time() | ||||
|  | ||||
|  | ||||
| def coco80_to_coco91_class():  # converts 80-index (val2014) to 91-index (paper) | ||||
|     # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ | ||||
|     # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') | ||||
|     # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') | ||||
|     # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco | ||||
|     # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]  # coco to darknet | ||||
|     return [ | ||||
|         1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, | ||||
|         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, | ||||
|         64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] | ||||
|  | ||||
|  | ||||
| def segment2box(segment, width=640, height=640): | ||||
|     # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) | ||||
|     x, y = segment.T  # segment xy | ||||
|     inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) | ||||
|     x, y, = ( | ||||
|         x[inside], | ||||
|         y[inside], | ||||
|     ) | ||||
|     x, y, = x[inside], y[inside] | ||||
|     return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4)  # xyxy | ||||
|  | ||||
|  | ||||
| @ -304,3 +314,63 @@ def resample_segments(segments, n=1000): | ||||
|         xp = np.arange(len(s)) | ||||
|         segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T  # segment xy | ||||
|     return segments | ||||
|  | ||||
|  | ||||
| def crop_mask(masks, boxes): | ||||
|     """ | ||||
|     "Crop" predicted masks by zeroing out everything not in the predicted bbox. | ||||
|     Vectorized by Chong (thanks Chong). | ||||
|     Args: | ||||
|         - masks should be a size [h, w, n] tensor of masks | ||||
|         - boxes should be a size [n, 4] tensor of bbox coords in relative point form | ||||
|     """ | ||||
|  | ||||
|     n, h, w = masks.shape | ||||
|     x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(1,1,n) | ||||
|     r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,w,1) | ||||
|     c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(h,1,1) | ||||
|  | ||||
|     return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) | ||||
|  | ||||
|  | ||||
| def process_mask_upsample(protos, masks_in, bboxes, shape): | ||||
|     """ | ||||
|     Crop after upsample. | ||||
|     proto_out: [mask_dim, mask_h, mask_w] | ||||
|     out_masks: [n, mask_dim], n is number of masks after nms | ||||
|     bboxes: [n, 4], n is number of masks after nms | ||||
|     shape:input_image_size, (h, w) | ||||
|     return: h, w, n | ||||
|     """ | ||||
|  | ||||
|     c, mh, mw = protos.shape  # CHW | ||||
|     masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) | ||||
|     masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW | ||||
|     masks = crop_mask(masks, bboxes)  # CHW | ||||
|     return masks.gt_(0.5) | ||||
|  | ||||
|  | ||||
| def process_mask(protos, masks_in, bboxes, shape, upsample=False): | ||||
|     """ | ||||
|     Crop before upsample. | ||||
|     proto_out: [mask_dim, mask_h, mask_w] | ||||
|     out_masks: [n, mask_dim], n is number of masks after nms | ||||
|     bboxes: [n, 4], n is number of masks after nms | ||||
|     shape:input_image_size, (h, w) | ||||
|     return: h, w, n | ||||
|     """ | ||||
|  | ||||
|     c, mh, mw = protos.shape  # CHW | ||||
|     ih, iw = shape | ||||
|     masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)  # CHW | ||||
|  | ||||
|     downsampled_bboxes = bboxes.clone() | ||||
|     downsampled_bboxes[:, 0] *= mw / iw | ||||
|     downsampled_bboxes[:, 2] *= mw / iw | ||||
|     downsampled_bboxes[:, 3] *= mh / ih | ||||
|     downsampled_bboxes[:, 1] *= mh / ih | ||||
|  | ||||
|     masks = crop_mask(masks, downsampled_bboxes)  # CHW | ||||
|     if upsample: | ||||
|         masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW | ||||
|     return masks.gt_(0.5) | ||||
|  | ||||
| @ -179,3 +179,13 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): | ||||
| def intersect_state_dicts(da, db, exclude=()): | ||||
|     # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values | ||||
|     return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} | ||||
|  | ||||
|  | ||||
| def is_parallel(model): | ||||
|     # Returns True if model is of type DP or DDP | ||||
|     return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) | ||||
|  | ||||
|  | ||||
| def de_parallel(model): | ||||
|     # De-parallelize a model: returns single-GPU model if model is of type DP or DDP | ||||
|     return model.module if is_parallel(model) else model | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| from pathlib import Path | ||||
|  | ||||
| from ultralytics.yolo.v8 import classify | ||||
| from ultralytics.yolo.v8 import classify, segment | ||||
|  | ||||
| ROOT = Path(__file__).parents[0]  # yolov8 ROOT | ||||
|  | ||||
| __all__ = ["classify"] | ||||
| __all__ = ["classify", "segment"] | ||||
|  | ||||
| @ -38,13 +38,22 @@ class ClassificationTrainer(BaseTrainer): | ||||
|         return train_set, test_set | ||||
|  | ||||
|     def get_dataloader(self, dataset_path, batch_size=None, rank=0): | ||||
|         return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank) | ||||
|         return build_classification_dataloader(path=dataset_path, | ||||
|                                                imgsz=self.args.img_size, | ||||
|                                                batch_size=self.args.batch_size, | ||||
|                                                rank=rank) | ||||
|  | ||||
|     def preprocess_batch(self, batch): | ||||
|         batch["img"] = batch["img"].to(self.device) | ||||
|         batch["cls"] = batch["cls"].to(self.device) | ||||
|         return batch | ||||
|  | ||||
|     def get_validator(self): | ||||
|         return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console) | ||||
|  | ||||
|     def criterion(self, preds, targets): | ||||
|         return torch.nn.functional.cross_entropy(preds, targets) | ||||
|     def criterion(self, preds, batch): | ||||
|         loss = torch.nn.functional.cross_entropy(preds, batch["cls"]) | ||||
|         return loss, loss | ||||
|  | ||||
|  | ||||
| @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) | ||||
|  | ||||
| @ -5,10 +5,16 @@ from ultralytics.yolo.engine.validator import BaseValidator | ||||
|  | ||||
| class ClassificationValidator(BaseValidator): | ||||
|  | ||||
|     def init_metrics(self): | ||||
|     def init_metrics(self, model): | ||||
|         self.correct = torch.tensor([]) | ||||
|  | ||||
|     def update_metrics(self, preds, targets): | ||||
|     def preprocess_batch(self, batch): | ||||
|         batch["img"] = batch["img"].to(self.device) | ||||
|         batch["cls"] = batch["cls"].to(self.device) | ||||
|         return batch | ||||
|  | ||||
|     def update_metrics(self, preds, batch): | ||||
|         targets = batch["cls"] | ||||
|         correct_in_batch = (targets[:, None] == preds).float() | ||||
|         self.correct = torch.cat((self.correct, correct_in_batch)) | ||||
|  | ||||
|  | ||||
							
								
								
									
										48
									
								
								ultralytics/yolo/v8/models/yolov5n-seg.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								ultralytics/yolo/v8/models/yolov5n-seg.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,48 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
|  | ||||
| # Parameters | ||||
| nc: 80  # number of classes | ||||
| depth_multiple: 0.33  # model depth multiple | ||||
| width_multiple: 0.25  # layer channel multiple | ||||
| anchors: | ||||
|   - [10,13, 16,30, 33,23]  # P3/8 | ||||
|   - [30,61, 62,45, 59,119]  # P4/16 | ||||
|   - [116,90, 156,198, 373,326]  # P5/32 | ||||
|  | ||||
| # YOLOv5 v6.0 backbone | ||||
| backbone: | ||||
|   # [from, number, module, args] | ||||
|   [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2 | ||||
|    [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4 | ||||
|    [-1, 3, C3, [128]], | ||||
|    [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8 | ||||
|    [-1, 6, C3, [256]], | ||||
|    [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16 | ||||
|    [-1, 9, C3, [512]], | ||||
|    [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32 | ||||
|    [-1, 3, C3, [1024]], | ||||
|    [-1, 1, SPPF, [1024, 5]],  # 9 | ||||
|   ] | ||||
|  | ||||
| # YOLOv5 v6.0 head | ||||
| head: | ||||
|   [[-1, 1, Conv, [512, 1, 1]], | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 6], 1, Concat, [1]],  # cat backbone P4 | ||||
|    [-1, 3, C3, [512, False]],  # 13 | ||||
|  | ||||
|    [-1, 1, Conv, [256, 1, 1]], | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 4], 1, Concat, [1]],  # cat backbone P3 | ||||
|    [-1, 3, C3, [256, False]],  # 17 (P3/8-small) | ||||
|  | ||||
|    [-1, 1, Conv, [256, 3, 2]], | ||||
|    [[-1, 14], 1, Concat, [1]],  # cat head P4 | ||||
|    [-1, 3, C3, [512, False]],  # 20 (P4/16-medium) | ||||
|  | ||||
|    [-1, 1, Conv, [512, 3, 2]], | ||||
|    [[-1, 10], 1, Concat, [1]],  # cat head P5 | ||||
|    [-1, 3, C3, [1024, False]],  # 23 (P5/32-large) | ||||
|  | ||||
|    [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]],  # Detect(P3, P4, P5) | ||||
|   ] | ||||
							
								
								
									
										48
									
								
								ultralytics/yolo/v8/models/yolov5n.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								ultralytics/yolo/v8/models/yolov5n.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,48 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
|  | ||||
| # Parameters | ||||
| nc: 80  # number of classes | ||||
| depth_multiple: 0.33  # model depth multiple | ||||
| width_multiple: 0.25  # layer channel multiple | ||||
| anchors: | ||||
|   - [10,13, 16,30, 33,23]  # P3/8 | ||||
|   - [30,61, 62,45, 59,119]  # P4/16 | ||||
|   - [116,90, 156,198, 373,326]  # P5/32 | ||||
|  | ||||
| # YOLOv5 v6.0 backbone | ||||
| backbone: | ||||
|   # [from, number, module, args] | ||||
|   [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2 | ||||
|    [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4 | ||||
|    [-1, 3, C3, [128]], | ||||
|    [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8 | ||||
|    [-1, 6, C3, [256]], | ||||
|    [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16 | ||||
|    [-1, 9, C3, [512]], | ||||
|    [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32 | ||||
|    [-1, 3, C3, [1024]], | ||||
|    [-1, 1, SPPF, [1024, 5]],  # 9 | ||||
|   ] | ||||
|  | ||||
| # YOLOv5 v6.0 head | ||||
| head: | ||||
|   [[-1, 1, Conv, [512, 1, 1]], | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 6], 1, Concat, [1]],  # cat backbone P4 | ||||
|    [-1, 3, C3, [512, False]],  # 13 | ||||
|  | ||||
|    [-1, 1, Conv, [256, 1, 1]], | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 4], 1, Concat, [1]],  # cat backbone P3 | ||||
|    [-1, 3, C3, [256, False]],  # 17 (P3/8-small) | ||||
|  | ||||
|    [-1, 1, Conv, [256, 3, 2]], | ||||
|    [[-1, 14], 1, Concat, [1]],  # cat head P4 | ||||
|    [-1, 3, C3, [512, False]],  # 20 (P4/16-medium) | ||||
|  | ||||
|    [-1, 1, Conv, [512, 3, 2]], | ||||
|    [[-1, 10], 1, Concat, [1]],  # cat head P5 | ||||
|    [-1, 3, C3, [1024, False]],  # 23 (P5/32-large) | ||||
|  | ||||
|    [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5) | ||||
|   ] | ||||
							
								
								
									
										2
									
								
								ultralytics/yolo/v8/segment/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								ultralytics/yolo/v8/segment/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| from ultralytics.yolo.v8.segment.train import SegmentationTrainer | ||||
| from ultralytics.yolo.v8.segment.val import SegmentationValidator | ||||
							
								
								
									
										269
									
								
								ultralytics/yolo/v8/segment/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										269
									
								
								ultralytics/yolo/v8/segment/train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,269 @@ | ||||
| import subprocess | ||||
| import time | ||||
| from pathlib import Path | ||||
|  | ||||
| import hydra | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from ultralytics.yolo import v8 | ||||
| from ultralytics.yolo.data import build_dataloader | ||||
| from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer | ||||
| from ultralytics.yolo.utils.downloads import download | ||||
| from ultralytics.yolo.utils.files import WorkingDirectory | ||||
| from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE | ||||
| from ultralytics.yolo.utils.modeling.tasks import SegmentationModel | ||||
| from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy | ||||
| from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first | ||||
|  | ||||
|  | ||||
| # BaseTrainer python usage | ||||
| class SegmentationTrainer(BaseTrainer): | ||||
|  | ||||
|     def get_dataset(self, dataset): | ||||
|         # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module | ||||
|         data = Path("datasets") / dataset | ||||
|         with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): | ||||
|             data_dir = data if data.is_dir() else (Path.cwd() / data) | ||||
|             if not data_dir.is_dir(): | ||||
|                 self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') | ||||
|                 t = time.time() | ||||
|                 if str(data) == 'imagenet': | ||||
|                     subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|                 else: | ||||
|                     url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|                     download(url, dir=data_dir.parent) | ||||
|                 # TODO: add colorstr | ||||
|                 s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" | ||||
|                 self.console.info(s) | ||||
|         train_set = data_dir.parent / "coco128-seg" | ||||
|         test_set = train_set | ||||
|         return train_set, test_set | ||||
|  | ||||
|     def get_dataloader(self, dataset_path, batch_size, rank=0): | ||||
|         # TODO: manage splits differently | ||||
|         # calculate stride - check if model is initialized | ||||
|         gs = max(int(self.model.stride.max() if self.model else 0), 32) | ||||
|         loader = build_dataloader( | ||||
|             img_path=dataset_path, | ||||
|             img_size=self.args.img_size, | ||||
|             batch_size=batch_size, | ||||
|             single_cls=self.args.single_cls, | ||||
|             cache=self.args.cache, | ||||
|             image_weights=self.args.image_weights, | ||||
|             stride=gs, | ||||
|             rect=self.args.rect, | ||||
|             rank=rank, | ||||
|             workers=self.args.workers, | ||||
|             shuffle=self.args.shuffle, | ||||
|             use_segments=True, | ||||
|         )[0] | ||||
|         return loader | ||||
|  | ||||
|     def preprocess_batch(self, batch): | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 | ||||
|         return batch | ||||
|  | ||||
|     def load_cfg(self, cfg): | ||||
|         return SegmentationModel(cfg, nc=80) | ||||
|  | ||||
|     def get_validator(self): | ||||
|         return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console) | ||||
|  | ||||
|     def criterion(self, preds, batch): | ||||
|         head = de_parallel(self.model).model[-1] | ||||
|         sort_obj_iou = False | ||||
|         autobalance = False | ||||
|  | ||||
|         # init losses | ||||
|         BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device)) | ||||
|         BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device)) | ||||
|  | ||||
|         # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 | ||||
|         cp, cn = smooth_BCE(eps=self.args.label_smoothing)  # positive, negative BCE targets | ||||
|  | ||||
|         # Focal loss | ||||
|         g = self.args.fl_gamma | ||||
|         if self.args.fl_gamma > 0: | ||||
|             BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) | ||||
|  | ||||
|         balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7 | ||||
|         ssi = list(head.stride).index(16) if autobalance else 0  # stride 16 index | ||||
|         BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance | ||||
|  | ||||
|         def single_mask_loss(gt_mask, pred, proto, xyxy, area): | ||||
|             # Mask loss for one image | ||||
|             pred_mask = (pred @ proto.view(head.nm, -1)).view(-1, *proto.shape[1:])  # (n,32) @ (32,80,80) -> (n,80,80) | ||||
|             loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") | ||||
|             return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() | ||||
|  | ||||
|         def build_targets(p, targets): | ||||
|             # Build targets for compute_loss(), input targets(image,class,x,y,w,h) | ||||
|             nonlocal head | ||||
|             na, nt = head.na, targets.shape[0]  # number of anchors, targets | ||||
|             tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], [] | ||||
|             gain = torch.ones(8, device=self.device)  # normalized to gridspace gain | ||||
|             ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, | ||||
|                                                                                  nt)  # same as .repeat_interleave(nt) | ||||
|             if self.args.overlap_mask: | ||||
|                 batch = p[0].shape[0] | ||||
|                 ti = [] | ||||
|                 for i in range(batch): | ||||
|                     num = (targets[:, 0] == i).sum()  # find number of targets of each image | ||||
|                     ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1)  # (na, num) | ||||
|                 ti = torch.cat(ti, 1)  # (na, nt) | ||||
|             else: | ||||
|                 ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1) | ||||
|             targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2)  # append anchor indices | ||||
|  | ||||
|             g = 0.5  # bias | ||||
|             off = torch.tensor( | ||||
|                 [ | ||||
|                     [0, 0], | ||||
|                     [1, 0], | ||||
|                     [0, 1], | ||||
|                     [-1, 0], | ||||
|                     [0, -1],  # j,k,l,m | ||||
|                     # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm | ||||
|                 ], | ||||
|                 device=self.device).float() * g  # offsets | ||||
|  | ||||
|             for i in range(head.nl): | ||||
|                 anchors, shape = head.anchors[i], p[i].shape | ||||
|                 gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain | ||||
|  | ||||
|                 # Match targets to anchors | ||||
|                 t = targets * gain  # shape(3,n,7) | ||||
|                 if nt: | ||||
|                     # Matches | ||||
|                     r = t[..., 4:6] / anchors[:, None]  # wh ratio | ||||
|                     j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t  # compare | ||||
|                     # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) | ||||
|                     t = t[j]  # filter | ||||
|  | ||||
|                     # Offsets | ||||
|                     gxy = t[:, 2:4]  # grid xy | ||||
|                     gxi = gain[[2, 3]] - gxy  # inverse | ||||
|                     j, k = ((gxy % 1 < g) & (gxy > 1)).T | ||||
|                     l, m = ((gxi % 1 < g) & (gxi > 1)).T | ||||
|                     j = torch.stack((torch.ones_like(j), j, k, l, m)) | ||||
|                     t = t.repeat((5, 1, 1))[j] | ||||
|                     offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] | ||||
|                 else: | ||||
|                     t = targets[0] | ||||
|                     offsets = 0 | ||||
|  | ||||
|                 # Define | ||||
|                 bc, gxy, gwh, at = t.chunk(4, 1)  # (image, class), grid xy, grid wh, anchors | ||||
|                 (a, tidx), (b, c) = at.long().T, bc.long().T  # anchors, image, class | ||||
|                 gij = (gxy - offsets).long() | ||||
|                 gi, gj = gij.T  # grid indices | ||||
|  | ||||
|                 # Append | ||||
|                 indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, grid | ||||
|                 tbox.append(torch.cat((gxy - gij, gwh), 1))  # box | ||||
|                 anch.append(anchors[a])  # anchors | ||||
|                 tcls.append(c)  # class | ||||
|                 tidxs.append(tidx) | ||||
|                 xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6])  # xywh normalized | ||||
|  | ||||
|             return tcls, tbox, indices, anch, tidxs, xywhn | ||||
|  | ||||
|         if self.model.training: | ||||
|             p, proto, = preds | ||||
|         else: | ||||
|             p, proto, train_out = preds | ||||
|             p = train_out | ||||
|         targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) | ||||
|         masks = batch["masks"] | ||||
|         targets, masks = targets.to(self.device), masks.to(self.device).float() | ||||
|  | ||||
|         bs, nm, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width | ||||
|         lcls = torch.zeros(1, device=self.device) | ||||
|         lbox = torch.zeros(1, device=self.device) | ||||
|         lobj = torch.zeros(1, device=self.device) | ||||
|         lseg = torch.zeros(1, device=self.device) | ||||
|         tcls, tbox, indices, anchors, tidxs, xywhn = build_targets(p, targets) | ||||
|  | ||||
|         # Losses | ||||
|         for i, pi in enumerate(p):  # layer index, layer predictions | ||||
|             b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx | ||||
|             tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device)  # target obj | ||||
|  | ||||
|             n = b.shape[0]  # number of targets | ||||
|             if n: | ||||
|                 pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, head.nc, nm), 1)  # subset of predictions | ||||
|  | ||||
|                 # Box regression | ||||
|                 pxy = pxy.sigmoid() * 2 - 0.5 | ||||
|                 pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] | ||||
|                 pbox = torch.cat((pxy, pwh), 1)  # predicted box | ||||
|                 iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()  # iou(prediction, target) | ||||
|                 lbox += (1.0 - iou).mean()  # iou loss | ||||
|  | ||||
|                 # Objectness | ||||
|                 iou = iou.detach().clamp(0).type(tobj.dtype) | ||||
|                 if sort_obj_iou: | ||||
|                     j = iou.argsort() | ||||
|                     b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j] | ||||
|                 if gr < 1: | ||||
|                     iou = (1.0 - gr) + gr * iou | ||||
|                 tobj[b, a, gj, gi] = iou  # iou ratio | ||||
|  | ||||
|                 # Classification | ||||
|                 if head.nc > 1:  # cls loss (only if multiple classes) | ||||
|                     t = torch.full_like(pcls, cn, device=self.device)  # targets | ||||
|                     t[range(n), tcls[i]] = cp | ||||
|                     lcls += BCEcls(pcls, t)  # BCE | ||||
|  | ||||
|                 # Mask regression | ||||
|                 if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample | ||||
|                     masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] | ||||
|                 marea = xywhn[i][:, 2:].prod(1)  # mask width, height normalized | ||||
|                 mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)) | ||||
|                 for bi in b.unique(): | ||||
|                     j = b == bi  # matching index | ||||
|                     if True: | ||||
|                         mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0) | ||||
|                     else: | ||||
|                         mask_gti = masks[tidxs[i]][j] | ||||
|                     lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j]) | ||||
|  | ||||
|             obji = BCEobj(pi[..., 4], tobj) | ||||
|             lobj += obji * balance[i]  # obj loss | ||||
|             if autobalance: | ||||
|                 balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item() | ||||
|  | ||||
|         if autobalance: | ||||
|             balance = [x / balance[ssi] for x in balance] | ||||
|         lbox *= self.args.box | ||||
|         lobj *= self.args.obj | ||||
|         lcls *= self.args.cls | ||||
|         lseg *= self.args.box / bs | ||||
|  | ||||
|         loss = lbox + lobj + lcls + lseg | ||||
|         return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach() | ||||
|  | ||||
|     def progress_string(self): | ||||
|         return ('\n' + '%11s' * 7) % \ | ||||
|                ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size') | ||||
|  | ||||
|  | ||||
| @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) | ||||
| def train(cfg): | ||||
|     cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml" | ||||
|     cfg.data = cfg.data or "coco128-segments"  # or yolo.ClassificationDataset("mnist") | ||||
|     trainer = SegmentationTrainer(cfg) | ||||
|     trainer.train() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     """ | ||||
|     CLI usage: | ||||
|     python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 | ||||
|  | ||||
|     TODO: | ||||
|     Direct cli support, i.e, yolov8 classify_train args.epochs 10 | ||||
|     """ | ||||
|     train() | ||||
							
								
								
									
										211
									
								
								ultralytics/yolo/v8/segment/val.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										211
									
								
								ultralytics/yolo/v8/segment/val.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,211 @@ | ||||
| import os | ||||
| from pathlib import Path | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from ultralytics.yolo.engine.validator import BaseValidator | ||||
| from ultralytics.yolo.utils import ops | ||||
| from ultralytics.yolo.utils.checks import check_requirements | ||||
| from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou, | ||||
|                                             fitness_segmentation, mask_iou) | ||||
| from ultralytics.yolo.utils.modeling import yaml_load | ||||
| from ultralytics.yolo.utils.torch_utils import de_parallel | ||||
|  | ||||
|  | ||||
| class SegmentationValidator(BaseValidator): | ||||
|  | ||||
|     def __init__(self, dataloader, pbar=None, logger=None, args=None): | ||||
|         super().__init__(dataloader, pbar, logger, args) | ||||
|         if self.args.save_json: | ||||
|             check_requirements(['pycocotools']) | ||||
|             self.process = ops.process_mask_upsample  # more accurate | ||||
|         else: | ||||
|             self.process = ops.process_mask  # faster | ||||
|         self.data_dict = yaml_load(self.args.data) if self.args.data else None | ||||
|         self.is_coco = False | ||||
|         self.class_map = None | ||||
|         self.targets = None | ||||
|  | ||||
|     def preprocess_batch(self, batch): | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True) | ||||
|         batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225 | ||||
|         batch["bboxes"] = batch["bboxes"].to(self.device) | ||||
|         batch["masks"] = batch["masks"].to(self.device).float() | ||||
|         self.nb, _, self.height, self.width = batch["img"].shape  # batch size, channels, height, width | ||||
|         self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) | ||||
|         self.lb = [self.targets[self.targets[:, 0] == i, 1:] | ||||
|                    for i in range(self.nb)] if self.args.save_hybrid else []  # for autolabelling | ||||
|  | ||||
|         return batch | ||||
|  | ||||
|     def init_metrics(self, model): | ||||
|         head = de_parallel(model).model[-1] | ||||
|         if self.data_dict: | ||||
|             self.is_coco = isinstance(self.data_dict.get('val'), | ||||
|                                       str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt') | ||||
|             self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||
|  | ||||
|         self.nc = head.nc | ||||
|         self.nm = head.nm | ||||
|         self.names = model.names | ||||
|         if isinstance(self.names, (list, tuple)):  # old format | ||||
|             self.names = dict(enumerate(self.names)) | ||||
|  | ||||
|         self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device)  # iou vector for mAP@0.5:0.95 | ||||
|         self.niou = self.iouv.numel() | ||||
|         self.seen = 0 | ||||
|         self.confusion_matrix = ConfusionMatrix(nc=self.nc) | ||||
|         self.metrics = Metrics() | ||||
|         self.loss = torch.zeros(4, device=self.device) | ||||
|         self.jdict = [] | ||||
|         self.stats = [] | ||||
|  | ||||
|     def get_desc(self): | ||||
|         return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", | ||||
|                                          "R", "mAP50", "mAP50-95)") | ||||
|  | ||||
|     def preprocess_preds(self, preds): | ||||
|         p = ops.non_max_suppression(preds[0], | ||||
|                                     self.args.conf_thres, | ||||
|                                     self.args.iou_thres, | ||||
|                                     labels=self.lb, | ||||
|                                     multi_label=True, | ||||
|                                     agnostic=self.args.single_cls, | ||||
|                                     max_det=self.args.max_det, | ||||
|                                     nm=self.nm) | ||||
|         return (p, preds[0], preds[2]) | ||||
|  | ||||
|     def update_metrics(self, preds, batch): | ||||
|         # Metrics | ||||
|         plot_masks = []  # masks for plotting | ||||
|         for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): | ||||
|             labels = self.targets[self.targets[:, 0] == si, 1:] | ||||
|             nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions | ||||
|             shape = Path(batch["im_file"][si]) | ||||
|             # path = batch["shape"][si][0] | ||||
|             correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init | ||||
|             correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init | ||||
|             self.seen += 1 | ||||
|  | ||||
|             if npr == 0: | ||||
|                 if nl: | ||||
|                     self.stats.append((correct_masks, correct_bboxes, *torch.zeros( | ||||
|                         (2, 0), device=self.device), labels[:, 0])) | ||||
|                     if self.args.plots: | ||||
|                         self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0]) | ||||
|                 continue | ||||
|  | ||||
|             # Masks | ||||
|             midx = [si] if self.args.overlap_mask else self.targets[:, 0] == si | ||||
|             gt_masks = batch["masks"][midx] | ||||
|             pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:]) | ||||
|  | ||||
|             # Predictions | ||||
|             if self.args.single_cls: | ||||
|                 pred[:, 5] = 0 | ||||
|             predn = pred.clone() | ||||
|             ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1])  # native-space pred | ||||
|  | ||||
|             # Evaluate | ||||
|             if nl: | ||||
|                 tbox = ops.xywh2xyxy(labels[:, 1:5])  # target boxes | ||||
|                 ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1])  # native-space labels | ||||
|                 labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels | ||||
|                 correct_bboxes = self._process_batch(predn, labelsn, self.iouv) | ||||
|                 correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True) | ||||
|                 if self.args.plots: | ||||
|                     self.confusion_matrix.process_batch(predn, labelsn) | ||||
|             self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, | ||||
|                                                                                              0]))  # (conf, pcls, tcls) | ||||
|  | ||||
|             pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) | ||||
|             if self.plots and self.batch_i < 3: | ||||
|                 plot_masks.append(pred_masks[:15].cpu())  # filter top 15 to plot | ||||
|  | ||||
|             # TODO: Save/log | ||||
|             ''' | ||||
|             if self.args.save_txt: | ||||
|                 save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') | ||||
|             if self.args.save_json: | ||||
|                 pred_masks = scale_image(im[si].shape[1:], | ||||
|                                          pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1]) | ||||
|                 save_one_json(predn, jdict, path, class_map, pred_masks)  # append to COCO-JSON dictionary | ||||
|             # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) | ||||
|             ''' | ||||
|  | ||||
|         # TODO Plot images | ||||
|         ''' | ||||
|         if self.args.plots and self.batch_i < 3: | ||||
|             if len(plot_masks): | ||||
|                 plot_masks = torch.cat(plot_masks, dim=0) | ||||
|             plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) | ||||
|             plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths, | ||||
|                                   save_dir / f'val_batch{batch_i}_pred.jpg', names)  # pred | ||||
|         ''' | ||||
|  | ||||
|     def get_stats(self): | ||||
|         stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)]  # to numpy | ||||
|         if len(stats) and stats[0].any(): | ||||
|             # TODO: save_dir | ||||
|             results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names) | ||||
|             self.metrics.update(results) | ||||
|         self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc)  # number of targets per class | ||||
|         keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"] | ||||
|         metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))} | ||||
|         metrics |= zip(keys, self.metrics.mean_results()) | ||||
|         return metrics | ||||
|  | ||||
|     def print_results(self): | ||||
|         pf = '%22s' + '%11i' * 2 + '%11.3g' * 8  # print format | ||||
|         self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) | ||||
|         if self.nt_per_class.sum() == 0: | ||||
|             self.logger.warning( | ||||
|                 f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') | ||||
|  | ||||
|         # Print results per class | ||||
|         if (self.args.verbose or (self.nc < 50 and not self.training)) and self.nc > 1 and len(self.stats): | ||||
|             for i, c in enumerate(self.metrics.ap_class_index): | ||||
|                 self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) | ||||
|  | ||||
|         # plot TODO: save_dir | ||||
|         if self.args.plots: | ||||
|             self.confusion_matrix.plot(save_dir='', names=list(self.names.values())) | ||||
|  | ||||
|     def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): | ||||
|         """ | ||||
|         Return correct prediction matrix | ||||
|         Arguments: | ||||
|             detections (array[N, 6]), x1, y1, x2, y2, conf, class | ||||
|             labels (array[M, 5]), class, x1, y1, x2, y2 | ||||
|         Returns: | ||||
|             correct (array[N, 10]), for 10 IoU levels | ||||
|         """ | ||||
|         if masks: | ||||
|             if overlap: | ||||
|                 nl = len(labels) | ||||
|                 index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 | ||||
|                 gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640) | ||||
|                 gt_masks = torch.where(gt_masks == index, 1.0, 0.0) | ||||
|             if gt_masks.shape[1:] != pred_masks.shape[1:]: | ||||
|                 gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] | ||||
|                 gt_masks = gt_masks.gt_(0.5) | ||||
|             iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) | ||||
|         else:  # boxes | ||||
|             iou = box_iou(labels[:, 1:], detections[:, :4]) | ||||
|  | ||||
|         correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) | ||||
|         correct_class = labels[:, 0:1] == detections[:, 5] | ||||
|         for i in range(len(iouv)): | ||||
|             x = torch.where((iou >= iouv[i]) & correct_class)  # IoU > threshold and classes match | ||||
|             if x[0].shape[0]: | ||||
|                 matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), | ||||
|                                     1).cpu().numpy()  # [label, detect, iou] | ||||
|                 if x[0].shape[0] > 1: | ||||
|                     matches = matches[matches[:, 2].argsort()[::-1]] | ||||
|                     matches = matches[np.unique(matches[:, 1], return_index=True)[1]] | ||||
|                     # matches = matches[matches[:, 2].argsort()[::-1]] | ||||
|                     matches = matches[np.unique(matches[:, 0], return_index=True)[1]] | ||||
|                 correct[matches[:, 1].astype(int), i] = True | ||||
|         return torch.tensor(correct, dtype=torch.bool, device=iouv.device) | ||||
							
								
								
									
										48
									
								
								ultralytics/yolov5n-seg.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								ultralytics/yolov5n-seg.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,48 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
|  | ||||
| # Parameters | ||||
| nc: 80  # number of classes | ||||
| depth_multiple: 0.33  # model depth multiple | ||||
| width_multiple: 0.25  # layer channel multiple | ||||
| anchors: | ||||
|   - [10,13, 16,30, 33,23]  # P3/8 | ||||
|   - [30,61, 62,45, 59,119]  # P4/16 | ||||
|   - [116,90, 156,198, 373,326]  # P5/32 | ||||
|  | ||||
| # YOLOv5 v6.0 backbone | ||||
| backbone: | ||||
|   # [from, number, module, args] | ||||
|   [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2 | ||||
|    [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4 | ||||
|    [-1, 3, C3, [128]], | ||||
|    [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8 | ||||
|    [-1, 6, C3, [256]], | ||||
|    [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16 | ||||
|    [-1, 9, C3, [512]], | ||||
|    [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32 | ||||
|    [-1, 3, C3, [1024]], | ||||
|    [-1, 1, SPPF, [1024, 5]],  # 9 | ||||
|   ] | ||||
|  | ||||
| # YOLOv5 v6.0 head | ||||
| head: | ||||
|   [[-1, 1, Conv, [512, 1, 1]], | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 6], 1, Concat, [1]],  # cat backbone P4 | ||||
|    [-1, 3, C3, [512, False]],  # 13 | ||||
|  | ||||
|    [-1, 1, Conv, [256, 1, 1]], | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 4], 1, Concat, [1]],  # cat backbone P3 | ||||
|    [-1, 3, C3, [256, False]],  # 17 (P3/8-small) | ||||
|  | ||||
|    [-1, 1, Conv, [256, 3, 2]], | ||||
|    [[-1, 14], 1, Concat, [1]],  # cat head P4 | ||||
|    [-1, 3, C3, [512, False]],  # 20 (P4/16-medium) | ||||
|  | ||||
|    [-1, 1, Conv, [512, 3, 2]], | ||||
|    [[-1, 10], 1, Concat, [1]],  # cat head P5 | ||||
|    [-1, 3, C3, [1024, False]],  # 23 (P5/32-large) | ||||
|  | ||||
|    [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]],  # Detect(P3, P4, P5) | ||||
|   ] | ||||
		Reference in New Issue
	
	Block a user