Segmentation support & other enchancements (#40)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent c617ee1c79
commit f56c9bcc26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,8 @@ jobs:
os: [ ubuntu-latest ] os: [ ubuntu-latest ]
python-version: [ '3.10' ] python-version: [ '3.10' ]
model: [ yolov5n ] model: [ yolov5n ]
include: torch: [ latest ]
# include:
# - os: ubuntu-latest # - os: ubuntu-latest
# python-version: '3.7' # '3.6.8' min # python-version: '3.7' # '3.6.8' min
# model: yolov5n # model: yolov5n
@ -31,10 +32,10 @@ jobs:
# - os: ubuntu-latest # - os: ubuntu-latest
# python-version: '3.9' # python-version: '3.9'
# model: yolov5n # model: yolov5n
- os: ubuntu-latest # - os: ubuntu-latest
python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8 # python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
model: yolov5n # model: yolov5n
torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/ # torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
@ -93,9 +94,8 @@ jobs:
- name: Test segmentation - name: Test segmentation
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
echo "TODO" python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=1 img_size=64
- name: Test classification - name: Test classification
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
echo "TODO" python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist160 epochs=1 img_size=32
# python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist2560 epochs=1 img_size=64

@ -1,6 +1,7 @@
from itertools import repeat from itertools import repeat
from multiprocessing.pool import Pool from multiprocessing.pool import Pool
from pathlib import Path from pathlib import Path
from typing import OrderedDict
import torchvision import torchvision
from tqdm import tqdm from tqdm import tqdm
@ -205,7 +206,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
else: else:
sample = self.torch_transforms(im) sample = self.torch_transforms(im)
return sample, j return OrderedDict(img=sample, cls=j)
# TODO: support semantic segmentation # TODO: support semantic segmentation

@ -1,12 +1,17 @@
""" """
Simple training loop; Boilerplate that could apply to any arbitrary neural network, Simple training loop; Boilerplate that could apply to any arbitrary neural network,
""" """
# TODOs
# 1. finish _set_model_attributes
# 2. allow num_class update for both pretrained and csv_loaded models
# 3. save
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from telnetlib import TLS
from typing import Dict, Union from typing import Dict, Union
import torch import torch
@ -52,6 +57,8 @@ class BaseTrainer:
# Model and Dataloaders. # Model and Dataloaders.
self.trainset, self.testset = self.get_dataset(self.args.data) self.trainset, self.testset = self.get_dataset(self.args.data)
if self.args.cfg is not None:
self.model = self.load_cfg(self.args.cfg)
if self.args.model is not None: if self.args.model is not None:
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
@ -133,6 +140,20 @@ class BaseTrainer:
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
self.validator = self.get_validator() self.validator = self.get_validator()
print("created testloader :", rank) print("created testloader :", rank)
self.console.info(self.progress_string())
def _set_model_attributes(self):
# TODO: fix and use after self.data_dict is available
'''
head = utils.torch_utils.de_parallel(self.model).model[-1]
self.args.box *= 3 / head.nl # scale to layers
self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names
'''
def _do_train(self, rank, world_size): def _do_train(self, rank, world_size):
if world_size > 1: if world_size > 1:
@ -153,13 +174,17 @@ class BaseTrainer:
pbar = tqdm(enumerate(self.train_loader), pbar = tqdm(enumerate(self.train_loader),
total=len(self.train_loader), total=len(self.train_loader),
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
tloss = 0 tloss = None
for i, (images, labels) in pbar: for i, batch in pbar:
# img, label (classification)/ img, targets, paths, _, masks(detection)
# callback hook. on_batch_start # callback hook. on_batch_start
# forward # forward
images, labels = self.preprocess_batch(images, labels) batch = self.preprocess_batch(batch)
self.loss = self.criterion(self.model(images), labels)
tloss = (tloss * i + self.loss.item()) / (i + 1) # TODO: warmup, multiscale
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items
# backward # backward
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
@ -170,9 +195,13 @@ class BaseTrainer:
self.trigger_callbacks('on_batch_end') self.trigger_callbacks('on_batch_end')
# log # log
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
loss_len = tloss.shape[0] if len(tloss.size()) else 1
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0)
if rank in {-1, 0}: if rank in {-1, 0}:
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 pbar.set_description(
(" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses,
batch["img"].shape[-1]))
if rank in [-1, 0]: if rank in [-1, 0]:
# validation # validation
@ -240,6 +269,9 @@ class BaseTrainer:
return model return model
def load_cfg(self, cfg):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self): def get_validator(self):
pass pass
@ -250,11 +282,11 @@ class BaseTrainer:
self.scaler.update() self.scaler.update()
self.optimizer.zero_grad() self.optimizer.zero_grad()
def preprocess_batch(self, images, labels): def preprocess_batch(self, batch):
""" """
Allows custom preprocessing model inputs and ground truths depending on task type Allows custom preprocessing model inputs and ground truths depending on task type
""" """
return images.to(self.device, non_blocking=True), labels.to(self.device) return batch
def validate(self): def validate(self):
""" """
@ -270,14 +302,17 @@ class BaseTrainer:
def build_targets(self, preds, targets): def build_targets(self, preds, targets):
pass pass
def criterion(self, preds, targets): def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor
"""
pass pass
def progress_string(self): def progress_string(self):
""" """
Returns progress string depending on task type. Returns progress string depending on task type.
""" """
pass return ''
def usage_help(self): def usage_help(self):
""" """

@ -1,8 +1,10 @@
import logging import logging
import torch import torch
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device from ultralytics.yolo.utils.torch_utils import select_device
@ -12,12 +14,15 @@ class BaseValidator:
Base validator class. Base validator class.
""" """
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None): def __init__(self, dataloader, pbar=None, logger=None, args=None):
self.dataloader = dataloader self.dataloader = dataloader
self.half = half
self.device = select_device(device, dataloader.batch_size)
self.pbar = pbar self.pbar = pbar
self.logger = logger or logging.getLogger() self.logger = logger or logging.getLogger()
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.device = select_device(self.args.device, dataloader.batch_size)
self.cuda = self.device.type != 'cpu'
self.batch_i = None
self.training = True
def __call__(self, trainer=None, model=None): def __call__(self, trainer=None, model=None):
""" """
@ -25,45 +30,48 @@ class BaseValidator:
if trainer is passed (trainer gets priority). if trainer is passed (trainer gets priority).
""" """
training = trainer is not None training = trainer is not None
self.training = training
# trainer = trainer or self.trainer_class.get_trainer() # trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation" assert training or model is not None, "Either trainer or model is needed for validation"
if training: if training:
model = trainer.model model = trainer.model
self.half &= self.device.type != 'cpu' self.args.half &= self.device.type != 'cpu'
model = model.half() if self.half else model model = model.half() if self.args.half else model
else: # TODO: handle this when detectMultiBackend is supported else: # TODO: handle this when detectMultiBackend is supported
# model = DetectMultiBacked(model) # model = DetectMultiBacked(model)
pass pass
# TODO: implement init_model_attributes()
model.eval() model.eval()
dt = Profile(), Profile(), Profile(), Profile() dt = Profile(), Profile(), Profile(), Profile()
loss = 0 loss = 0
n_batches = len(self.dataloader) n_batches = len(self.dataloader)
desc = self.set_desc() desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
self.init_metrics() self.init_metrics(model)
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'): with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
for images, labels in bar: for batch_i, batch in enumerate(bar):
self.batch_i = batch_i
# pre-process # pre-process
with dt[0]: with dt[0]:
images, labels = self.preprocess_batch(images, labels) batch = self.preprocess_batch(batch)
# inference # inference
with dt[1]: with dt[1]:
preds = model(images) preds = model(batch["img"])
# TODO: remember to add native augmentation support when implementing model, like: # TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment) # preds, train_out = model(im, augment=augment)
# loss # loss
with dt[2]: with dt[2]:
if training: if training:
loss += trainer.criterion(preds, labels) / images.shape[0] loss += trainer.criterion(preds, batch)[0]
# pre-process predictions # pre-process predictions
with dt[3]: with dt[3]:
preds = self.preprocess_preds(preds) preds = self.preprocess_preds(preds)
self.update_metrics(preds, labels) self.update_metrics(preds, batch)
stats = self.get_stats() stats = self.get_stats()
self.check_stats(stats) self.check_stats(stats)
@ -81,8 +89,8 @@ class BaseValidator:
return stats return stats
def preprocess_batch(self, images, labels): def preprocess_batch(self, batch):
return images.to(self.device, non_blocking=True), labels.to(self.device) return batch
def preprocess_preds(self, preds): def preprocess_preds(self, preds):
return preds return preds
@ -90,7 +98,7 @@ class BaseValidator:
def init_metrics(self): def init_metrics(self):
pass pass
def update_metrics(self, preds, targets): def update_metrics(self, preds, batch):
pass pass
def get_stats(self): def get_stats(self):
@ -102,5 +110,5 @@ class BaseValidator:
def print_results(self): def print_results(self):
pass pass
def set_desc(self): def get_desc(self):
pass pass

@ -4,6 +4,7 @@
# Train settings ------------------------------------------------------------------------------------------------------- # Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov5s.pt model: null # i.e. yolov5s.pt
cfg: null # i.e. yolov5s.yaml
data: null # i.e. coco128.yaml data: null # i.e. coco128.yaml
epochs: 300 epochs: 300
batch_size: 16 batch_size: 16
@ -20,6 +21,23 @@ optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False verbose: False
seed: 0 seed: 0
local_rank: -1 local_rank: -1
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
shuffle: True
rect: False # support rectangular training
overlap_mask: True # Segmentation masks overlap
mask_ratio: 4 # Segmentation mask downsample ratio
# Val/Test settings ----------------------------------------------------------------------------------------------------
save_json: False
save_hybrid: False
conf_thres: 0.001
iou_thres: 0.6
max_det: 300
half: True
plots: False
save_txt: False
task: 'val'
# Hyperparameters ------------------------------------------------------------------------------------------------------ # Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
@ -51,6 +69,7 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability) mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
label_smoothing: 0.0
# Hydra configs -------------------------------------------------------------------------------------------------------- # Hydra configs --------------------------------------------------------------------------------------------------------
hydra: hydra:

@ -2,11 +2,19 @@
""" """
Model validation metrics Model validation metrics
""" """
import math
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from ultralytics.yolo.utils import TryExcept
# boxes
def box_area(box): def box_area(box):
# box = xyxy(4,n) # box = xyxy(4,n)
return (box[2] - box[0]) * (box[3] - box[1]) return (box[2] - box[0]) * (box[3] - box[1])
@ -53,3 +61,484 @@ def box_iou(box1, box2, eps=1e-7):
# IoU = inter / (area1 + area2 - inter) # IoU = inter / (area1 + area2 - inter)
return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter + eps) return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter + eps)
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
else: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
# IoU
iou = inter / union
if CIoU or DIoU or GIoU:
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou # IoU
def mask_iou(mask1, mask2, eps=1e-7):
"""
mask1: [N, n] m1 means number of predicted objects
mask2: [M, n] m2 means number of gt objects
Note: n means image_w x image_h
return: masks iou, [N, M]
"""
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
return intersection / (union + eps)
def masks_iou(mask1, mask2, eps=1e-7):
"""
mask1: [N, n] m1 means number of predicted objects
mask2: [N, n] m2 means number of gt objects
Note: n means image_w x image_h
return: masks iou, (N, )
"""
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
return intersection / (union + eps)
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
# return positive, negative label smoothing BCE targets
return 1.0 - 0.5 * eps, 0.5 * eps
# losses
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super().__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = torch.sigmoid(pred) # prob from logits
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
def __init__(self, nc, conf=0.25, iou_thres=0.45):
self.matrix = np.zeros((nc + 1, nc + 1))
self.nc = nc # number of classes
self.conf = conf
self.iou_thres = iou_thres
def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
"""
if detections is None:
gt_classes = labels.int()
for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN
return
detections = detections[detections[:, 4] > self.conf]
gt_classes = labels[:, 0].int()
detection_classes = detections[:, 5].int()
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
else:
matches = np.zeros((0, 3))
n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(int)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:
self.matrix[detection_classes[m1[j]], gc] += 1 # correct
else:
self.matrix[self.nc, gc] += 1 # true background
if n:
for i, dc in enumerate(detection_classes):
if not any(m1 == i):
self.matrix[dc, self.nc] += 1 # predicted background
def matrix(self):
return self.matrix
def tp_fp(self):
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (names + ['background']) if labels else "auto"
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=ticklabels,
yticklabels=ticklabels).set_facecolor((1, 1, 1))
ax.set_ylabel('True')
ax.set_ylabel('Predicted')
ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close(fig)
def print(self):
for i in range(self.nc + 1):
print(' '.join(map(str, self.matrix[i])))
def fitness_detection(x):
# Model fitness as a weighted combination of metrics
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
return (x[:, :4] * w).sum(1)
def fitness_segmentation(x):
# Model fitness as a weighted combination of metrics
w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9]
return (x[:, :8] * w).sum(1)
def smooth(y, f=0.05):
# Box filter of fraction f
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = np.ones(nf // 2) # ones padding
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
# Arguments
recall: The recall curve (list)
precision: The precision curve (list)
# Returns
Average precision, precision curve, recall curve
"""
# Append sentinel values to beginning and end
mrec = np.concatenate(([0.0], recall, [1.0]))
mpre = np.concatenate(([1.0], precision, [0.0]))
# Compute the precision envelope
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
# Integrate area under curve
method = 'interp' # methods: 'continuous', 'interp'
if method == 'interp':
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
else: # 'continuous'
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=""):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
save_dir: Plot save directory
# Returns
The average precision as computed in py-faster-rcnn.
"""
# Sort by objectness
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes
unique_classes, nt = np.unique(target_cls, return_counts=True)
nc = unique_classes.shape[0] # number of classes, number of detections
# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_l = nt[ci] # number of labels
n_p = i.sum() # number of predictions
if n_p == 0 or n_l == 0:
continue
# Accumulate FPs and TPs
fpc = (1 - tp[i]).cumsum(0)
tpc = tp[i].cumsum(0)
# Recall
recall = tpc / (n_l + eps) # recall curve
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
# Precision
precision = tpc / (tpc + fpc) # precision curve
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
# AP from recall-precision curve
for j in range(tp.shape[1]):
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if plot and j == 0:
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + eps)
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
# TODO: plot
'''
if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
'''
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
def ap_per_class_box_and_mask(
tp_m,
tp_b,
conf,
pred_cls,
target_cls,
plot=False,
save_dir=".",
names=(),
):
"""
Args:
tp_b: tp of boxes.
tp_m: tp of masks.
other arguments see `func: ap_per_class`.
"""
results_boxes = ap_per_class(tp_b,
conf,
pred_cls,
target_cls,
plot=plot,
save_dir=save_dir,
names=names,
prefix="Box")[2:]
results_masks = ap_per_class(tp_m,
conf,
pred_cls,
target_cls,
plot=plot,
save_dir=save_dir,
names=names,
prefix="Mask")[2:]
results = {
"boxes": {
"p": results_boxes[0],
"r": results_boxes[1],
"ap": results_boxes[3],
"f1": results_boxes[2],
"ap_class": results_boxes[4]},
"masks": {
"p": results_masks[0],
"r": results_masks[1],
"ap": results_masks[3],
"f1": results_masks[2],
"ap_class": results_masks[4]}}
return results
class Metric:
def __init__(self) -> None:
self.p = [] # (nc, )
self.r = [] # (nc, )
self.f1 = [] # (nc, )
self.all_ap = [] # (nc, 10)
self.ap_class_index = [] # (nc, )
@property
def ap50(self):
"""AP@0.5 of all classes.
Return:
(nc, ) or [].
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@property
def ap(self):
"""AP@0.5:0.95
Return:
(nc, ) or [].
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@property
def mp(self):
"""mean precision of all classes.
Return:
float.
"""
return self.p.mean() if len(self.p) else 0.0
@property
def mr(self):
"""mean recall of all classes.
Return:
float.
"""
return self.r.mean() if len(self.r) else 0.0
@property
def map50(self):
"""Mean AP@0.5 of all classes.
Return:
float.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
Return:
float.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
def mean_results(self):
"""Mean of results, return mp, mr, map50, map"""
return (self.mp, self.mr, self.map50, self.map)
def class_result(self, i):
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
return (self.p[i], self.r[i], self.ap50[i], self.ap[i])
def get_maps(self, nc):
maps = np.zeros(nc) + self.map
for i, c in enumerate(self.ap_class_index):
maps[c] = self.ap[i]
return maps
def update(self, results):
"""
Args:
results: tuple(p, r, ap, f1, ap_class)
"""
p, r, all_ap, f1, ap_class_index = results
self.p = p
self.r = r
self.all_ap = all_ap
self.f1 = f1
self.ap_class_index = ap_class_index
class Metrics:
"""Metric for boxes and masks."""
def __init__(self) -> None:
self.metric_box = Metric()
self.metric_mask = Metric()
def update(self, results):
"""
Args:
results: Dict{'boxes': Dict{}, 'masks': Dict{}}
"""
self.metric_box.update(list(results["boxes"].values()))
self.metric_mask.update(list(results["masks"].values()))
def mean_results(self):
return self.metric_box.mean_results() + self.metric_mask.mean_results()
def class_result(self, i):
return self.metric_box.class_result(i) + self.metric_mask.class_result(i)
def get_maps(self, nc):
return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
@property
def ap_class_index(self):
# boxes and masks have the same ap_class_index
return self.metric_box.ap_class_index

@ -5,6 +5,7 @@ import time
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torchvision import torchvision
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER
@ -32,14 +33,23 @@ class Profile(contextlib.ContextDecorator):
return time.time() return time.time()
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
def segment2box(segment, width=640, height=640): def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = ( x, y, = x[inside], y[inside]
x[inside],
y[inside],
)
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy
@ -304,3 +314,63 @@ def resample_segments(segments, n=1000):
xp = np.arange(len(s)) xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
return segments return segments
def crop_mask(masks, boxes):
"""
"Crop" predicted masks by zeroing out everything not in the predicted bbox.
Vectorized by Chong (thanks Chong).
Args:
- masks should be a size [h, w, n] tensor of masks
- boxes should be a size [n, 4] tensor of bbox coords in relative point form
"""
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
Crop after upsample.
proto_out: [mask_dim, mask_h, mask_w]
out_masks: [n, mask_dim], n is number of masks after nms
bboxes: [n, 4], n is number of masks after nms
shape:input_image_size, (h, w)
return: h, w, n
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
"""
Crop before upsample.
proto_out: [mask_dim, mask_h, mask_w]
out_masks: [n, mask_dim], n is number of masks after nms
bboxes: [n, 4], n is number of masks after nms
shape:input_image_size, (h, w)
return: h, w, n
"""
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= mw / iw
downsampled_bboxes[:, 2] *= mw / iw
downsampled_bboxes[:, 3] *= mh / ih
downsampled_bboxes[:, 1] *= mh / ih
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
return masks.gt_(0.5)

@ -179,3 +179,13 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
def intersect_state_dicts(da, db, exclude=()): def intersect_state_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from ultralytics.yolo.v8 import classify from ultralytics.yolo.v8 import classify, segment
ROOT = Path(__file__).parents[0] # yolov8 ROOT ROOT = Path(__file__).parents[0] # yolov8 ROOT
__all__ = ["classify"] __all__ = ["classify", "segment"]

@ -38,13 +38,22 @@ class ClassificationTrainer(BaseTrainer):
return train_set, test_set return train_set, test_set
def get_dataloader(self, dataset_path, batch_size=None, rank=0): def get_dataloader(self, dataset_path, batch_size=None, rank=0):
return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank) return build_classification_dataloader(path=dataset_path,
imgsz=self.args.img_size,
batch_size=self.args.batch_size,
rank=rank)
def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device)
batch["cls"] = batch["cls"].to(self.device)
return batch
def get_validator(self): def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console) return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
def criterion(self, preds, targets): def criterion(self, preds, batch):
return torch.nn.functional.cross_entropy(preds, targets) loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
return loss, loss
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)

@ -5,10 +5,16 @@ from ultralytics.yolo.engine.validator import BaseValidator
class ClassificationValidator(BaseValidator): class ClassificationValidator(BaseValidator):
def init_metrics(self): def init_metrics(self, model):
self.correct = torch.tensor([]) self.correct = torch.tensor([])
def update_metrics(self, preds, targets): def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device)
batch["cls"] = batch["cls"].to(self.device)
return batch
def update_metrics(self, preds, batch):
targets = batch["cls"]
correct_in_batch = (targets[:, None] == preds).float() correct_in_batch = (targets[:, None] == preds).float()
self.correct = torch.cat((self.correct, correct_in_batch)) self.correct = torch.cat((self.correct, correct_in_batch))

@ -0,0 +1,48 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
]

@ -0,0 +1,48 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

@ -0,0 +1,2 @@
from ultralytics.yolo.v8.segment.train import SegmentationTrainer
from ultralytics.yolo.v8.segment.val import SegmentationValidator

@ -0,0 +1,269 @@
import subprocess
import time
from pathlib import Path
import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import WorkingDirectory
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first
# BaseTrainer python usage
class SegmentationTrainer(BaseTrainer):
def get_dataset(self, dataset):
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
data = Path("datasets") / dataset
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
data_dir = data if data.is_dir() else (Path.cwd() / data)
if not data_dir.is_dir():
self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time()
if str(data) == 'imagenet':
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent)
# TODO: add colorstr
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
self.console.info(s)
train_set = data_dir.parent / "coco128-seg"
test_set = train_set
return train_set, test_set
def get_dataloader(self, dataset_path, batch_size, rank=0):
# TODO: manage splits differently
# calculate stride - check if model is initialized
gs = max(int(self.model.stride.max() if self.model else 0), 32)
loader = build_dataloader(
img_path=dataset_path,
img_size=self.args.img_size,
batch_size=batch_size,
single_cls=self.args.single_cls,
cache=self.args.cache,
image_weights=self.args.image_weights,
stride=gs,
rect=self.args.rect,
rank=rank,
workers=self.args.workers,
shuffle=self.args.shuffle,
use_segments=True,
)[0]
return loader
def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
return batch
def load_cfg(self, cfg):
return SegmentationModel(cfg, nc=80)
def get_validator(self):
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
def criterion(self, preds, batch):
head = de_parallel(self.model).model[-1]
sort_obj_iou = False
autobalance = False
# init losses
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets
# Focal loss
g = self.args.fl_gamma
if self.args.fl_gamma > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index
BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance
def single_mask_loss(gt_mask, pred, proto, xyxy, area):
# Mask loss for one image
pred_mask = (pred @ proto.view(head.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
def build_targets(p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
nonlocal head
na, nt = head.na, targets.shape[0] # number of anchors, targets
tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], []
gain = torch.ones(8, device=self.device) # normalized to gridspace gain
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1,
nt) # same as .repeat_interleave(nt)
if self.args.overlap_mask:
batch = p[0].shape[0]
ti = []
for i in range(batch):
num = (targets[:, 0] == i).sum() # find number of targets of each image
ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1) # (na, num)
ti = torch.cat(ti, 1) # (na, nt)
else:
ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1)
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2) # append anchor indices
g = 0.5 # bias
off = torch.tensor(
[
[0, 0],
[1, 0],
[0, 1],
[-1, 0],
[0, -1], # j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
],
device=self.device).float() * g # offsets
for i in range(head.nl):
anchors, shape = head.anchors[i], p[i].shape
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
# Match targets to anchors
t = targets * gain # shape(3,n,7)
if nt:
# Matches
r = t[..., 4:6] / anchors[:, None] # wh ratio
j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
t = t[j] # filter
# Offsets
gxy = t[:, 2:4] # grid xy
gxi = gain[[2, 3]] - gxy # inverse
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
# Define
bc, gxy, gwh, at = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
(a, tidx), (b, c) = at.long().T, bc.long().T # anchors, image, class
gij = (gxy - offsets).long()
gi, gj = gij.T # grid indices
# Append
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
anch.append(anchors[a]) # anchors
tcls.append(c) # class
tidxs.append(tidx)
xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6]) # xywh normalized
return tcls, tbox, indices, anch, tidxs, xywhn
if self.model.training:
p, proto, = preds
else:
p, proto, train_out = preds
p = train_out
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
masks = batch["masks"]
targets, masks = targets.to(self.device), masks.to(self.device).float()
bs, nm, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
lcls = torch.zeros(1, device=self.device)
lbox = torch.zeros(1, device=self.device)
lobj = torch.zeros(1, device=self.device)
lseg = torch.zeros(1, device=self.device)
tcls, tbox, indices, anchors, tidxs, xywhn = build_targets(p, targets)
# Losses
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
n = b.shape[0] # number of targets
if n:
pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, head.nc, nm), 1) # subset of predictions
# Box regression
pxy = pxy.sigmoid() * 2 - 0.5
pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1) # predicted box
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
lbox += (1.0 - iou).mean() # iou loss
# Objectness
iou = iou.detach().clamp(0).type(tobj.dtype)
if sort_obj_iou:
j = iou.argsort()
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
if gr < 1:
iou = (1.0 - gr) + gr * iou
tobj[b, a, gj, gi] = iou # iou ratio
# Classification
if head.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(pcls, cn, device=self.device) # targets
t[range(n), tcls[i]] = cp
lcls += BCEcls(pcls, t) # BCE
# Mask regression
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized
mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
for bi in b.unique():
j = b == bi # matching index
if True:
mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0)
else:
mask_gti = masks[tidxs[i]][j]
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
obji = BCEobj(pi[..., 4], tobj)
lobj += obji * balance[i] # obj loss
if autobalance:
balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item()
if autobalance:
balance = [x / balance[ssi] for x in balance]
lbox *= self.args.box
lobj *= self.args.obj
lcls *= self.args.cls
lseg *= self.args.box / bs
loss = lbox + lobj + lcls + lseg
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
def progress_string(self):
return ('\n' + '%11s' * 7) % \
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml"
cfg.data = cfg.data or "coco128-segments" # or yolo.ClassificationDataset("mnist")
trainer = SegmentationTrainer(cfg)
trainer.train()
if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10
"""
train()

