Segmentation support & other enchancements (#40)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>single_channel
parent
c617ee1c79
commit
f56c9bcc26
@ -1,7 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ultralytics.yolo.v8 import classify
|
from ultralytics.yolo.v8 import classify, segment
|
||||||
|
|
||||||
ROOT = Path(__file__).parents[0] # yolov8 ROOT
|
ROOT = Path(__file__).parents[0] # yolov8 ROOT
|
||||||
|
|
||||||
__all__ = ["classify"]
|
__all__ = ["classify", "segment"]
|
||||||
|
@ -0,0 +1,48 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 80 # number of classes
|
||||||
|
depth_multiple: 0.33 # model depth multiple
|
||||||
|
width_multiple: 0.25 # layer channel multiple
|
||||||
|
anchors:
|
||||||
|
- [10,13, 16,30, 33,23] # P3/8
|
||||||
|
- [30,61, 62,45, 59,119] # P4/16
|
||||||
|
- [116,90, 156,198, 373,326] # P5/32
|
||||||
|
|
||||||
|
# YOLOv5 v6.0 backbone
|
||||||
|
backbone:
|
||||||
|
# [from, number, module, args]
|
||||||
|
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||||
|
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||||
|
[-1, 3, C3, [128]],
|
||||||
|
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||||
|
[-1, 6, C3, [256]],
|
||||||
|
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||||
|
[-1, 9, C3, [512]],
|
||||||
|
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||||
|
[-1, 3, C3, [1024]],
|
||||||
|
[-1, 1, SPPF, [1024, 5]], # 9
|
||||||
|
]
|
||||||
|
|
||||||
|
# YOLOv5 v6.0 head
|
||||||
|
head:
|
||||||
|
[[-1, 1, Conv, [512, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||||
|
[-1, 3, C3, [512, False]], # 13
|
||||||
|
|
||||||
|
[-1, 1, Conv, [256, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||||
|
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||||
|
|
||||||
|
[-1, 1, Conv, [256, 3, 2]],
|
||||||
|
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||||
|
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||||
|
|
||||||
|
[-1, 1, Conv, [512, 3, 2]],
|
||||||
|
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||||
|
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||||
|
|
||||||
|
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||||
|
]
|
@ -0,0 +1,48 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 80 # number of classes
|
||||||
|
depth_multiple: 0.33 # model depth multiple
|
||||||
|
width_multiple: 0.25 # layer channel multiple
|
||||||
|
anchors:
|
||||||
|
- [10,13, 16,30, 33,23] # P3/8
|
||||||
|
- [30,61, 62,45, 59,119] # P4/16
|
||||||
|
- [116,90, 156,198, 373,326] # P5/32
|
||||||
|
|
||||||
|
# YOLOv5 v6.0 backbone
|
||||||
|
backbone:
|
||||||
|
# [from, number, module, args]
|
||||||
|
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||||
|
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||||
|
[-1, 3, C3, [128]],
|
||||||
|
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||||
|
[-1, 6, C3, [256]],
|
||||||
|
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||||
|
[-1, 9, C3, [512]],
|
||||||
|
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||||
|
[-1, 3, C3, [1024]],
|
||||||
|
[-1, 1, SPPF, [1024, 5]], # 9
|
||||||
|
]
|
||||||
|
|
||||||
|
# YOLOv5 v6.0 head
|
||||||
|
head:
|
||||||
|
[[-1, 1, Conv, [512, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||||
|
[-1, 3, C3, [512, False]], # 13
|
||||||
|
|
||||||
|
[-1, 1, Conv, [256, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||||
|
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||||
|
|
||||||
|
[-1, 1, Conv, [256, 3, 2]],
|
||||||
|
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||||
|
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||||
|
|
||||||
|
[-1, 1, Conv, [512, 3, 2]],
|
||||||
|
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||||
|
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||||
|
|
||||||
|
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||||
|
]
|
@ -0,0 +1,2 @@
|
|||||||
|
from ultralytics.yolo.v8.segment.train import SegmentationTrainer
|
||||||
|
from ultralytics.yolo.v8.segment.val import SegmentationValidator
|
@ -0,0 +1,269 @@
|
|||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ultralytics.yolo import v8
|
||||||
|
from ultralytics.yolo.data import build_dataloader
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||||
|
from ultralytics.yolo.utils.downloads import download
|
||||||
|
from ultralytics.yolo.utils.files import WorkingDirectory
|
||||||
|
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||||
|
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||||
|
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||||
|
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first
|
||||||
|
|
||||||
|
|
||||||
|
# BaseTrainer python usage
|
||||||
|
class SegmentationTrainer(BaseTrainer):
|
||||||
|
|
||||||
|
def get_dataset(self, dataset):
|
||||||
|
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
||||||
|
data = Path("datasets") / dataset
|
||||||
|
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
|
||||||
|
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
||||||
|
if not data_dir.is_dir():
|
||||||
|
self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||||
|
t = time.time()
|
||||||
|
if str(data) == 'imagenet':
|
||||||
|
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||||
|
else:
|
||||||
|
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||||
|
download(url, dir=data_dir.parent)
|
||||||
|
# TODO: add colorstr
|
||||||
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
|
||||||
|
self.console.info(s)
|
||||||
|
train_set = data_dir.parent / "coco128-seg"
|
||||||
|
test_set = train_set
|
||||||
|
return train_set, test_set
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
||||||
|
# TODO: manage splits differently
|
||||||
|
# calculate stride - check if model is initialized
|
||||||
|
gs = max(int(self.model.stride.max() if self.model else 0), 32)
|
||||||
|
loader = build_dataloader(
|
||||||
|
img_path=dataset_path,
|
||||||
|
img_size=self.args.img_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
single_cls=self.args.single_cls,
|
||||||
|
cache=self.args.cache,
|
||||||
|
image_weights=self.args.image_weights,
|
||||||
|
stride=gs,
|
||||||
|
rect=self.args.rect,
|
||||||
|
rank=rank,
|
||||||
|
workers=self.args.workers,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
use_segments=True,
|
||||||
|
)[0]
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def load_cfg(self, cfg):
|
||||||
|
return SegmentationModel(cfg, nc=80)
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
|
||||||
|
|
||||||
|
def criterion(self, preds, batch):
|
||||||
|
head = de_parallel(self.model).model[-1]
|
||||||
|
sort_obj_iou = False
|
||||||
|
autobalance = False
|
||||||
|
|
||||||
|
# init losses
|
||||||
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device))
|
||||||
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device))
|
||||||
|
|
||||||
|
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||||
|
cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets
|
||||||
|
|
||||||
|
# Focal loss
|
||||||
|
g = self.args.fl_gamma
|
||||||
|
if self.args.fl_gamma > 0:
|
||||||
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
||||||
|
|
||||||
|
balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
|
||||||
|
ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index
|
||||||
|
BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance
|
||||||
|
|
||||||
|
def single_mask_loss(gt_mask, pred, proto, xyxy, area):
|
||||||
|
# Mask loss for one image
|
||||||
|
pred_mask = (pred @ proto.view(head.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80)
|
||||||
|
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
||||||
|
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||||
|
|
||||||
|
def build_targets(p, targets):
|
||||||
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||||
|
nonlocal head
|
||||||
|
na, nt = head.na, targets.shape[0] # number of anchors, targets
|
||||||
|
tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], []
|
||||||
|
gain = torch.ones(8, device=self.device) # normalized to gridspace gain
|
||||||
|
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1,
|
||||||
|
nt) # same as .repeat_interleave(nt)
|
||||||
|
if self.args.overlap_mask:
|
||||||
|
batch = p[0].shape[0]
|
||||||
|
ti = []
|
||||||
|
for i in range(batch):
|
||||||
|
num = (targets[:, 0] == i).sum() # find number of targets of each image
|
||||||
|
ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1) # (na, num)
|
||||||
|
ti = torch.cat(ti, 1) # (na, nt)
|
||||||
|
else:
|
||||||
|
ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1)
|
||||||
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2) # append anchor indices
|
||||||
|
|
||||||
|
g = 0.5 # bias
|
||||||
|
off = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 0],
|
||||||
|
[1, 0],
|
||||||
|
[0, 1],
|
||||||
|
[-1, 0],
|
||||||
|
[0, -1], # j,k,l,m
|
||||||
|
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
|
||||||
|
],
|
||||||
|
device=self.device).float() * g # offsets
|
||||||
|
|
||||||
|
for i in range(head.nl):
|
||||||
|
anchors, shape = head.anchors[i], p[i].shape
|
||||||
|
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
|
||||||
|
|
||||||
|
# Match targets to anchors
|
||||||
|
t = targets * gain # shape(3,n,7)
|
||||||
|
if nt:
|
||||||
|
# Matches
|
||||||
|
r = t[..., 4:6] / anchors[:, None] # wh ratio
|
||||||
|
j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare
|
||||||
|
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
|
||||||
|
t = t[j] # filter
|
||||||
|
|
||||||
|
# Offsets
|
||||||
|
gxy = t[:, 2:4] # grid xy
|
||||||
|
gxi = gain[[2, 3]] - gxy # inverse
|
||||||
|
j, k = ((gxy % 1 < g) & (gxy > 1)).T
|
||||||
|
l, m = ((gxi % 1 < g) & (gxi > 1)).T
|
||||||
|
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
||||||
|
t = t.repeat((5, 1, 1))[j]
|
||||||
|
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
||||||
|
else:
|
||||||
|
t = targets[0]
|
||||||
|
offsets = 0
|
||||||
|
|
||||||
|
# Define
|
||||||
|
bc, gxy, gwh, at = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
|
||||||
|
(a, tidx), (b, c) = at.long().T, bc.long().T # anchors, image, class
|
||||||
|
gij = (gxy - offsets).long()
|
||||||
|
gi, gj = gij.T # grid indices
|
||||||
|
|
||||||
|
# Append
|
||||||
|
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
|
||||||
|
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
||||||
|
anch.append(anchors[a]) # anchors
|
||||||
|
tcls.append(c) # class
|
||||||
|
tidxs.append(tidx)
|
||||||
|
xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6]) # xywh normalized
|
||||||
|
|
||||||
|
return tcls, tbox, indices, anch, tidxs, xywhn
|
||||||
|
|
||||||
|
if self.model.training:
|
||||||
|
p, proto, = preds
|
||||||
|
else:
|
||||||
|
p, proto, train_out = preds
|
||||||
|
p = train_out
|
||||||
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
|
masks = batch["masks"]
|
||||||
|
targets, masks = targets.to(self.device), masks.to(self.device).float()
|
||||||
|
|
||||||
|
bs, nm, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||||
|
lcls = torch.zeros(1, device=self.device)
|
||||||
|
lbox = torch.zeros(1, device=self.device)
|
||||||
|
lobj = torch.zeros(1, device=self.device)
|
||||||
|
lseg = torch.zeros(1, device=self.device)
|
||||||
|
tcls, tbox, indices, anchors, tidxs, xywhn = build_targets(p, targets)
|
||||||
|
|
||||||
|
# Losses
|
||||||
|
for i, pi in enumerate(p): # layer index, layer predictions
|
||||||
|
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||||
|
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
|
||||||
|
|
||||||
|
n = b.shape[0] # number of targets
|
||||||
|
if n:
|
||||||
|
pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, head.nc, nm), 1) # subset of predictions
|
||||||
|
|
||||||
|
# Box regression
|
||||||
|
pxy = pxy.sigmoid() * 2 - 0.5
|
||||||
|
pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
|
||||||
|
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||||
|
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
|
||||||
|
lbox += (1.0 - iou).mean() # iou loss
|
||||||
|
|
||||||
|
# Objectness
|
||||||
|
iou = iou.detach().clamp(0).type(tobj.dtype)
|
||||||
|
if sort_obj_iou:
|
||||||
|
j = iou.argsort()
|
||||||
|
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
|
||||||
|
if gr < 1:
|
||||||
|
iou = (1.0 - gr) + gr * iou
|
||||||
|
tobj[b, a, gj, gi] = iou # iou ratio
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
if head.nc > 1: # cls loss (only if multiple classes)
|
||||||
|
t = torch.full_like(pcls, cn, device=self.device) # targets
|
||||||
|
t[range(n), tcls[i]] = cp
|
||||||
|
lcls += BCEcls(pcls, t) # BCE
|
||||||
|
|
||||||
|
# Mask regression
|
||||||
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||||
|
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
||||||
|
marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized
|
||||||
|
mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
|
||||||
|
for bi in b.unique():
|
||||||
|
j = b == bi # matching index
|
||||||
|
if True:
|
||||||
|
mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0)
|
||||||
|
else:
|
||||||
|
mask_gti = masks[tidxs[i]][j]
|
||||||
|
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
|
||||||
|
|
||||||
|
obji = BCEobj(pi[..., 4], tobj)
|
||||||
|
lobj += obji * balance[i] # obj loss
|
||||||
|
if autobalance:
|
||||||
|
balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item()
|
||||||
|
|
||||||
|
if autobalance:
|
||||||
|
balance = [x / balance[ssi] for x in balance]
|
||||||
|
lbox *= self.args.box
|
||||||
|
lobj *= self.args.obj
|
||||||
|
lcls *= self.args.cls
|
||||||
|
lseg *= self.args.box / bs
|
||||||
|
|
||||||
|
loss = lbox + lobj + lcls + lseg
|
||||||
|
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
return ('\n' + '%11s' * 7) % \
|
||||||
|
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
|
def train(cfg):
|
||||||
|
cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml"
|
||||||
|
cfg.data = cfg.data or "coco128-segments" # or yolo.ClassificationDataset("mnist")
|
||||||
|
trainer = SegmentationTrainer(cfg)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
CLI usage:
|
||||||
|
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640
|
||||||
|
|
||||||
|
TODO:
|
||||||
|
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
||||||
|
"""
|
||||||
|
train()
|
@ -0,0 +1,211 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
|
from ultralytics.yolo.utils import ops
|
||||||
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
|
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
|
||||||
|
fitness_segmentation, mask_iou)
|
||||||
|
from ultralytics.yolo.utils.modeling import yaml_load
|
||||||
|
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationValidator(BaseValidator):
|
||||||
|
|
||||||
|
def __init__(self, dataloader, pbar=None, logger=None, args=None):
|
||||||
|
super().__init__(dataloader, pbar, logger, args)
|
||||||
|
if self.args.save_json:
|
||||||
|
check_requirements(['pycocotools'])
|
||||||
|
self.process = ops.process_mask_upsample # more accurate
|
||||||
|
else:
|
||||||
|
self.process = ops.process_mask # faster
|
||||||
|
self.data_dict = yaml_load(self.args.data) if self.args.data else None
|
||||||
|
self.is_coco = False
|
||||||
|
self.class_map = None
|
||||||
|
self.targets = None
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225
|
||||||
|
batch["bboxes"] = batch["bboxes"].to(self.device)
|
||||||
|
batch["masks"] = batch["masks"].to(self.device).float()
|
||||||
|
self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width
|
||||||
|
self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
|
self.lb = [self.targets[self.targets[:, 0] == i, 1:]
|
||||||
|
for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def init_metrics(self, model):
|
||||||
|
head = de_parallel(model).model[-1]
|
||||||
|
if self.data_dict:
|
||||||
|
self.is_coco = isinstance(self.data_dict.get('val'),
|
||||||
|
str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt')
|
||||||
|
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
|
|
||||||
|
self.nc = head.nc
|
||||||
|
self.nm = head.nm
|
||||||
|
self.names = model.names
|
||||||
|
if isinstance(self.names, (list, tuple)): # old format
|
||||||
|
self.names = dict(enumerate(self.names))
|
||||||
|
|
||||||
|
self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device) # iou vector for mAP@0.5:0.95
|
||||||
|
self.niou = self.iouv.numel()
|
||||||
|
self.seen = 0
|
||||||
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
|
||||||
|
self.metrics = Metrics()
|
||||||
|
self.loss = torch.zeros(4, device=self.device)
|
||||||
|
self.jdict = []
|
||||||
|
self.stats = []
|
||||||
|
|
||||||
|
def get_desc(self):
|
||||||
|
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
|
||||||
|
"R", "mAP50", "mAP50-95)")
|
||||||
|
|
||||||
|
def preprocess_preds(self, preds):
|
||||||
|
p = ops.non_max_suppression(preds[0],
|
||||||
|
self.args.conf_thres,
|
||||||
|
self.args.iou_thres,
|
||||||
|
labels=self.lb,
|
||||||
|
multi_label=True,
|
||||||
|
agnostic=self.args.single_cls,
|
||||||
|
max_det=self.args.max_det,
|
||||||
|
nm=self.nm)
|
||||||
|
return (p, preds[0], preds[2])
|
||||||
|
|
||||||
|
def update_metrics(self, preds, batch):
|
||||||
|
# Metrics
|
||||||
|
plot_masks = [] # masks for plotting
|
||||||
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||||
|
labels = self.targets[self.targets[:, 0] == si, 1:]
|
||||||
|
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
|
shape = Path(batch["im_file"][si])
|
||||||
|
# path = batch["shape"][si][0]
|
||||||
|
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
|
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
|
self.seen += 1
|
||||||
|
|
||||||
|
if npr == 0:
|
||||||
|
if nl:
|
||||||
|
self.stats.append((correct_masks, correct_bboxes, *torch.zeros(
|
||||||
|
(2, 0), device=self.device), labels[:, 0]))
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Masks
|
||||||
|
midx = [si] if self.args.overlap_mask else self.targets[:, 0] == si
|
||||||
|
gt_masks = batch["masks"][midx]
|
||||||
|
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:])
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
if self.args.single_cls:
|
||||||
|
pred[:, 5] = 0
|
||||||
|
predn = pred.clone()
|
||||||
|
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1]) # native-space pred
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
if nl:
|
||||||
|
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
|
||||||
|
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1]) # native-space labels
|
||||||
|
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
||||||
|
correct_bboxes = self._process_batch(predn, labelsn, self.iouv)
|
||||||
|
correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True)
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.process_batch(predn, labelsn)
|
||||||
|
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:,
|
||||||
|
0])) # (conf, pcls, tcls)
|
||||||
|
|
||||||
|
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||||
|
if self.plots and self.batch_i < 3:
|
||||||
|
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||||
|
|
||||||
|
# TODO: Save/log
|
||||||
|
'''
|
||||||
|
if self.args.save_txt:
|
||||||
|
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
if self.args.save_json:
|
||||||
|
pred_masks = scale_image(im[si].shape[1:],
|
||||||
|
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
|
||||||
|
save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary
|
||||||
|
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
|
||||||
|
'''
|
||||||
|
|
||||||
|
# TODO Plot images
|
||||||
|
'''
|
||||||
|
if self.args.plots and self.batch_i < 3:
|
||||||
|
if len(plot_masks):
|
||||||
|
plot_masks = torch.cat(plot_masks, dim=0)
|
||||||
|
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
|
||||||
|
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
|
||||||
|
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
|
||||||
|
'''
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||||
|
if len(stats) and stats[0].any():
|
||||||
|
# TODO: save_dir
|
||||||
|
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
|
||||||
|
self.metrics.update(results)
|
||||||
|
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
|
||||||
|
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
|
||||||
|
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
|
||||||
|
metrics |= zip(keys, self.metrics.mean_results())
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format
|
||||||
|
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
|
if self.nt_per_class.sum() == 0:
|
||||||
|
self.logger.warning(
|
||||||
|
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
||||||
|
|
||||||
|
# Print results per class
|
||||||
|
if (self.args.verbose or (self.nc < 50 and not self.training)) and self.nc > 1 and len(self.stats):
|
||||||
|
for i, c in enumerate(self.metrics.ap_class_index):
|
||||||
|
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||||
|
|
||||||
|
# plot TODO: save_dir
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
|
||||||
|
|
||||||
|
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||||
|
"""
|
||||||
|
Return correct prediction matrix
|
||||||
|
Arguments:
|
||||||
|
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||||
|
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||||
|
Returns:
|
||||||
|
correct (array[N, 10]), for 10 IoU levels
|
||||||
|
"""
|
||||||
|
if masks:
|
||||||
|
if overlap:
|
||||||
|
nl = len(labels)
|
||||||
|
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
||||||
|
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
||||||
|
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
||||||
|
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
||||||
|
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
||||||
|
gt_masks = gt_masks.gt_(0.5)
|
||||||
|
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
||||||
|
else: # boxes
|
||||||
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
|
correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
|
||||||
|
correct_class = labels[:, 0:1] == detections[:, 5]
|
||||||
|
for i in range(len(iouv)):
|
||||||
|
x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||||
|
if x[0].shape[0]:
|
||||||
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
||||||
|
1).cpu().numpy() # [label, detect, iou]
|
||||||
|
if x[0].shape[0] > 1:
|
||||||
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||||
|
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||||
|
correct[matches[:, 1].astype(int), i] = True
|
||||||
|
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
@ -0,0 +1,48 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 80 # number of classes
|
||||||
|
depth_multiple: 0.33 # model depth multiple
|
||||||
|
width_multiple: 0.25 # layer channel multiple
|
||||||
|
anchors:
|
||||||
|
- [10,13, 16,30, 33,23] # P3/8
|
||||||
|
- [30,61, 62,45, 59,119] # P4/16
|
||||||
|
- [116,90, 156,198, 373,326] # P5/32
|
||||||
|
|
||||||
|
# YOLOv5 v6.0 backbone
|
||||||
|
backbone:
|
||||||
|
# [from, number, module, args]
|
||||||
|
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||||
|
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||||
|
[-1, 3, C3, [128]],
|
||||||
|
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||||
|
[-1, 6, C3, [256]],
|
||||||
|
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||||
|
[-1, 9, C3, [512]],
|
||||||
|
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||||
|
[-1, 3, C3, [1024]],
|
||||||
|
[-1, 1, SPPF, [1024, 5]], # 9
|
||||||
|
]
|
||||||
|
|
||||||
|
# YOLOv5 v6.0 head
|
||||||
|
head:
|
||||||
|
[[-1, 1, Conv, [512, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||||
|
[-1, 3, C3, [512, False]], # 13
|
||||||
|
|
||||||
|
[-1, 1, Conv, [256, 1, 1]],
|
||||||
|
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||||
|
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||||
|
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||||
|
|
||||||
|
[-1, 1, Conv, [256, 3, 2]],
|
||||||
|
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||||
|
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||||
|
|
||||||
|
[-1, 1, Conv, [512, 3, 2]],
|
||||||
|
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||||
|
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||||
|
|
||||||
|
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||||
|
]
|
Loading…
Reference in new issue