|
|
@ -6,9 +6,10 @@ import torch.nn.functional as F
|
|
|
|
from ultralytics.nn.tasks import SegmentationModel
|
|
|
|
from ultralytics.nn.tasks import SegmentationModel
|
|
|
|
from ultralytics.yolo import v8
|
|
|
|
from ultralytics.yolo import v8
|
|
|
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
|
|
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
|
|
|
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
|
|
|
from ultralytics.yolo.utils.loss import BboxLoss
|
|
|
|
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
|
|
|
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|
|
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
|
|
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
|
|
|
|
|
|
|
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
|
|
|
|
from ultralytics.yolo.utils.torch_utils import de_parallel
|
|
|
|
from ultralytics.yolo.utils.torch_utils import de_parallel
|
|
|
|
|
|
|
|
|
|
|
|
from ..detect import DetectionTrainer
|
|
|
|
from ..detect import DetectionTrainer
|
|
|
@ -31,188 +32,9 @@ class SegmentationTrainer(DetectionTrainer):
|
|
|
|
args=self.args)
|
|
|
|
args=self.args)
|
|
|
|
|
|
|
|
|
|
|
|
def criterion(self, preds, batch):
|
|
|
|
def criterion(self, preds, batch):
|
|
|
|
head = de_parallel(self.model).model[-1]
|
|
|
|
if not hasattr(self, 'compute_loss'):
|
|
|
|
sort_obj_iou = False
|
|
|
|
self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
|
|
|
|
autobalance = False
|
|
|
|
return self.compute_loss(preds, batch)
|
|
|
|
|
|
|
|
|
|
|
|
# init losses
|
|
|
|
|
|
|
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device))
|
|
|
|
|
|
|
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
|
|
|
|
|
|
|
cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Focal loss
|
|
|
|
|
|
|
|
g = self.args.fl_gamma
|
|
|
|
|
|
|
|
if self.args.fl_gamma > 0:
|
|
|
|
|
|
|
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
|
|
|
|
|
|
|
|
ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index
|
|
|
|
|
|
|
|
BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def single_mask_loss(gt_mask, pred, proto, xyxy, area):
|
|
|
|
|
|
|
|
# Mask loss for one image
|
|
|
|
|
|
|
|
pred_mask = (pred @ proto.view(head.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80)
|
|
|
|
|
|
|
|
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
|
|
|
|
|
|
|
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_targets(p, targets):
|
|
|
|
|
|
|
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
|
|
|
|
|
|
|
nonlocal head
|
|
|
|
|
|
|
|
na, nt = head.na, targets.shape[0] # number of anchors, targets
|
|
|
|
|
|
|
|
tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], []
|
|
|
|
|
|
|
|
gain = torch.ones(8, device=self.device) # normalized to gridspace gain
|
|
|
|
|
|
|
|
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1,
|
|
|
|
|
|
|
|
nt) # same as .repeat_interleave(nt)
|
|
|
|
|
|
|
|
if self.args.overlap_mask:
|
|
|
|
|
|
|
|
batch = p[0].shape[0]
|
|
|
|
|
|
|
|
ti = []
|
|
|
|
|
|
|
|
for i in range(batch):
|
|
|
|
|
|
|
|
num = (targets[:, 0] == i).sum() # find number of targets of each image
|
|
|
|
|
|
|
|
ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1) # (na, num)
|
|
|
|
|
|
|
|
ti = torch.cat(ti, 1) # (na, nt)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1)
|
|
|
|
|
|
|
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2) # append anchor indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g = 0.5 # bias
|
|
|
|
|
|
|
|
off = torch.tensor(
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
[0, 0],
|
|
|
|
|
|
|
|
[1, 0],
|
|
|
|
|
|
|
|
[0, 1],
|
|
|
|
|
|
|
|
[-1, 0],
|
|
|
|
|
|
|
|
[0, -1], # j,k,l,m
|
|
|
|
|
|
|
|
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
device=self.device).float() * g # offsets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(head.nl):
|
|
|
|
|
|
|
|
anchors, shape = head.anchors[i], p[i].shape
|
|
|
|
|
|
|
|
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Match targets to anchors
|
|
|
|
|
|
|
|
t = targets * gain # shape(3,n,7)
|
|
|
|
|
|
|
|
if nt:
|
|
|
|
|
|
|
|
# Matches
|
|
|
|
|
|
|
|
r = t[..., 4:6] / anchors[:, None] # wh ratio
|
|
|
|
|
|
|
|
j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare
|
|
|
|
|
|
|
|
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
|
|
|
|
|
|
|
|
t = t[j] # filter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Offsets
|
|
|
|
|
|
|
|
gxy = t[:, 2:4] # grid xy
|
|
|
|
|
|
|
|
gxi = gain[[2, 3]] - gxy # inverse
|
|
|
|
|
|
|
|
j, k = ((gxy % 1 < g) & (gxy > 1)).T
|
|
|
|
|
|
|
|
l, m = ((gxi % 1 < g) & (gxi > 1)).T
|
|
|
|
|
|
|
|
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
|
|
|
|
|
|
|
t = t.repeat((5, 1, 1))[j]
|
|
|
|
|
|
|
|
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
t = targets[0]
|
|
|
|
|
|
|
|
offsets = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Define
|
|
|
|
|
|
|
|
bc, gxy, gwh, at = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
|
|
|
|
|
|
|
|
(a, tidx), (b, c) = at.long().T, bc.long().T # anchors, image, class
|
|
|
|
|
|
|
|
gij = (gxy - offsets).long()
|
|
|
|
|
|
|
|
gi, gj = gij.T # grid indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Append
|
|
|
|
|
|
|
|
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
|
|
|
|
|
|
|
|
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
|
|
|
|
|
|
|
anch.append(anchors[a]) # anchors
|
|
|
|
|
|
|
|
tcls.append(c) # class
|
|
|
|
|
|
|
|
tidxs.append(tidx)
|
|
|
|
|
|
|
|
xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6]) # xywh normalized
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tcls, tbox, indices, anch, tidxs, xywhn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(preds) == 2: # eval
|
|
|
|
|
|
|
|
p, proto, = preds
|
|
|
|
|
|
|
|
else: # len(3) train
|
|
|
|
|
|
|
|
_, proto, p = preds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
|
|
|
|
|
|
masks = batch["masks"]
|
|
|
|
|
|
|
|
targets, masks = targets.to(self.device), masks.to(self.device).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bs, nm, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
|
|
|
|
|
|
|
lcls = torch.zeros(1, device=self.device)
|
|
|
|
|
|
|
|
lbox = torch.zeros(1, device=self.device)
|
|
|
|
|
|
|
|
lobj = torch.zeros(1, device=self.device)
|
|
|
|
|
|
|
|
lseg = torch.zeros(1, device=self.device)
|
|
|
|
|
|
|
|
tcls, tbox, indices, anchors, tidxs, xywhn = build_targets(p, targets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Losses
|
|
|
|
|
|
|
|
for i, pi in enumerate(p): # layer index, layer predictions
|
|
|
|
|
|
|
|
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
|
|
|
|
|
|
|
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n = b.shape[0] # number of targets
|
|
|
|
|
|
|
|
if n:
|
|
|
|
|
|
|
|
pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, head.nc, nm), 1) # subset of predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Box regression
|
|
|
|
|
|
|
|
pxy = pxy.sigmoid() * 2 - 0.5
|
|
|
|
|
|
|
|
pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
|
|
|
|
|
|
|
|
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
|
|
|
|
|
|
|
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
|
|
|
|
|
|
|
|
lbox += (1.0 - iou).mean() # iou loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Objectness
|
|
|
|
|
|
|
|
iou = iou.detach().clamp(0).type(tobj.dtype)
|
|
|
|
|
|
|
|
if sort_obj_iou:
|
|
|
|
|
|
|
|
j = iou.argsort()
|
|
|
|
|
|
|
|
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
|
|
|
|
|
|
|
|
if gr < 1:
|
|
|
|
|
|
|
|
iou = (1.0 - gr) + gr * iou
|
|
|
|
|
|
|
|
tobj[b, a, gj, gi] = iou # iou ratio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Classification
|
|
|
|
|
|
|
|
if head.nc > 1: # cls loss (only if multiple classes)
|
|
|
|
|
|
|
|
t = torch.full_like(pcls, cn, device=self.device) # targets
|
|
|
|
|
|
|
|
t[range(n), tcls[i]] = cp
|
|
|
|
|
|
|
|
lcls += BCEcls(pcls, t) # BCE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Mask regression
|
|
|
|
|
|
|
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
|
|
|
|
|
|
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
|
|
|
|
|
|
marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized
|
|
|
|
|
|
|
|
mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
|
|
|
|
|
|
|
|
for bi in b.unique():
|
|
|
|
|
|
|
|
j = b == bi # matching index
|
|
|
|
|
|
|
|
if self.args.overlap_mask:
|
|
|
|
|
|
|
|
mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
mask_gti = masks[tidxs[i]][j]
|
|
|
|
|
|
|
|
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
lseg += (proto * 0).sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obji = BCEobj(pi[..., 4], tobj)
|
|
|
|
|
|
|
|
lobj += obji * balance[i] # obj loss
|
|
|
|
|
|
|
|
if autobalance:
|
|
|
|
|
|
|
|
balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if autobalance:
|
|
|
|
|
|
|
|
balance = [x / balance[ssi] for x in balance]
|
|
|
|
|
|
|
|
lbox *= self.args.box
|
|
|
|
|
|
|
|
lobj *= self.args.obj
|
|
|
|
|
|
|
|
lcls *= self.args.cls
|
|
|
|
|
|
|
|
lseg *= self.args.box / bs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = lbox + lobj + lcls + lseg
|
|
|
|
|
|
|
|
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
|
|
|
|
|
|
# We should just use named tensors here in future
|
|
|
|
|
|
|
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
|
|
|
|
|
|
|
return dict(zip(keys, loss_items)) if loss_items is not None else keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def progress_string(self):
|
|
|
|
|
|
|
|
return ('\n' + '%11s' * 8) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_training_samples(self, batch, ni):
|
|
|
|
def plot_training_samples(self, batch, ni):
|
|
|
|
images = batch["img"]
|
|
|
|
images = batch["img"]
|
|
|
@ -227,6 +49,129 @@ class SegmentationTrainer(DetectionTrainer):
|
|
|
|
plot_results(file=self.csv, segment=True) # save results.png
|
|
|
|
plot_results(file=self.csv, segment=True) # save results.png
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Criterion class for computing training losses
|
|
|
|
|
|
|
|
class SegLoss:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, overlap=True): # model must be de-paralleled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = next(model.parameters()).device # get model device
|
|
|
|
|
|
|
|
h = model.args # hyperparameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = model.model[-1] # Detect() module
|
|
|
|
|
|
|
|
self.bce = nn.BCEWithLogitsLoss(reduction='none')
|
|
|
|
|
|
|
|
self.hyp = h
|
|
|
|
|
|
|
|
self.stride = m.stride # model strides
|
|
|
|
|
|
|
|
self.nc = m.nc # number of classes
|
|
|
|
|
|
|
|
self.no = m.no
|
|
|
|
|
|
|
|
self.nm = m.nm # number of masks
|
|
|
|
|
|
|
|
self.reg_max = m.reg_max
|
|
|
|
|
|
|
|
self.overlap = overlap
|
|
|
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.use_dfl = m.reg_max > 1
|
|
|
|
|
|
|
|
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
|
|
|
|
|
|
|
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
|
|
|
|
|
|
|
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, targets, batch_size, scale_tensor):
|
|
|
|
|
|
|
|
if targets.shape[0] == 0:
|
|
|
|
|
|
|
|
out = torch.zeros(batch_size, 0, 5, device=self.device)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
i = targets[:, 0] # image index
|
|
|
|
|
|
|
|
_, counts = i.unique(return_counts=True)
|
|
|
|
|
|
|
|
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
|
|
|
|
|
|
|
|
for j in range(batch_size):
|
|
|
|
|
|
|
|
matches = i == j
|
|
|
|
|
|
|
|
n = matches.sum()
|
|
|
|
|
|
|
|
if n:
|
|
|
|
|
|
|
|
out[j, :n] = targets[matches, 1:]
|
|
|
|
|
|
|
|
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bbox_decode(self, anchor_points, pred_dist):
|
|
|
|
|
|
|
|
if self.use_dfl:
|
|
|
|
|
|
|
|
b, a, c = pred_dist.shape # batch, anchors, channels
|
|
|
|
|
|
|
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
|
|
|
|
|
|
|
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
|
|
|
|
|
|
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
|
|
|
|
|
|
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, preds, batch):
|
|
|
|
|
|
|
|
loss = torch.zeros(4, device=self.device) # box, cls, dfl
|
|
|
|
|
|
|
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
|
|
|
|
|
|
|
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
|
|
|
|
|
|
|
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
|
|
|
|
|
|
|
(self.reg_max * 4, self.nc), 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# b, grids, ..
|
|
|
|
|
|
|
|
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = pred_scores.dtype
|
|
|
|
|
|
|
|
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
|
|
|
|
|
|
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# targets
|
|
|
|
|
|
|
|
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
|
|
|
|
|
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
|
|
|
|
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
|
|
|
|
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
|
|
|
|
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masks = batch["masks"].to(self.device).float()
|
|
|
|
|
|
|
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
|
|
|
|
|
|
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pboxes
|
|
|
|
|
|
|
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
|
|
|
|
|
|
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
|
|
|
|
|
|
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_scores_sum = target_scores.sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# cls loss
|
|
|
|
|
|
|
|
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
|
|
|
|
|
|
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# bbox loss
|
|
|
|
|
|
|
|
if fg_mask.sum():
|
|
|
|
|
|
|
|
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
|
|
|
|
|
|
|
target_scores, target_scores_sum, fg_mask)
|
|
|
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
|
|
|
if fg_mask[i].sum():
|
|
|
|
|
|
|
|
mask_idx = target_gt_idx[i][fg_mask[i]] + 1
|
|
|
|
|
|
|
|
if self.overlap:
|
|
|
|
|
|
|
|
gt_mask = torch.where(masks[[i]] == mask_idx.view(-1, 1, 1), 1.0, 0.0)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
gt_mask = masks[batch_idx == i][mask_idx]
|
|
|
|
|
|
|
|
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
|
|
|
|
|
|
|
|
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
|
|
|
|
|
|
|
|
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
|
|
|
|
|
|
|
|
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy,
|
|
|
|
|
|
|
|
marea) # seg loss
|
|
|
|
|
|
|
|
# WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors
|
|
|
|
|
|
|
|
# else:
|
|
|
|
|
|
|
|
# loss[1] += proto.sum() * 0
|
|
|
|
|
|
|
|
# else:
|
|
|
|
|
|
|
|
# loss[1] += proto.sum() * 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss[0] *= 7.5 # box gain
|
|
|
|
|
|
|
|
loss[1] *= 7.5 / batch_size # seg gain
|
|
|
|
|
|
|
|
loss[2] *= 0.5 # cls gain
|
|
|
|
|
|
|
|
loss[3] *= 1.5 # dfl gain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
|
|
|
|
|
|
|
# Mask loss for one image
|
|
|
|
|
|
|
|
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
|
|
|
|
|
|
|
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
|
|
|
|
|
|
|
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
|
|
|
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
|
|
|
def train(cfg):
|
|
|
|
def train(cfg):
|
|
|
|
cfg.model = cfg.model or "models/yolov8n-seg.yaml"
|
|
|
|
cfg.model = cfg.model or "models/yolov8n-seg.yaml"
|
|
|
|