@ -0,0 +1,211 @@
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
fitness_segmentation, mask_iou)
from ultralytics.yolo.utils.modeling import yaml_load
from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationValidator(BaseValidator):
def __init__(self, dataloader, pbar=None, logger=None, args=None):
super().__init__(dataloader, pbar, logger, args)
if self.args.save_json:
check_requirements(['pycocotools'])
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
self.data_dict = yaml_load(self.args.data) if self.args.data else None
self.is_coco = False
self.class_map = None
self.targets = None
def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225
batch["bboxes"] = batch["bboxes"].to(self.device)
batch["masks"] = batch["masks"].to(self.device).float()
self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width
self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
self.lb = [self.targets[self.targets[:, 0] == i, 1:]
for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling
return batch
def init_metrics(self, model):
head = de_parallel(model).model[-1]
if self.data_dict:
self.is_coco = isinstance(self.data_dict.get('val'),
str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt')
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.nc = head.nc
self.nm = head.nm
self.names = model.names
if isinstance(self.names, (list, tuple)): # old format
self.names = dict(enumerate(self.names))
self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.seen = 0
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.metrics = Metrics()
self.loss = torch.zeros(4, device=self.device)
self.jdict = []
self.stats = []
def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
"R", "mAP50", "mAP50-95)")
def preprocess_preds(self, preds):
p = ops.non_max_suppression(preds[0],
self.args.conf_thres,
self.args.iou_thres,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nm=self.nm)
return (p, preds[0], preds[2])
def update_metrics(self, preds, batch):
# Metrics
plot_masks = [] # masks for plotting
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
labels = self.targets[self.targets[:, 0] == si, 1:]
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
shape = Path(batch["im_file"][si])
# path = batch["shape"][si][0]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_masks, correct_bboxes, *torch.zeros(
(2, 0), device=self.device), labels[:, 0]))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
continue
# Masks
midx = [si] if self.args.overlap_mask else self.targets[:, 0] == si
gt_masks = batch["masks"][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:])
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1]) # native-space pred
# Evaluate
if nl:
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn, self.iouv)
correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:,
0])) # (conf, pcls, tcls)
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.plots and self.batch_i < 3:
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# TODO: Save/log
'''
if self.args.save_txt:
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
if self.args.save_json:
pred_masks = scale_image(im[si].shape[1:],
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
'''
# TODO Plot images
'''
if self.args.plots and self.batch_i < 3:
if len(plot_masks):
plot_masks = torch.cat(plot_masks, dim=0)
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
'''
def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
# TODO: save_dir
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
self.metrics.update(results)
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
metrics |= zip(keys, self.metrics.mean_results())
return metrics
def print_results(self):
pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
self.logger.warning(
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
# Print results per class
if (self.args.verbose or (self.nc < 50 and not self.training)) and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
# plot TODO: save_dir
if self.args.plots:
self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if masks:
if overlap:
nl = len(labels)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(iouv)):
x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)

@ -0,0 +1,48 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
]
Loading…
Cancel
Save