From ebd3cfb2fd2def6c8838ed9e3920f30d1b8959db Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 24 Dec 2022 18:10:44 +0100 Subject: [PATCH] YOLOv8 architecture updates from R&D branch (#88) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 13 +- ultralytics/tests/test_model.py | 10 +- .../yolo/data/scripts/download_weights.sh | 22 ++ ultralytics/yolo/data/scripts/get_coco.sh | 60 +++++ ultralytics/yolo/data/scripts/get_coco128.sh | 17 ++ ultralytics/yolo/data/scripts/get_imagenet.sh | 51 ++++ ultralytics/yolo/engine/model.py | 16 +- ultralytics/yolo/utils/anchors.py | 169 ------------ ultralytics/yolo/utils/callbacks/tb.py | 4 +- ultralytics/yolo/utils/loss.py | 53 ++++ ultralytics/yolo/utils/metrics.py | 20 +- ultralytics/yolo/utils/modeling/__init__.py | 19 +- ultralytics/yolo/utils/modeling/modules.py | 83 +++--- ultralytics/yolo/utils/modeling/tasks.py | 73 ++--- ultralytics/yolo/utils/ops.py | 29 +- ultralytics/yolo/utils/tal.py | 211 +++++++++++++++ ultralytics/yolo/v8/detect/train.py | 251 ++++++++---------- ultralytics/yolo/v8/detect/val.py | 1 - ultralytics/yolo/v8/models/yolov5n-seg.yaml | 48 ---- ultralytics/yolo/v8/models/yolov5n.yaml | 48 ---- ultralytics/yolo/v8/models/yolov8n-seg.yaml | 43 +++ ultralytics/yolo/v8/models/yolov8n.yaml | 42 +++ ultralytics/yolo/v8/segment/train.py | 11 +- 23 files changed, 722 insertions(+), 572 deletions(-) create mode 100755 ultralytics/yolo/data/scripts/download_weights.sh create mode 100755 ultralytics/yolo/data/scripts/get_coco.sh create mode 100755 ultralytics/yolo/data/scripts/get_coco128.sh create mode 100755 ultralytics/yolo/data/scripts/get_imagenet.sh delete mode 100644 ultralytics/yolo/utils/anchors.py create mode 100644 ultralytics/yolo/utils/loss.py create mode 100644 ultralytics/yolo/utils/tal.py delete mode 100644 ultralytics/yolo/v8/models/yolov5n-seg.yaml delete mode 100644 ultralytics/yolo/v8/models/yolov5n.yaml create mode 100644 ultralytics/yolo/v8/models/yolov8n-seg.yaml create mode 100644 ultralytics/yolo/v8/models/yolov8n.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 49d0c47..3a2df75 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -90,16 +90,15 @@ jobs: - name: Test detection shell: bash # for Windows compatibility run: | - yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=1 imgsz=64 + yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64 yolo task=detect mode=val model=runs/train/exp/weights/last.pt imgsz=64 - - name: Test segmentation + - name: Test segmentation # TODO: segmentation CI shell: bash # for Windows compatibility - # TODO: redo val test without hardcoded weights run: | - yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 - yolo task=segment mode=val model=runs/train/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64 - - name: Test classification + # yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 + # yolo task=segment mode=val model=runs/train/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64 + - name: Test classification # TODO: change to exp3 on Segmentation CI update shell: bash # for Windows compatibility run: | yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32 - yolo task=classify mode=val model=runs/train/exp3/weights/last.pt data=mnist160 + yolo task=classify mode=val model=runs/train/exp2/weights/last.pt data=mnist160 diff --git a/ultralytics/tests/test_model.py b/ultralytics/tests/test_model.py index 0bb8cbf..306f11f 100644 --- a/ultralytics/tests/test_model.py +++ b/ultralytics/tests/test_model.py @@ -5,7 +5,7 @@ from ultralytics.yolo import YOLO def test_model_forward(): model = YOLO() - model.new("yolov5n-seg.yaml") + model.new("yolov8n-seg.yaml") img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) model.forward(img) model(img) @@ -13,7 +13,7 @@ def test_model_forward(): def test_model_info(): model = YOLO() - model.new("yolov5n.yaml") + model.new("yolov8n.yaml") model.info() model.load("balloon-detect.pt") model.info(verbose=True) @@ -21,7 +21,7 @@ def test_model_info(): def test_model_fuse(): model = YOLO() - model.new("yolov5n.yaml") + model.new("yolov8n.yaml") model.fuse() model.load("balloon-detect.pt") model.fuse() @@ -41,7 +41,7 @@ def test_val(): def test_model_resume(): model = YOLO() - model.new("yolov5n-seg.yaml") + model.new("yolov8n-seg.yaml") model.train(epochs=1, imgsz=32, data="coco128-seg.yaml") try: model.resume(task="segment") @@ -53,7 +53,7 @@ def test_model_train_pretrained(): model = YOLO() model.load("balloon-detect.pt") model.train(data="coco128.yaml", epochs=1, imgsz=32) - model.new("yolov5n.yaml") + model.new("yolov8n.yaml") model.train(data="coco128.yaml", epochs=1, imgsz=32) img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) model(img) diff --git a/ultralytics/yolo/data/scripts/download_weights.sh b/ultralytics/yolo/data/scripts/download_weights.sh new file mode 100755 index 0000000..31e0a15 --- /dev/null +++ b/ultralytics/yolo/data/scripts/download_weights.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +# Download latest models from https://github.com/ultralytics/yolov5/releases +# Example usage: bash data/scripts/download_weights.sh +# parent +# └── yolov5 +# ├── yolov5s.pt ← downloads here +# ├── yolov5m.pt +# └── ... + +python - < 1 / thr).float().sum(1).mean() # anchors above threshold - bpr = (best > 1 / thr).float().mean() # best possible recall - return bpr, aat - - stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides - anchors = m.anchors.clone() * stride # current anchors - bpr, aat = metric(anchors.cpu().view(-1, 2)) - s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). ' - if bpr > 0.98: # threshold to recompute - LOGGER.info(f'{s}Current anchors are a good fit to dataset ✅') - else: - LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...') - na = m.anchors.numel() // 2 # number of anchors - anchors = kmean_anchors(dataset, n=na, imgsz=imgsz, thr=thr, gen=1000, verbose=False) - new_bpr = metric(anchors)[0] - if new_bpr > bpr: # replace anchors - anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) - m.anchors[:] = anchors.clone().view_as(m.anchors) - check_anchor_order(m) # must be in pixel-space (not grid-space) - m.anchors /= stride - s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)' - else: - s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)' - LOGGER.info(s) - - -def kmean_anchors(dataset='./data/coco128.yaml', n=9, imgsz=640, thr=4.0, gen=1000, verbose=True): - """ Creates kmeans-evolved anchors from training dataset - - Arguments: - dataset: path to data.yaml, or a loaded dataset - n: number of anchors - imgsz: image size used for training - thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 - gen: generations to evolve anchors using genetic algorithm - verbose: print all results - - Return: - k: kmeans evolved anchors - - Usage: - from utils.autoanchor import *; _ = kmean_anchors() - """ - from scipy.cluster.vq import kmeans - - npr = np.random - thr = 1 / thr - - def metric(k, wh): # compute metrics - r = wh[:, None] / k[None] - x = torch.min(r, 1 / r).min(2)[0] # ratio metric - # x = wh_iou(wh, torch.tensor(k)) # iou metric - return x, x.max(1)[0] # x, best_x - - def anchor_fitness(k): # mutation fitness - _, best = metric(torch.tensor(k, dtype=torch.float32), wh) - return (best * (best > thr).float()).mean() # fitness - - def print_results(k, verbose=True): - k = k[np.argsort(k.prod(1))] # sort small to large - x, best = metric(k, wh0) - bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr - s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \ - f'{PREFIX}n={n}, imgsz={imgsz}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \ - f'past_thr={x[x > thr].mean():.3f}-mean: ' - for x in k: - s += '%i,%i, ' % (round(x[0]), round(x[1])) - if verbose: - LOGGER.info(s[:-2]) - return k - - if isinstance(dataset, str): # *.yaml file - with open(dataset, errors='ignore') as f: - data_dict = yaml.safe_load(f) # model dict - - dataset = BaseDataset(data_dict['train'], augment=True, rect=True) - - # Get label wh - shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) - wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh - - # Filter - i = (wh0 < 3.0).any(1).sum() - if i: - LOGGER.info(f'{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size') - wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels - # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 - - # Kmeans init - try: - LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...') - assert n <= len(wh) # apply overdetermined constraint - s = wh.std(0) # sigmas for whitening - k = kmeans(wh / s, n, iter=30)[0] * s # points - assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar - except Exception: - LOGGER.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init') - k = np.sort(npr.rand(n * 2)).reshape(n, 2) * imgsz # random init - wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0)) - k = print_results(k, verbose=False) - - # Plot - # k, d = [None] * 20, [None] * 20 - # for i in tqdm(range(1, 21)): - # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance - # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) - # ax = ax.ravel() - # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') - # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh - # ax[0].hist(wh[wh[:, 0]<100, 0],400) - # ax[1].hist(wh[wh[:, 1]<100, 1],400) - # fig.savefig('wh.png', dpi=200) - - # Evolve - f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma - pbar = tqdm(range(gen), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar - for _ in pbar: - v = np.ones(sh) - while (v == 1).all(): # mutate until a change occurs (prevent duplicates) - v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) - kg = (k.copy() * v).clip(min=2.0) - fg = anchor_fitness(kg) - if fg > f: - f, k = fg, kg.copy() - pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' - if verbose: - print_results(k, verbose) - - return print_results(k).astype(np.float32) diff --git a/ultralytics/yolo/utils/callbacks/tb.py b/ultralytics/yolo/utils/callbacks/tb.py index a86a0d6..5fe4d28 100644 --- a/ultralytics/yolo/utils/callbacks/tb.py +++ b/ultralytics/yolo/utils/callbacks/tb.py @@ -16,11 +16,11 @@ def on_train_start(trainer): def on_batch_end(trainer): - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch) + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) def on_val_end(trainer): - _log_scalars(trainer.metrics, trainer.epoch) + _log_scalars(trainer.metrics, trainer.epoch + 1) callbacks = {"on_train_start": on_train_start, "on_val_end": on_val_end, "on_batch_end": on_batch_end} diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py new file mode 100644 index 0000000..bb8505c --- /dev/null +++ b/ultralytics/yolo/utils/loss.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .metrics import bbox_iou +from .tal import bbox2dist + + +class VarifocalLoss(nn.Module): + # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367 + def __init__(self): + super().__init__() + + def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): + weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label + with torch.cuda.amp.autocast(enabled=False): + loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * + weight).sum() + return loss + + +class BboxLoss(nn.Module): + + def __init__(self, reg_max, use_dfl=False): + super().__init__() + self.reg_max = reg_max + self.use_dfl = use_dfl + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + # IoU loss + weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) + iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + if self.use_dfl: + target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max) + loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl + + @staticmethod + def _df_loss(pred_dist, target): + # Return sum of left and right DFL losses + tl = target.long() # target left + tr = tl + 1 # target right + wl = tr - target # weight left + wr = 1 - wl # weight right + return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True) diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 3db69cf..30b1b72 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -56,11 +56,11 @@ def box_iou(box1, box2, eps=1e-7): """ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) - (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1) + (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2) inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2) # IoU = inter / (area1 + area2 - inter) - return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter + eps) + return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps) def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): @@ -68,19 +68,19 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 # Get the coordinates of bounding boxes if xywh: # transform from xywh to xyxy - (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1) + (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ else: # x1, y1, x2, y2 = box1 - b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1) - b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1) + b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) + b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps # Intersection area - inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ - (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \ + (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0) # Union Area union = w1 * h1 + w2 * h2 - inter + eps @@ -88,13 +88,13 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 # IoU iou = inter / union if CIoU or DIoU or GIoU: - cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width - ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width + ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU diff --git a/ultralytics/yolo/utils/modeling/__init__.py b/ultralytics/yolo/utils/modeling/__init__.py index 2719650..48a5917 100644 --- a/ultralytics/yolo/utils/modeling/__init__.py +++ b/ultralytics/yolo/utils/modeling/__init__.py @@ -46,12 +46,11 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True): def parse_model(d, ch): # model_dict, input_channels(3) # Parse a YOLOv5 model.yaml dictionary LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}") - anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation') + nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation') if act: Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU() LOGGER.info(f"{colorstr('activation:')} {act}") # print - na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors - no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + no = nc + 4 # number of outputs = classes + box layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args @@ -62,14 +61,14 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in { - Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C3, C3TR, - C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: + Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, + C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] - if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}: + if m in {BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x}: args.insert(2, n) # number of repeats n = 1 elif m is nn.BatchNorm2d: @@ -79,8 +78,6 @@ def parse_model(d, ch): # model_dict, input_channels(3) # TODO: channel, gw, gd elif m in {Detect, Segment}: args.append([ch[x] for x in f]) - if isinstance(args[1], int): # number of anchors - args[1] = [list(range(args[1] * 2))] * len(f) if m is Segment: args[3] = make_divisible(args[3] * gw, 8) else: @@ -88,9 +85,9 @@ def parse_model(d, ch): # model_dict, input_channels(3) m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module t = str(m)[8:-2].replace('__main__.', '') # module type - np = sum(x.numel() for x in m_.parameters()) # number params - m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params - LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print + m.np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type + LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{m.np:10.0f} {t:<40}{str(args):<30}') # print save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) if i == 0: diff --git a/ultralytics/yolo/utils/modeling/modules.py b/ultralytics/yolo/utils/modeling/modules.py index 6fe61fa..06270a6 100644 --- a/ultralytics/yolo/utils/modeling/modules.py +++ b/ultralytics/yolo/utils/modeling/modules.py @@ -19,10 +19,10 @@ from torch.cuda import amp from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.utils import LOGGER, colorstr -from ultralytics.yolo.utils.checks import check_version from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box +from ultralytics.yolo.utils.tal import dist2bbox, make_anchors from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode from .autobackend import AutoBackend @@ -605,62 +605,55 @@ class Ensemble(nn.ModuleList): # heads class Detect(nn.Module): # YOLOv5 Detect head for detection models - stride = None # strides computed during build dynamic = False # force grid reconstruction export = False # export mode + shape = None + anchors = torch.empty(0) # init + strides = torch.empty(0) # init - def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer + def __init__(self, nc=80, ch=()): # detection layer super().__init__() self.nc = nc # number of classes - self.no = nc + 5 # number of outputs per anchor - self.nl = len(anchors) # number of detection layers - self.na = len(anchors[0]) // 2 # number of anchors - self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid - self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid - self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2) - self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv - self.inplace = inplace # use inplace ops (e.g. slice assignment) + self.nl = len(ch) # number of detection layers + self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = torch.zeros(self.nl) # strides computed during build + + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch) + self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() def forward(self, x): - z = [] # inference output + shape = x[0].shape # BCHW for i in range(self.nl): - x[i] = self.m[i](x[i]) # conv - bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) - x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() - - if not self.training: # inference - if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: - self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) - - if isinstance(self, Segment): # (boxes + masks) - xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4) - xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy - wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh - y = torch.cat((xy, wh, conf.sigmoid(), mask), 4) - else: # Detect (boxes only) - xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4) - xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy - wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh - y = torch.cat((xy, wh, conf), 4) - z.append(y.view(bs, self.na * nx * ny, self.no)) - - return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) - - def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')): - d = self.anchors[i].device - t = self.anchors[i].dtype - shape = 1, self.na, ny, nx, 2 # grid shape - y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t) - yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility - grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5 - anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape) - return grid, anchor_grid + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1) + if self.training: + return x, box, cls + elif self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + y = torch.cat((dbox, cls.sigmoid()), 1) + return y if self.export else (y, (x, box, cls)) + + def bias_init(self): + # Initialize Detect() biases, WARNING: requires stride availability + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) class Segment(Detect): # YOLOv5 Segment head for segmentation models - def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True): - super().__init__(nc, anchors, ch, inplace) + def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=()): + super().__init__(nc, anchors, ch) self.nm = nm # number of masks self.npr = npr # number of protos self.no = 5 + nc + self.nm # number of outputs per anchor diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py index 7f22387..59e502c 100644 --- a/ultralytics/yolo/utils/modeling/tasks.py +++ b/ultralytics/yolo/utils/modeling/tasks.py @@ -2,7 +2,6 @@ from copy import deepcopy import thop -from ultralytics.yolo.utils.anchors import check_anchor_order from ultralytics.yolo.utils.modeling import parse_model from ultralytics.yolo.utils.modeling.modules import * from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info, @@ -60,9 +59,8 @@ class BaseModel(nn.Module): m = self.model[-1] # Detect() if isinstance(m, (Detect, Segment)): m.stride = fn(m.stride) - m.grid = list(map(fn, m.grid)) - if isinstance(m.anchor_grid, list): - m.anchor_grid = list(map(fn, m.anchor_grid)) + m.anchors = fn(m.anchors) + m.strides = fn(m.strides) return self def load(self, weights): @@ -71,8 +69,8 @@ class BaseModel(nn.Module): class DetectionModel(BaseModel): - # YOLO detection model - def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes + # YOLOv5 detection model + def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes super().__init__() if isinstance(cfg, dict): self.yaml = cfg # model dict @@ -87,24 +85,19 @@ class DetectionModel(BaseModel): if nc and nc != self.yaml['nc']: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") self.yaml['nc'] = nc # override yaml value - if anchors: - LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}') - self.yaml['anchors'] = round(anchors) # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names self.inplace = self.yaml.get('inplace', True) - # Build strides, anchors + # Build strides m = self.model[-1] # Detect() if isinstance(m, (Detect, Segment)): s = 256 # 2x min stride m.inplace = self.inplace - forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x) + forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Detect)) else self.forward(x) m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward - check_anchor_order(m) - m.anchors /= m.stride.view(-1, 1, 1) self.stride = m.stride - self._initialize_biases() # only run once + m.bias_init() # only run once # Init weights, biases initialize_weights(self) @@ -117,7 +110,7 @@ class DetectionModel(BaseModel): return self._forward_once(x, profile, visualize) # single-scale inference, train def _forward_augment(self, x): - imgsz = x.shape[-2:] # height, width + img_size = x.shape[-2:] # height, width s = [1, 0.83, 0.67] # scales f = [None, 3, None] # flips (2-ud, 3-lr) y = [] # outputs @@ -125,49 +118,33 @@ class DetectionModel(BaseModel): xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) yi = self._forward_once(xi)[0] # forward # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save - yi = self._descale_pred(yi, fi, si, imgsz) + yi = self._descale_pred(yi, fi, si, img_size) y.append(yi) y = self._clip_augmented(y) # clip augmented tails - return torch.cat(y, 1), None # augmented inference, train + return torch.cat(y, -1), None # augmented inference, train - def _descale_pred(self, p, flips, scale, imgsz): + @staticmethod + def _descale_pred(p, flips, scale, img_size, dim=1): # de-scale predictions following augmented inference (inverse operation) - if self.inplace: - p[..., :4] /= scale # de-scale - if flips == 2: - p[..., 1] = imgsz[0] - p[..., 1] # de-flip ud - elif flips == 3: - p[..., 0] = imgsz[1] - p[..., 0] # de-flip lr - else: - x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale - if flips == 2: - y = imgsz[0] - y # de-flip ud - elif flips == 3: - x = imgsz[1] - x # de-flip lr - p = torch.cat((x, y, wh, p[..., 4:]), -1) - return p + p[:, :4] /= scale # de-scale + x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim) + if flips == 2: + y = img_size[0] - y # de-flip ud + elif flips == 3: + x = img_size[1] - x # de-flip lr + return torch.cat((x, y, wh, cls), dim) def _clip_augmented(self, y): # Clip YOLOv5 augmented inference tails nl = self.model[-1].nl # number of detection layers (P3-P5) g = sum(4 ** x for x in range(nl)) # grid points e = 1 # exclude layer count - i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices - y[0] = y[0][:, :-i] # large - i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices - y[-1] = y[-1][:, i:] # small + i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices + y[0] = y[0][..., :-i] # large + i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices + y[-1] = y[-1][..., i:] # small return y - def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency - # https://arxiv.org/abs/1708.02002 section 3.3 - # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. - m = self.model[-1] # Detect() module - for mi, s in zip(m.m, m.stride): # from - b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) - b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) - b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls - mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) - def load(self, weights): csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_state_dicts(csd, self.state_dict()) # intersect @@ -177,8 +154,8 @@ class DetectionModel(BaseModel): class SegmentationModel(DetectionModel): # YOLOv5 segmentation model - def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, anchors=None): - super().__init__(cfg, ch, nc, anchors) + def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None): + super().__init__(cfg, ch, nc) class ClassificationModel(BaseModel): diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index 106dda5..eb289f6 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -101,6 +101,7 @@ def non_max_suppression( nm=0, # number of masks ): """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections + Returns: list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ @@ -113,8 +114,9 @@ def non_max_suppression( if mps: # MPS not fully supported yet, convert tensors to CPU before NMS prediction = prediction.cpu() bs = prediction.shape[0] # batch size - nc = prediction.shape[2] - nm - 5 # number of classes - xc = prediction[..., 4] > conf_thres # candidates + nc = prediction.shape[1] - nm - 4 # number of classes + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates # Checks assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' @@ -130,39 +132,32 @@ def non_max_suppression( merge = False # use merge-NMS t = time.time() - mi = 5 + nc # mask start index output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs for xi, x in enumerate(prediction): # image index, image inference # Apply constraints - # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height - x = x[xc[xi]] # confidence + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x.T[xc[xi]] # confidence # Cat apriori labels if autolabelling if labels and len(labels[xi]): lb = labels[xi] v = torch.zeros((len(lb), nc + nm + 5), device=x.device) v[:, :4] = lb[:, 1:5] # box - v[:, 4] = 1.0 # conf - v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls x = torch.cat((x, v), 0) # If none remain process next image if not x.shape[0]: continue - # Compute conf - x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf - - # Box/Mask - box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2) - mask = x[:, mi:] # zero columns if no masks - # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2) if multi_label: - i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T - x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1) + i, j = (cls > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) else: # best class only - conf, j = x[:, 5:mi].max(1, keepdim=True) + conf, j = cls.max(1, keepdim=True) x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] # Filter by class diff --git a/ultralytics/yolo/utils/tal.py b/ultralytics/yolo/utils/tal.py new file mode 100644 index 0000000..35e9f28 --- /dev/null +++ b/ultralytics/yolo/utils/tal.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .checks import check_version +from .metrics import bbox_iou + +TORCH_1_10 = check_version(torch.__version__, '1.10.0') + + +def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): + """select the positive anchor center in gt + + Args: + xy_centers (Tensor): shape(h*w, 4) + gt_bboxes (Tensor): shape(b, n_boxes, 4) + Return: + (Tensor): shape(b, n_boxes, h*w) + """ + n_anchors = xy_centers.shape[0] + bs, n_boxes, _ = gt_bboxes.shape + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom + bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) + # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) + return bbox_deltas.amin(3).gt_(eps) + + +def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): + """if an anchor box is assigned to multiple gts, + the one with the highest iou will be selected. + + Args: + mask_pos (Tensor): shape(b, n_max_boxes, h*w) + overlaps (Tensor): shape(b, n_max_boxes, h*w) + Return: + target_gt_idx (Tensor): shape(b, h*w) + fg_mask (Tensor): shape(b, h*w) + mask_pos (Tensor): shape(b, n_max_boxes, h*w) + """ + # (b, n_max_boxes, h*w) -> (b, h*w) + fg_mask = mask_pos.sum(-2) + if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes + mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w) + max_overlaps_idx = overlaps.argmax(1) # (b, h*w) + is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) # (b, h*w, n_max_boxes) + is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w) + mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w) + fg_mask = mask_pos.sum(-2) + # find each grid serve which gt(index) + target_gt_idx = mask_pos.argmax(-2) # (b, h*w) + return target_gt_idx, fg_mask, mask_pos + + +class TaskAlignedAssigner(nn.Module): + + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): + super().__init__() + self.topk = topk + self.num_classes = num_classes + self.bg_idx = num_classes + self.alpha = alpha + self.beta = beta + self.eps = eps + + @torch.no_grad() + def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + """This code referenced to + https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py + + Args: + pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) + pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) + anc_points (Tensor): shape(num_total_anchors, 2) + gt_labels (Tensor): shape(bs, n_max_boxes, 1) + gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) + mask_gt (Tensor): shape(bs, n_max_boxes, 1) + Returns: + target_labels (Tensor): shape(bs, num_total_anchors) + target_bboxes (Tensor): shape(bs, num_total_anchors, 4) + target_scores (Tensor): shape(bs, num_total_anchors, num_classes) + fg_mask (Tensor): shape(bs, num_total_anchors) + """ + self.bs = pd_scores.size(0) + self.n_max_boxes = gt_bboxes.size(1) + + if self.n_max_boxes == 0: + device = gt_bboxes.device + return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device), + torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device)) + + mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, + mask_gt) + + target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) + + # assigned target + target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) + + # normalize + align_metric *= mask_pos + pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj + pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj + norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) + target_scores = target_scores * norm_align_metric + + return target_labels, target_bboxes, target_scores, fg_mask.bool() + + def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): + # get anchor_align metric, (b, max_num_obj, h*w) + align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes) + # get in_gts mask, (b, max_num_obj, h*w) + mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes) + # get topk_metric mask, (b, max_num_obj, h*w) + mask_topk = self.select_topk_candidates(align_metric * mask_in_gts, + topk_mask=mask_gt.repeat([1, 1, self.topk]).bool()) + # merge all mask to a final mask, (b, max_num_obj, h*w) + mask_pos = mask_topk * mask_in_gts * mask_gt + + return mask_pos, align_metric, overlaps + + def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): + ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj + ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj + ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj + # get the scores of each grid for each gt cls + bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w + + overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, CIoU=True).squeeze(3).clamp(0) + align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) + return align_metric, overlaps + + def select_topk_candidates(self, metrics, largest=True, topk_mask=None): + """ + Args: + metrics: (b, max_num_obj, h*w). + topk_mask: (b, max_num_obj, topk) or None + """ + + num_anchors = metrics.shape[-1] # h*w + # (b, max_num_obj, topk) + topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest) + if topk_mask is None: + topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk]) + # (b, max_num_obj, topk) + topk_idxs = torch.where(topk_mask, topk_idxs, 0) + # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) + is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2) + # filter invalid bboxes + # assigned topk should be unique, this is for dealing with empty labels + # since empty labels will generate index `0` through `F.one_hot` + # NOTE: but what if the topk_idxs include `0`? + is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk) + return is_in_topk.to(metrics.dtype) + + def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): + """ + Args: + gt_labels: (b, max_num_obj, 1) + gt_bboxes: (b, max_num_obj, 4) + target_gt_idx: (b, h*w) + fg_mask: (b, h*w) + """ + + # assigned target labels, (b, 1) + batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] + target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w) + target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w) + + # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w) + target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx] + + # assigned target scores + target_labels.clamp(0) + target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80) + fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) + target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) + + return target_labels, target_bboxes, target_scores + + +def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = [], [] + assert feats is not None + dtype, device = feats[0].dtype, feats[0].device + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_points), torch.cat(stride_tensor) + + +def dist2bbox(distance, anchor_points, xywh=True, dim=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = torch.split(distance, 2, dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return torch.cat((c_xy, wh), dim) # xywh bbox + return torch.cat((x1y1, x2y2), dim) # xyxy bbox + + +def bbox2dist(anchor_points, bbox, reg_max): + """Transform bbox(xyxy) to dist(ltrb).""" + x1y1, x2y2 = torch.split(bbox, 2, -1) + return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp(0, reg_max - 0.01) # dist (lt, rb) diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 3beac4a..b361eab 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -5,9 +5,12 @@ import torch.nn as nn from ultralytics.yolo import v8 from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer -from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE +from ultralytics.yolo.utils.loss import BboxLoss +from ultralytics.yolo.utils.metrics import smooth_BCE from ultralytics.yolo.utils.modeling.tasks import DetectionModel +from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.plotting import plot_images, plot_results +from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors from ultralytics.yolo.utils.torch_utils import de_parallel @@ -35,10 +38,7 @@ class DetectionTrainer(BaseTrainer): self.model.names = self.data["names"] def load_model(self, model_cfg=None, weights=None): - model = DetectionModel(model_cfg or weights["model"].yaml, - ch=3, - nc=self.data["nc"], - anchors=self.args.get("anchors")) + model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"]) if weights: model.load(weights) for _, v in model.named_parameters(): @@ -46,150 +46,14 @@ class DetectionTrainer(BaseTrainer): return model def get_validator(self): - self.loss_names = 'box_loss', 'obj_loss', 'cls_loss' + self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, args=self.args) def criterion(self, preds, batch): - head = de_parallel(self.model).model[-1] - sort_obj_iou = False - autobalance = False - - # init losses - BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device)) - BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device)) - - # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 - cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets - - # Focal loss - g = self.args.fl_gamma - if self.args.fl_gamma > 0: - BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) - - balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 - ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index - BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance - - def build_targets(p, targets): - # Build targets for compute_loss(), input targets(image,class,x,y,w,h) - nonlocal head - na, nt = head.na, targets.shape[0] # number of anchors, targets - tcls, tbox, indices, anch = [], [], [], [] - gain = torch.ones(7, device=self.device) # normalized to gridspace gain - ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) - targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices - - g = 0.5 # bias - off = torch.tensor( - [ - [0, 0], - [1, 0], - [0, 1], - [-1, 0], - [0, -1], # j,k,l,m - # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm - ], - device=self.device).float() * g # offsets - - for i in range(head.nl): - anchors, shape = head.anchors[i], p[i].shape - gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain - - # Match targets to anchors - t = targets * gain # shape(3,n,7) - if nt: - # Matches - r = t[..., 4:6] / anchors[:, None] # wh ratio - j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare - # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) - t = t[j] # filter - - # Offsets - gxy = t[:, 2:4] # grid xy - gxi = gain[[2, 3]] - gxy # inverse - j, k = ((gxy % 1 < g) & (gxy > 1)).T - l, m = ((gxi % 1 < g) & (gxi > 1)).T - j = torch.stack((torch.ones_like(j), j, k, l, m)) - t = t.repeat((5, 1, 1))[j] - offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] - else: - t = targets[0] - offsets = 0 - - # Define - bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors - a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class - gij = (gxy - offsets).long() - gi, gj = gij.T # grid indices - - # Append - indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid - tbox.append(torch.cat((gxy - gij, gwh), 1)) # box - anch.append(anchors[a]) # anchors - tcls.append(c) # class - - return tcls, tbox, indices, anch - - if len(preds) == 2: # eval - _, p = preds - else: # len(3) train - p = preds - - targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) - targets = targets.to(self.device) - - lcls = torch.zeros(1, device=self.device) - lbox = torch.zeros(1, device=self.device) - lobj = torch.zeros(1, device=self.device) - tcls, tbox, indices, anchors = build_targets(p, targets) - - # Losses - for i, pi in enumerate(p): # layer index, layer predictions - b, a, gj, gi = indices[i] # image, anchor, gridy, gridx - tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj - bs = tobj.shape[0] - n = b.shape[0] # number of targets - if n: - pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, head.nc), 1) # subset of predictions - - # Box regression - pxy = pxy.sigmoid() * 2 - 0.5 - pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] - pbox = torch.cat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target) - lbox += (1.0 - iou).mean() # iou loss - - # Objectness - iou = iou.detach().clamp(0).type(tobj.dtype) - if sort_obj_iou: - j = iou.argsort() - b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j] - if gr < 1: - iou = (1.0 - gr) + gr * iou - tobj[b, a, gj, gi] = iou # iou ratio - - # Classification - if head.nc > 1: # cls loss (only if multiple classes) - t = torch.full_like(pcls, cn, device=self.device) # targets - t[range(n), tcls[i]] = cp - lcls += BCEcls(pcls, t) # BCE - - obji = BCEobj(pi[..., 4], tobj) - lobj += obji * balance[i] # obj loss - if autobalance: - balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item() - - if autobalance: - balance = [x / balance[ssi] for x in balance] - lbox *= self.args.box - lobj *= self.args.obj - lcls *= self.args.cls - - loss = lbox + lobj + lcls - return loss * bs, torch.cat((lbox, lobj, lcls)).detach() + return Loss(self.model)(preds, batch) def label_loss_items(self, loss_items=None, prefix="train"): # We should just use named tensors here in future @@ -212,10 +76,105 @@ class DetectionTrainer(BaseTrainer): plot_results(file=self.csv) # save results.png +# Criterion class for computing training losses +class Loss: + + def __init__(self, model): + + device = next(model.parameters()).device # get model device + h = model.args # hyperparameters + + # Define criteria + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device), reduction='none') + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets + + m = de_parallel(model).model[-1] # Detect() module + self.BCEcls = BCEcls + self.hyp = h + self.stride = m.stride # model strides + self.nc = m.nc # number of classes + self.nl = m.nl # number of layers + self.device = device + + self.use_dfl = m.reg_max > 1 + self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) + self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) + + def preprocess(self, targets, batch_size, scale_tensor): + if targets.shape[0] == 0: + out = torch.zeros(batch_size, 0, 5, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + out = torch.zeros(batch_size, counts.max(), 5, device=self.device) + for j in range(batch_size): + matches = i == j + n = matches.sum() + if n: + out[j, :n] = targets[matches, 1:] + out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) + return out + + def bbox_decode(self, anchor_points, pred_dist): + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) + return dist2bbox(pred_dist, anchor_points, xywh=False) + + def __call__(self, preds, batch): + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats, pred_distri, pred_scores = preds if len(preds) == 3 else preds[1] + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + batch_size, grid_size = pred_scores.shape[:2] + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + + # pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + + target_labels, target_bboxes, target_scores, fg_mask = self.assigner( + pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + + target_bboxes /= stride_tensor + target_scores_sum = target_scores.sum() + + # cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.BCEcls(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # bbox loss + if fg_mask.sum(): + loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, + target_scores_sum, fg_mask) + + loss[0] *= 7.5 # box gain + loss[1] *= 0.5 # cls gain + loss[2] *= 1.5 # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "models/yolov5n.yaml" + cfg.model = cfg.model or "models/yolov8n.yaml" cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") + cfg.imgsz = 160 + cfg.epochs = 5 trainer = DetectionTrainer(cfg) trainer.train() @@ -223,9 +182,9 @@ def train(cfg): if __name__ == "__main__": """ CLI usage: - python ultralytics/yolo/v8/detect/train.py model=yolov5n.yaml data=coco128 epochs=100 imgsz=640 + python ultralytics/yolo/v8/detect/train.py model=yolov8n.yaml data=coco128 epochs=100 imgsz=640 TODO: - yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=100 + yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 """ train() diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 4feace6..63bfe2f 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -3,7 +3,6 @@ import os import hydra import numpy as np import torch -import torch.nn.functional as F from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG diff --git a/ultralytics/yolo/v8/models/yolov5n-seg.yaml b/ultralytics/yolo/v8/models/yolov5n-seg.yaml deleted file mode 100644 index c28225a..0000000 --- a/ultralytics/yolo/v8/models/yolov5n-seg.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.33 # model depth multiple -width_multiple: 0.25 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) - ] diff --git a/ultralytics/yolo/v8/models/yolov5n.yaml b/ultralytics/yolo/v8/models/yolov5n.yaml deleted file mode 100644 index 8a28a40..0000000 --- a/ultralytics/yolo/v8/models/yolov5n.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.33 # model depth multiple -width_multiple: 0.25 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] diff --git a/ultralytics/yolo/v8/models/yolov8n-seg.yaml b/ultralytics/yolo/v8/models/yolov8n-seg.yaml new file mode 100644 index 0000000..18f877e --- /dev/null +++ b/ultralytics/yolo/v8/models/yolov8n-seg.yaml @@ -0,0 +1,43 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple +anchors: [[16,19], [55,65], [178,192]] + +# YOLOv8n v0.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 3, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C2f, [128, True]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C2f, [256, True]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 6, C2f, [512, True]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C2f, [1024, True]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv8n v0.0 head +head: + [[-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C2f, [512]], # 13 + + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C2f, [256]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P4 + [-1, 3, C2f, [512]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C2f, [1024]], # 23 (P5/32-large) + + [[15, 18, 21], 1, Segment, [nc, 32, 256]], # Detect(P3, P4, P5) + ] diff --git a/ultralytics/yolo/v8/models/yolov8n.yaml b/ultralytics/yolo/v8/models/yolov8n.yaml new file mode 100644 index 0000000..ba0fcee --- /dev/null +++ b/ultralytics/yolo/v8/models/yolov8n.yaml @@ -0,0 +1,42 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple + +# YOLOv8.0n backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 3, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C2f, [128, True]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C2f, [256, True]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 6, C2f, [512, True]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C2f, [1024, True]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv8.0n head +head: + [[-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C2f, [512]], # 13 + + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C2f, [256]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P4 + [-1, 3, C2f, [512]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C2f, [1024]], # 23 (P5/32-large) + + [[15, 18, 21], 1, Detect, [nc]], # Detect(P3, P4, P5) + ] diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index df2b1b2..5dd0f59 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -18,10 +18,7 @@ from ..detect import DetectionTrainer class SegmentationTrainer(DetectionTrainer): def load_model(self, model_cfg=None, weights=None): - model = SegmentationModel(model_cfg or weights["model"].yaml, - ch=3, - nc=self.data["nc"], - anchors=self.args.get("anchors")) + model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"]) if weights: model.load(weights) for _, v in model.named_parameters(): @@ -29,7 +26,7 @@ class SegmentationTrainer(DetectionTrainer): return model def get_validator(self): - self.loss_names = 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss' + self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, @@ -235,7 +232,7 @@ class SegmentationTrainer(DetectionTrainer): @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "models/yolov5n-seg.yaml" + cfg.model = cfg.model or "models/yolov8n-seg.yaml" cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") trainer = SegmentationTrainer(cfg) trainer.train() @@ -244,7 +241,7 @@ def train(cfg): if __name__ == "__main__": """ CLI usage: - python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 imgsz=640 + python ultralytics/yolo/v8/segment/train.py model=yolov8n-seg.yaml data=coco128-segments epochs=100 imgsz=640 TODO: Direct cli support, i.e, yolov8 classify_train args.epochs 10