diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0a768a5..0ab515e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -90,15 +90,16 @@ jobs: - name: Test detection shell: bash # for Windows compatibility run: | - echo "TODO" + yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=1 img_size=64 + yolo task=detect mode=val model=runs/exp/weights/last.pt img_size=64 - name: Test segmentation shell: bash # for Windows compatibility # TODO: redo val test without hardcoded weights run: | yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64 - yolo task=segment mode=val model=runs/exp/weights/last.pt data=coco128-seg.yaml img_size=64 + yolo task=segment mode=val model=runs/exp2/weights/last.pt data=coco128-seg.yaml img_size=64 - name: Test classification shell: bash # for Windows compatibility run: | yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32 - yolo task=classify mode=val model=runs/exp2/weights/last.pt data=mnist160 \ No newline at end of file + yolo task=classify mode=val model=runs/exp3/weights/last.pt data=mnist160 diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 1e05843..1f80b92 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -459,14 +459,14 @@ def ap_per_class_box_and_mask( "boxes": { "p": results_boxes[0], "r": results_boxes[1], - "ap": results_boxes[3], "f1": results_boxes[2], + "ap": results_boxes[3], "ap_class": results_boxes[4]}, "masks": { "p": results_masks[0], "r": results_masks[1], - "ap": results_masks[3], "f1": results_masks[2], + "ap": results_masks[3], "ap_class": results_masks[4]}} return results @@ -547,7 +547,7 @@ class Metric: Args: results: tuple(p, r, ap, f1, ap_class) """ - p, r, all_ap, f1, ap_class_index = results + p, r, f1, all_ap, ap_class_index = results self.p = p self.r = r self.all_ap = all_ap diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py index 2bad281..cdd2fd2 100644 --- a/ultralytics/yolo/utils/plotting.py +++ b/ultralytics/yolo/utils/plotting.py @@ -186,7 +186,15 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, @threaded -def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None): +def plot_images_and_masks(images, + batch_idx, + cls, + bboxes, + masks, + confs=None, + paths=None, + fname='images.jpg', + names=None): # Plot image grid with labels if isinstance(images, torch.Tensor): images = images.cpu().float().numpy() @@ -327,3 +335,99 @@ def output_to_target(output, max_det=300): targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1)) targets = torch.cat(targets, 0).numpy() return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6] + + +@threaded +def plot_images(images, batch_idx, cls, bboxes, confs=None, paths=None, fname='images.jpg', names=None): + # Plot image grid with labels + if isinstance(images, torch.Tensor): + images = images.cpu().float().numpy() + if isinstance(cls, torch.Tensor): + cls = cls.cpu().numpy() + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + if isinstance(batch_idx, torch.Tensor): + batch_idx = batch_idx.cpu().numpy() + + max_size = 1920 # max image size + max_subplots = 16 # max image subplots, i.e. 4x4 + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs ** 0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) + + # Build Image + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i, im in enumerate(images): + if i == max_subplots: # if last batch has fewer images than we expect + break + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + im = im.transpose(1, 2, 0) + mosaic[y:y + h, x:x + w, :] = im + + # Resize (optional) + scale = max_size / ns / max(h, w) + if scale < 1: + h = math.ceil(scale * h) + w = math.ceil(scale * w) + mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) + + # Annotate + fs = int((h + w) * ns * 0.01) # font size + annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) + for i in range(i + 1): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders + if paths: + annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames + if len(cls) > 0: + idx = batch_idx == i + + boxes = xywh2xyxy(bboxes[idx]).T + classes = cls[idx].astype('int') + labels = confs is None # labels if no conf column + conf = None if labels else confs[idx] # check for confidence presence (label vs pred) + + if boxes.shape[1]: + if boxes.max() <= 1.01: # if normalized with tolerance 0.01 + boxes[[0, 2]] *= w # scale to pixels + boxes[[1, 3]] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes *= scale + boxes[[0, 2]] += x + boxes[[1, 3]] += y + for j, box in enumerate(boxes.T.tolist()): + c = classes[j] + color = colors(c) + c = names[c] if names else c + if labels or conf[j] > 0.25: # 0.25 conf thresh + label = f'{c}' if labels else f'{c} {conf[j]:.1f}' + annotator.box_label(box, label, color=color) + annotator.im.save(fname) # save + + +def plot_results(file='path/to/results.csv', dir=''): + # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') + save_dir = Path(file).parent if file else Path(dir) + fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) + ax = ax.ravel() + files = list(save_dir.glob('results*.csv')) + assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' + for f in files: + try: + data = pd.read_csv(f) + s = [x.strip() for x in data.columns] + x = data.values[:, 0] + for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]): + y = data.values[:, j].astype('float') + # y[y == 0] = np.nan # don't show zero values + ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) + ax[i].set_title(s[j], fontsize=12) + # if j in [8, 9, 10]: # share train and val loss y axes + # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) + except Exception as e: + print(f'Warning: Plotting error for {f}: {e}') + ax[1].legend() + fig.savefig(save_dir / 'results.png', dpi=200) + plt.close() diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py index a18b41a..cec773e 100644 --- a/ultralytics/yolo/v8/__init__.py +++ b/ultralytics/yolo/v8/__init__.py @@ -1,7 +1,7 @@ from pathlib import Path -from ultralytics.yolo.v8 import classify, segment +from ultralytics.yolo.v8 import classify, detect, segment ROOT = Path(__file__).parents[0] # yolov8 ROOT -__all__ = ["classify", "segment"] +__all__ = ["classify", "segment", "detect"] diff --git a/ultralytics/yolo/v8/detect/__init__.py b/ultralytics/yolo/v8/detect/__init__.py new file mode 100644 index 0000000..edce22a --- /dev/null +++ b/ultralytics/yolo/v8/detect/__init__.py @@ -0,0 +1,2 @@ +from ultralytics.yolo.v8.detect.train import DetectionTrainer, train +from ultralytics.yolo.v8.detect.val import DetectionValidator, val diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py new file mode 100644 index 0000000..22e9e36 --- /dev/null +++ b/ultralytics/yolo/v8/detect/train.py @@ -0,0 +1,209 @@ +import hydra +import torch +import torch.nn as nn + +from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG +from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE +from ultralytics.yolo.utils.modeling.tasks import DetectionModel +from ultralytics.yolo.utils.plotting import plot_images, plot_results +from ultralytics.yolo.utils.torch_utils import de_parallel + +from ..segment import SegmentationTrainer +from .val import DetectionValidator + + +# BaseTrainer python usage +class DetectionTrainer(SegmentationTrainer): + + def load_model(self, model_cfg, weights, data): + model = DetectionModel(model_cfg or weights["model"].yaml, + ch=3, + nc=data["nc"], + anchors=self.args.get("anchors")) + if weights: + model.load(weights) + for _, v in model.named_parameters(): + v.requires_grad = True # train all layers + return model + + def get_validator(self): + return DetectionValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, args=self.args) + + def criterion(self, preds, batch): + head = de_parallel(self.model).model[-1] + sort_obj_iou = False + autobalance = False + + # init losses + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device)) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device)) + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets + + # Focal loss + g = self.args.fl_gamma + if self.args.fl_gamma > 0: + BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) + + balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 + ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index + BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance + + def build_targets(p, targets): + # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + nonlocal head + na, nt = head.na, targets.shape[0] # number of anchors, targets + tcls, tbox, indices, anch = [], [], [], [] + gain = torch.ones(7, device=self.device) # normalized to gridspace gain + ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) + targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor( + [ + [0, 0], + [1, 0], + [0, 1], + [-1, 0], + [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], + device=self.device).float() * g # offsets + + for i in range(head.nl): + anchors, shape = head.anchors[i], p[i].shape + gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + t = targets * gain # shape(3,n,7) + if nt: + # Matches + r = t[..., 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare + # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter + + # Offsets + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1 < g) & (gxy > 1)).T + l, m = ((gxi % 1 < g) & (gxi > 1)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 + + # Define + bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors + a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class + gij = (gxy - offsets).long() + gi, gj = gij.T # grid indices + + # Append + indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid + tbox.append(torch.cat((gxy - gij, gwh), 1)) # box + anch.append(anchors[a]) # anchors + tcls.append(c) # class + + return tcls, tbox, indices, anch + + if len(preds) == 2: # eval + _, p = preds + else: # len(3) train + p = preds + + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = targets.to(self.device) + + lcls = torch.zeros(1, device=self.device) + lbox = torch.zeros(1, device=self.device) + lobj = torch.zeros(1, device=self.device) + tcls, tbox, indices, anchors = build_targets(p, targets) + + # Losses + for i, pi in enumerate(p): # layer index, layer predictions + b, a, gj, gi = indices[i] # image, anchor, gridy, gridx + tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj + bs = tobj.shape[0] + n = b.shape[0] # number of targets + if n: + pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, head.nc), 1) # subset of predictions + + # Box regression + pxy = pxy.sigmoid() * 2 - 0.5 + pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] + pbox = torch.cat((pxy, pwh), 1) # predicted box + iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss + + # Objectness + iou = iou.detach().clamp(0).type(tobj.dtype) + if sort_obj_iou: + j = iou.argsort() + b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j] + if gr < 1: + iou = (1.0 - gr) + gr * iou + tobj[b, a, gj, gi] = iou # iou ratio + + # Classification + if head.nc > 1: # cls loss (only if multiple classes) + t = torch.full_like(pcls, cn, device=self.device) # targets + t[range(n), tcls[i]] = cp + lcls += BCEcls(pcls, t) # BCE + + obji = BCEobj(pi[..., 4], tobj) + lobj += obji * balance[i] # obj loss + if autobalance: + balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item() + + if autobalance: + balance = [x / balance[ssi] for x in balance] + lbox *= self.args.box + lobj *= self.args.obj + lcls *= self.args.cls + + loss = lbox + lobj + lcls + return loss * bs, torch.cat((lbox, lobj, lcls)).detach() + + # TODO: improve from API users perspective + def label_loss_items(self, loss_items=None, prefix="train"): + # We should just use named tensors here in future + keys = [f"{prefix}/lbox", f"{prefix}/lobj", f"{prefix}/lcls"] + return dict(zip(keys, loss_items)) if loss_items is not None else keys + + def progress_string(self): + return ('\n' + '%11s' * 6) % \ + ('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'Size') + + def plot_training_samples(self, batch, ni): + images = batch["img"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images(images, batch_idx, cls, bboxes, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg") + + def plot_metrics(self): + plot_results(file=self.csv) # save results.png + + +@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +def train(cfg): + cfg.model = cfg.model or "models/yolov5n.yaml" + cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") + trainer = DetectionTrainer(cfg) + trainer.train() + + +if __name__ == "__main__": + """ + CLI usage: + python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640 + + TODO: + Direct cli support, i.e, yolov8 classify_train args.epochs 10 + """ + train() diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py new file mode 100644 index 0000000..63e9e46 --- /dev/null +++ b/ultralytics/yolo/v8/detect/val.py @@ -0,0 +1,218 @@ +import os + +import hydra +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.yolo.data import build_dataloader +from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG +from ultralytics.yolo.engine.validator import BaseValidator +from ultralytics.yolo.utils import ops +from ultralytics.yolo.utils.checks import check_file, check_requirements +from ultralytics.yolo.utils.files import yaml_load +from ultralytics.yolo.utils.metrics import ConfusionMatrix, Metric, ap_per_class, box_iou, fitness_detection +from ultralytics.yolo.utils.plotting import output_to_target, plot_images +from ultralytics.yolo.utils.torch_utils import de_parallel + + +class DetectionValidator(BaseValidator): + + def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): + super().__init__(dataloader, save_dir, pbar, logger, args) + if self.args.save_json: + check_requirements(['pycocotools']) + self.process = ops.process_mask_upsample # more accurate + else: + self.process = ops.process_mask # faster + self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None + self.is_coco = False + self.class_map = None + self.targets = None + + def preprocess(self, batch): + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 + self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width + self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + self.targets = self.targets.to(self.device) + height, width = batch["img"].shape[2:] + self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device) # to pixels + self.lb = [self.targets[self.targets[:, 0] == i, 1:] + for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling + + return batch + + def init_metrics(self, model): + if self.training: + head = de_parallel(model).model[-1] + else: + head = de_parallel(model).model.model[-1] + + if self.data: + self.is_coco = isinstance(self.data.get('val'), + str) and self.data['val'].endswith(f'coco{os.sep}val2017.txt') + self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + self.nc = head.nc + self.names = model.names + if isinstance(self.names, (list, tuple)): # old format + self.names = dict(enumerate(self.names)) + + self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device) # iou vector for mAP@0.5:0.95 + self.niou = self.iouv.numel() + self.seen = 0 + self.confusion_matrix = ConfusionMatrix(nc=self.nc) + self.metrics = Metric() + self.loss = torch.zeros(4, device=self.device) + self.jdict = [] + self.stats = [] + + def get_desc(self): + return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)") + + def postprocess(self, preds): + preds = ops.non_max_suppression(preds, + self.args.conf_thres, + self.args.iou_thres, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det) + return preds + + def update_metrics(self, preds, batch): + # Metrics + for si, (pred) in enumerate(preds): + labels = self.targets[self.targets[:, 0] == si, 1:] + nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions + shape = batch["ori_shape"][si] + # path = batch["shape"][si][0] + correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init + self.seen += 1 + + if npr == 0: + if nl: + self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), labels[:, 0])) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0]) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = pred.clone() + ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape) # native-space pred + + # Evaluate + if nl: + tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes + ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels + labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels + correct_bboxes = self._process_batch(predn, labelsn, self.iouv) + # TODO: maybe remove these `self.` arguments as they already are member variable + if self.args.plots: + self.confusion_matrix.process_batch(predn, labelsn) + self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls) + + # TODO: Save/log + ''' + if self.args.save_txt: + save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') + if self.args.save_json: + pred_masks = scale_image(im[si].shape[1:], + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1]) + save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary + # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) + ''' + + def get_stats(self): + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy + if len(stats) and stats[0].any(): + results = ap_per_class(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names) + self.metrics.update(results[2:]) + self.nt_per_class = np.bincount(stats[3].astype(int), minlength=self.nc) # number of targets per class + metrics = {"fitness": fitness_detection(np.array(self.metrics.mean_results()).reshape(1, -1))} + metrics |= zip(self.metric_keys, self.metrics.mean_results()) + return metrics + + def print_results(self): + pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format + self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + if self.nt_per_class.sum() == 0: + self.logger.warning( + f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') + + # Print results per class + if (self.args.verbose or (self.nc < 50 and not self.training)) and self.nc > 1 and len(self.stats): + for i, c in enumerate(self.metrics.ap_class_index): + self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) + + if self.args.plots: + self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) + + def _process_batch(self, detections, labels, iouv): + """ + Return correct prediction matrix + Arguments: + detections (array[N, 6]), x1, y1, x2, y2, conf, class + labels (array[M, 5]), class, x1, y1, x2, y2 + Returns: + correct (array[N, 10]), for 10 IoU levels + """ + iou = box_iou(labels[:, 1:], detections[:, :4]) + correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) + correct_class = labels[:, 0:1] == detections[:, 5] + for i in range(len(iouv)): + x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), + 1).cpu().numpy() # [label, detect, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=iouv.device) + + def get_dataloader(self, dataset_path, batch_size): + # TODO: manage splits differently + # calculate stride - check if model is initialized + gs = max(int(de_parallel(self.model).stride if self.model else 0), 32) + return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0] + + # TODO: align with train loss metrics + @property + def metric_keys(self): + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP_0.5(B)", "metrics/mAP_0.5:0.95(B)"] + + def plot_val_samples(self, batch, ni): + images = batch["img"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images(images, + batch_idx, + cls, + bboxes, + paths=paths, + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names) + + def plot_predictions(self, batch, preds, ni): + images = batch["img"] + paths = batch["im_file"] + plot_images(images, *output_to_target(preds, max_det=15), paths, self.save_dir / f'val_batch{ni}_pred.jpg', + self.names) # pred + + +@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +def val(cfg): + cfg.data = cfg.data or "coco128.yaml" + validator = DetectionValidator(args=cfg) + validator(model=cfg.model) + + +if __name__ == "__main__": + val() diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 5a4af69..a5481d2 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -250,7 +250,7 @@ class SegmentationTrainer(BaseTrainer): cls, bboxes, masks, - paths, + paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg") def plot_metrics(self): diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 18dead9..3784fd3 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -252,7 +252,7 @@ class SegmentationValidator(BaseValidator): if len(self.plot_masks): plot_masks = torch.cat(self.plot_masks, dim=0) batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15) - plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf, + plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, conf, paths, self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred self.plot_masks.clear()