From 3a241e4ceacdb2c8a39f32490f2b87fd29710131 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 29 Nov 2022 05:30:08 -0600 Subject: [PATCH] update segment training (#57) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ayush chaurasia --- ultralytics/yolo/data/augment.py | 6 +- ultralytics/yolo/data/base.py | 2 +- ultralytics/yolo/data/build.py | 54 +++---- ultralytics/yolo/engine/trainer.py | 115 ++++++++++----- ultralytics/yolo/engine/validator.py | 42 ++++-- ultralytics/yolo/utils/__init__.py | 11 ++ ultralytics/yolo/utils/configs/default.yaml | 31 ++-- ultralytics/yolo/utils/metrics.py | 47 +++++- ultralytics/yolo/utils/plotting.py | 150 +++++++++++++++++++- ultralytics/yolo/utils/torch_utils.py | 16 +++ ultralytics/yolo/v8/classify/train.py | 5 +- ultralytics/yolo/v8/classify/val.py | 4 + ultralytics/yolo/v8/segment/train.py | 49 ++++--- ultralytics/yolo/v8/segment/val.py | 72 +++++++--- 14 files changed, 460 insertions(+), 144 deletions(-) diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index c67b5c6..f30f630 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -578,8 +578,8 @@ class Albumentations: # TODO: add supports of segments and keypoints if self.transform and random.random() < self.p: new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed - labels["img"] = new["image"] - labels["cls"] = np.array(new["class_labels"]) + labels["img"] = new["image"] + labels["cls"] = np.array(new["class_labels"]) labels["instances"].update(bboxes=bboxes) return labels @@ -635,7 +635,7 @@ class Format: def _format_img(self, img): if len(img.shape) < 3: img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) + img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]) img = torch.from_numpy(img) return img diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 32b9344..9a11f6e 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -151,7 +151,7 @@ class BaseDataset(Dataset): bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index nb = bi[-1] + 1 # number of batches - s = np.array([x["shape"] for x in self.labels]) # hw + s = np.array([x.pop("shape") for x in self.labels]) # hw ar = s[:, 0] / s[:, 1] # aspect ratio irect = ar.argsort() self.im_files = [self.im_files[i] for i in irect] diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py index 9bc4960..3f3b881 100644 --- a/ultralytics/yolo/data/build.py +++ b/ultralytics/yolo/data/build.py @@ -5,7 +5,7 @@ import numpy as np import torch from torch.utils.data import DataLoader, dataloader, distributed -from ..utils import LOGGER +from ..utils import LOGGER, colorstr from ..utils.torch_utils import torch_distributed_zero_first from .dataset import ClassificationDataset, YOLODataset from .utils import PIN_MEMORY, RANK @@ -52,53 +52,36 @@ def seed_worker(worker_id): random.seed(worker_seed) -# TODO: we can inject most args from a config file -def build_dataloader( - img_path, - img_size, # - batch_size, # - single_cls=False, # - hyp=None, # - augment=False, - cache=False, # - image_weights=False, # - stride=32, - label_path=None, - pad=0.0, - rect=False, - rank=-1, - workers=8, - prefix="", - shuffle=False, - use_segments=False, - use_keypoints=False, -): - if rect and shuffle: +def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"): + assert mode in ["train", "val"] + shuffle = mode == "train" + if cfg.rect and shuffle: LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False") shuffle = False with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = YOLODataset( img_path=img_path, - img_size=img_size, - batch_size=batch_size, label_path=label_path, - augment=augment, # augmentation - hyp=hyp, - rect=rect, # rectangular batches - cache=cache, - single_cls=single_cls, + img_size=cfg.img_size, + batch_size=batch_size, + augment=True if mode == "train" else False, # augmentation + hyp=cfg.get("augment_hyp", None), + rect=cfg.rect if mode == "train" else True, # rectangular batches + cache=None if cfg.noval else cfg.get("cache", None), + single_cls=cfg.get("single_cls", False), stride=int(stride), - pad=pad, - prefix=prefix, - use_segments=use_segments, - use_keypoints=use_keypoints, + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + use_segments=cfg.task == "segment", + use_keypoints=cfg.task == "keypoint", ) batch_size = min(batch_size, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices + workers = cfg.workers if mode == "train" else cfg.workers * 2 nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) - loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates + loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates generator = torch.Generator() generator.manual_seed(6148914691236517205 + RANK) return ( @@ -118,6 +101,7 @@ def build_dataloader( # build classification +# TODO: using cfg like `build_dataloader` def build_classification_dataloader(path, imgsz=224, batch_size=16, diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 8f1f7b8..8f31987 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -24,11 +24,11 @@ from tqdm import tqdm import ultralytics.yolo.utils as utils import ultralytics.yolo.utils.callbacks as callbacks from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT +from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils.checks import print_args from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.modeling import get_model -from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle +from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" RANK = int(os.getenv('RANK', -1)) @@ -48,13 +48,15 @@ class BaseTrainer: self.wdir = self.save_dir / 'weights' # weights dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths + self.batch_size = self.args.batch_size + self.epochs = self.args.epochs print_args(dict(self.args)) # Save run settings save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # device - self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size) + self.device = utils.torch_utils.select_device(self.args.device, self.batch_size) self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') # Model and Dataloaders. @@ -73,10 +75,11 @@ class BaseTrainer: self.scheduler = None # epoch level metrics - self.metrics = {} # handle metrics returned by validator self.best_fitness = None self.fitness = None self.loss = None + self.tloss = None + self.csv = self.save_dir / 'results.csv' for callback, func in callbacks.default_callbacks.items(): self.add_callback(callback, func) @@ -122,6 +125,7 @@ class BaseTrainer: if world_size > 1: mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) else: + # self._do_train(int(os.getenv("RANK", -1)), world_size) self._do_train() def _setup_ddp(self, rank, world_size): @@ -129,21 +133,20 @@ class BaseTrainer: os.environ['MASTER_PORT'] = '9020' torch.cuda.set_device(rank) self.device = torch.device('cuda', rank) - print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") + self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) self.model = self.model.to(self.device) self.model = DDP(self.model, device_ids=[rank]) - self.args.batch_size = self.args.batch_size // world_size - def _setup_train(self, rank): + def _setup_train(self, rank, world_size): """ Builds dataloaders and optimizer on correct rank process """ # Optimizer self.set_model_attributes() - accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing - self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay + self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing + self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay self.optimizer = build_optimizer(model=self.model, name=self.args.optimizer, lr=self.args.lr0, @@ -151,18 +154,21 @@ class BaseTrainer: decay=self.args.weight_decay) # Scheduler if self.args.cos_lr: - self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf'] + self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] else: - self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear + self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) # dataloaders - self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) + batch_size = self.batch_size // world_size + self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train") if rank in {0, -1}: - print(" Creating testloader rank :", rank) - self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1) - self.validator = self.get_validator() - print("created testloader :", rank) + self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val") + validator = self.get_validator() + # init metric, for plot_results + metric_keys = validator.metric_keys + self.label_loss_items(prefix="val") + self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) + self.validator = validator self.ema = ModelEMA(self.model) def _do_train(self, rank=-1, world_size=1): @@ -172,7 +178,7 @@ class BaseTrainer: self.model = self.model.to(self.device) self.trigger_callbacks("before_train") - self._setup_train(rank) + self._setup_train(rank, world_size) self.epoch = 0 self.epoch_time = None @@ -181,13 +187,17 @@ class BaseTrainer: nb = len(self.train_loader) # number of batches nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations last_opt_step = -1 - for epoch in range(self.args.epochs): + for epoch in range(self.epochs): self.trigger_callbacks("on_epoch_start") self.model.train() + if rank != -1: + self.train_loader.sampler.set_epoch(epoch) pbar = enumerate(self.train_loader) if rank in {-1, 0}: + self.console.info(self.progress_string()) pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT) self.tloss = None + self.optimizer.zero_grad() for i, batch in pbar: self.trigger_callbacks("on_batch_start") # forward @@ -197,7 +207,7 @@ class BaseTrainer: ni = i + nb * epoch if ni <= nw: xi = [0, nw] # x interp - accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round()) + self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()) for j, x in enumerate(self.optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp( @@ -207,37 +217,47 @@ class BaseTrainer: preds = self.model(batch["img"]) self.loss, self.loss_items = self.criterion(preds, batch) + if rank != -1: + self.loss *= world_size self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ else self.loss_items # backward - self.model.zero_grad(set_to_none=True) self.scaler.scale(self.loss).backward() # optimize - if ni - last_opt_step >= accumulate: + if ni - last_opt_step >= self.accumulate: self.optimizer_step() last_opt_step = ni # log - mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) + mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) if rank in {-1, 0}: pbar.set_description( - (" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem, - *losses, batch["img"].shape[-1])) + ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % + (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])) self.trigger_callbacks('on_batch_end') + if self.args.plots and ni < 3: + self.plot_training_samples(batch, ni) + + lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.scheduler.step() if rank in [-1, 0]: # validation self.trigger_callbacks('on_val_start') self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) - self.metrics, self.fitness = self.validate() + final_epoch = (epoch + 1 == self.epochs) + if not self.args.noval or final_epoch: + self.metrics, self.fitness = self.validate() self.trigger_callbacks('on_val_end') + log_vals = self.label_loss_items(self.tloss) | self.metrics | lr + self.save_metrics(metrics=log_vals) # save model - if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs): + if (not self.args.nosave) or (self.epoch + 1 == self.epochs): self.save_model() self.trigger_callbacks('on_model_save') @@ -248,9 +268,15 @@ class BaseTrainer: # TODO: termination condition - self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") - self.trigger_callbacks('on_train_end') + if rank in [-1, 0]: + # do the last evaluation with best.pt + self.final_eval() + if self.args.plots: + self.plot_metrics() + self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") + self.trigger_callbacks('on_train_end') dist.destroy_process_group() if world_size != 1 else None + torch.cuda.empty_cache() def save_model(self): ckpt = { @@ -306,7 +332,7 @@ class BaseTrainer: "fitness" metric. """ metrics = self.validator(self) - fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found + fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found if not self.best_fitness or self.best_fitness < fitness: self.best_fitness = self.fitness return metrics, fitness @@ -339,12 +365,12 @@ class BaseTrainer: """ raise NotImplementedError("criterion function not implemented in trainer") - def label_loss_items(self, loss_items): + def label_loss_items(self, loss_items=None, prefix="train"): """ Returns a loss dict with labelled training loss items tensor """ # Not needed for classification but necessary for segmentation & detection - return {"loss": loss_items} + return {"loss": loss_items} if loss_items is not None else ["loss"] def set_model_attributes(self): """ @@ -355,6 +381,31 @@ class BaseTrainer: def build_targets(self, preds, targets): pass + def progress_string(self): + return "" + + # TODO: may need to put these following functions into callback + def plot_training_samples(self, batch, ni): + pass + + def save_metrics(self, metrics): + keys, vals = list(metrics.keys()), list(metrics.values()) + n = len(metrics) + 1 # number of cols + s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header + with open(self.csv, 'a') as f: + f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n') + + def plot_metrics(self): + pass + + def final_eval(self): + # TODO: need standalone evaluator to do this + for f in self.last, self.best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if f is self.best: + self.console.info(f'\nValidating {f}...') + def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils? @@ -382,7 +433,7 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) - LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups " + LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias") return optimizer diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 24a840e..bc2fdf7 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path import torch from omegaconf import OmegaConf @@ -6,6 +7,7 @@ from tqdm import tqdm from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils import TQDM_BAR_FORMAT +from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.torch_utils import de_parallel, select_device @@ -15,16 +17,17 @@ class BaseValidator: Base validator class. """ - def __init__(self, dataloader, pbar=None, logger=None, args=None): + def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None): self.dataloader = dataloader self.pbar = pbar self.logger = logger or logging.getLogger() self.args = args or OmegaConf.load(DEFAULT_CONFIG) self.device = select_device(self.args.device, dataloader.batch_size) + self.save_dir = save_dir if save_dir is not None else \ + increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) self.cuda = self.device.type != 'cpu' self.batch_i = None self.training = True - self.loss = None def __call__(self, trainer=None, model=None): """ @@ -35,20 +38,22 @@ class BaseValidator: if self.training: model = trainer.ema.ema or trainer.model self.args.half &= self.device.type != 'cpu' - # NOTE: half() inference in evaluation will make training stuck, - # so I comment it out for now, I think we can reuse half mode after we add EMA. model = model.half() if self.args.half else model.float() + loss = torch.zeros_like(trainer.loss_items, device=trainer.device) else: # TODO: handle this when detectMultiBackend is supported assert model is not None, "Either trainer or model is needed for validation" # model = DetectMultiBacked(model) # TODO: implement init_model_attributes() model.eval() + dt = Profile(), Profile(), Profile(), Profile() - self.loss = 0 n_batches = len(self.dataloader) desc = self.get_desc() - bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT) + # NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training, + # so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py. + # bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT) + bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT) self.init_metrics(de_parallel(model)) with torch.no_grad(): for batch_i, batch in enumerate(bar): @@ -59,20 +64,23 @@ class BaseValidator: # inference with dt[1]: - preds = model(batch["img"].float()) + preds = model(batch["img"]) # TODO: remember to add native augmentation support when implementing model, like: # preds, train_out = model(im, augment=augment) # loss with dt[2]: if self.training: - self.loss += trainer.criterion(preds, batch)[0] + loss += trainer.criterion(preds, batch)[1] # pre-process predictions with dt[3]: preds = self.postprocess(preds) self.update_metrics(preds, batch) + if self.args.plots and batch_i < 3: + self.plot_val_samples(batch, batch_i) + self.plot_predictions(batch, preds, batch_i) stats = self.get_stats() self.check_stats(stats) @@ -81,7 +89,7 @@ class BaseValidator: # print speeds if not self.training: - t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image + t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image # shape = (self.dataloader.batch_size, 3, imgsz, imgsz) self.logger.info( 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t) @@ -90,7 +98,8 @@ class BaseValidator: model.float() # TODO: implement save json - return stats + return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \ + if self.training else stats def preprocess(self, batch): return batch @@ -105,7 +114,7 @@ class BaseValidator: pass def get_stats(self): - pass + return {} def check_stats(self, stats): pass @@ -115,3 +124,14 @@ class BaseValidator: def get_desc(self): pass + + @property + def metric_keys(self): + return [] + + # TODO: may need to put these following functions into callback + def plot_val_samples(self, batch, ni): + pass + + def plot_predictions(self, batch, preds, ni): + pass diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 0216ec3..c171a48 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -3,6 +3,7 @@ import logging.config import os import platform import sys +import threading from pathlib import Path # Constants @@ -130,3 +131,13 @@ class TryExcept(contextlib.ContextDecorator): if value: print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) return True + + +def threaded(func): + # Multi-threads a target function and returns thread. Usage: @threaded decorator + def wrapper(*args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + + return wrapper diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml index 99c1c10..fe50a3d 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/utils/configs/default.yaml @@ -26,11 +26,11 @@ deterministic: True local_rank: -1 single_cls: False # train multi-class data as single-class image_weights: False # use weighted image selection for training -shuffle: True rect: False # support rectangular training cos_lr: False # Use cosine LR scheduler overlap_mask: True # Segmentation masks overlap mask_ratio: 4 # Segmentation mask downsample ratio +noval: False # Val/Test settings ---------------------------------------------------------------------------------------------------- save_json: False @@ -43,7 +43,7 @@ plots: False save_txt: False # Hyperparameters ------------------------------------------------------------------------------------------------------ -lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) momentum: 0.937 # SGD momentum/Adam beta1 weight_decay: 0.0005 # optimizer weight decay 5e-4 @@ -59,22 +59,23 @@ iou_t: 0.20 # IoU training threshold anchor_t: 4.0 # anchor-multiple threshold # anchors: 3 # anchors per output layer (0 to ignore) fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) -hsv_h: 0.015 # image HSV-Hue augmentation (fraction) -hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) -hsv_v: 0.4 # image HSV-Value augmentation (fraction) -degrees: 0.0 # image rotation (+/- deg) -translate: 0.1 # image translation (+/- fraction) -scale: 0.5 # image scale (+/- gain) -shear: 0.0 # image shear (+/- deg) -perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 -flipud: 0.0 # image flip up-down (probability) -fliplr: 0.5 # image flip left-right (probability) -mosaic: 1.0 # image mosaic (probability) -mixup: 0.0 # image mixup (probability) -copy_paste: 0.0 # segment copy-paste (probability) label_smoothing: 0.0 nbs: 64 # nominal batch size # anchors: 3 +augment_hyp: + hsv_h: 0.015 # image HSV-Hue augmentation (fraction) + hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) + hsv_v: 0.4 # image HSV-Value augmentation (fraction) + degrees: 0.0 # image rotation (+/- deg) + translate: 0.1 # image translation (+/- fraction) + scale: 0.5 # image scale (+/- gain) + shear: 0.0 # image shear (+/- deg) + perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 + flipud: 0.0 # image flip up-down (probability) + fliplr: 0.5 # image flip left-right (probability) + mosaic: 1.0 # image mosaic (probability) + mixup: 0.0 # image mixup (probability) + copy_paste: 0.0 # segment copy-paste (probability) # Hydra configs -------------------------------------------------------------------------------------------------------- hydra: diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 62bdcc9..1e05843 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -283,6 +283,50 @@ def smooth(y, f=0.05): return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed +def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): + # Precision-recall curve + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + py = np.stack(py, axis=1) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) + + ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title('Precision-Recall Curve') + fig.savefig(save_dir, dpi=250) + plt.close(fig) + + +def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'): + # Metric-confidence curve + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py): + ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) + else: + ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) + + y = smooth(py.mean(0), 0.05) + ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f'{ylabel}-Confidence Curve') + fig.savefig(save_dir, dpi=250) + plt.close(fig) + + def compute_ap(recall, precision): """ Compute the average precision, given the recall and precision curves # Arguments @@ -365,14 +409,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names f1 = 2 * p * r / (p + r + eps) names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = dict(enumerate(names)) # to dict - # TODO: plot - ''' if plot: plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names) plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1') plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision') plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall') - ''' i = smooth(f1.mean(0), 0.1).argmax() # max F1 index p, r, f1 = p[:, i], r[:, i], f1[:, i] diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py index c983eee..2bad281 100644 --- a/ultralytics/yolo/utils/plotting.py +++ b/ultralytics/yolo/utils/plotting.py @@ -1,12 +1,16 @@ +import contextlib +import math from pathlib import Path from urllib.error import URLError import cv2 +import matplotlib.pyplot as plt import numpy as np +import pandas as pd import torch from PIL import Image, ImageDraw, ImageFont -from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR +from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded from .checks import check_font, check_requirements, is_ascii from .files import increment_path @@ -179,3 +183,147 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB return crop + + +@threaded +def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None): + # Plot image grid with labels + if isinstance(images, torch.Tensor): + images = images.cpu().float().numpy() + if isinstance(cls, torch.Tensor): + cls = cls.cpu().numpy() + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + if isinstance(masks, torch.Tensor): + masks = masks.cpu().numpy().astype(int) + if isinstance(batch_idx, torch.Tensor): + batch_idx = batch_idx.cpu().numpy() + + max_size = 1920 # max image size + max_subplots = 16 # max image subplots, i.e. 4x4 + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs ** 0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) + + # Build Image + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i, im in enumerate(images): + if i == max_subplots: # if last batch has fewer images than we expect + break + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + im = im.transpose(1, 2, 0) + mosaic[y:y + h, x:x + w, :] = im + + # Resize (optional) + scale = max_size / ns / max(h, w) + if scale < 1: + h = math.ceil(scale * h) + w = math.ceil(scale * w) + mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) + + # Annotate + fs = int((h + w) * ns * 0.01) # font size + annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) + for i in range(i + 1): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders + if paths: + annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames + if len(cls) > 0: + idx = batch_idx == i + + boxes = xywh2xyxy(bboxes[idx]).T + classes = cls[idx].astype('int') + labels = confs is None # labels if no conf column + conf = None if labels else confs[idx] # check for confidence presence (label vs pred) + + if boxes.shape[1]: + if boxes.max() <= 1.01: # if normalized with tolerance 0.01 + boxes[[0, 2]] *= w # scale to pixels + boxes[[1, 3]] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes *= scale + boxes[[0, 2]] += x + boxes[[1, 3]] += y + for j, box in enumerate(boxes.T.tolist()): + c = classes[j] + color = colors(c) + c = names[c] if names else c + if labels or conf[j] > 0.25: # 0.25 conf thresh + label = f'{c}' if labels else f'{c} {conf[j]:.1f}' + annotator.box_label(box, label, color=color) + + # Plot masks + if len(masks): + if masks.max() > 1.0: # mean that masks are overlap + image_masks = masks[[i]] # (1, 640, 640) + nl = idx.sum() + index = np.arange(nl).reshape(nl, 1, 1) + 1 + image_masks = np.repeat(image_masks, nl, axis=0) + image_masks = np.where(image_masks == index, 1.0, 0.0) + else: + image_masks = masks[idx] + + im = np.asarray(annotator.im).copy() + for j, box in enumerate(boxes.T.tolist()): + if labels or conf[j] > 0.25: # 0.25 conf thresh + color = colors(classes[j]) + mh, mw = image_masks[j].shape + if mh != h or mw != w: + mask = image_masks[j].astype(np.uint8) + mask = cv2.resize(mask, (w, h)) + mask = mask.astype(bool) + else: + mask = image_masks[j].astype(bool) + with contextlib.suppress(Exception): + im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6 + annotator.fromarray(im) + annotator.im.save(fname) # save + + +def plot_results_with_masks(file="path/to/results.csv", dir="", best=True): + # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') + save_dir = Path(file).parent if file else Path(dir) + fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) + ax = ax.ravel() + files = list(save_dir.glob("results*.csv")) + assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." + for f in files: + try: + data = pd.read_csv(f) + index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] + + 0.1 * data.values[:, 11]) + s = [x.strip() for x in data.columns] + x = data.values[:, 0] + for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]): + y = data.values[:, j] + # y[y == 0] = np.nan # don't show zero values + ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2) + if best: + # best + ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3) + ax[i].set_title(s[j] + f"\n{round(y[index], 5)}") + else: + # last + ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3) + ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}") + # if j in [8, 9, 10]: # share train and val loss y axes + # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) + except Exception as e: + print(f"Warning: Plotting error for {f}: {e}") + ax[1].legend() + fig.savefig(save_dir / "results.png", dpi=200) + plt.close() + + +def output_to_target(output, max_det=300): + # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting + targets = [] + for i, o in enumerate(output): + box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1)) + targets = torch.cat(targets, 0).numpy() + return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6] diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 645b860..dea42d8 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -245,3 +245,19 @@ class ModelEMA: def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes copy_attr(self.ema, model, include, exclude) + + +def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() + # Strip optimizer from 'f' to finalize training, optionally save as 's' + x = torch.load(f, map_location=torch.device('cpu')) + if x.get('ema'): + x['model'] = x['ema'] # replace model with ema + for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys + x[k] = None + x['epoch'] = -1 + x['model'].half() # to FP16 + for p in x['model'].parameters(): + p.requires_grad = False + torch.save(x, s or f) + mb = os.path.getsize(s or f) / 1E6 # filesize + LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 4037b83..fa00ab2 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel class ClassificationTrainer(BaseTrainer): + def set_model_attributes(self): + self.model.names = self.data["names"] + def load_model(self, model_cfg, weights, data): # TODO: why treat clf models as unique. We should have clf yamls? if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision @@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer): ClassificationModel.reshape_outputs(model, data["nc"]) return model - def get_dataloader(self, dataset_path, batch_size=None, rank=0): + def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"): return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, batch_size=batch_size, diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index 9fcfc6e..ae5e5bd 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator): acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy top1, top5 = acc.mean(0).tolist() return {"top1": top1, "top5": top5, "fitness": top5} + + @property + def metric_keys(self): + return ["top1", "top5"] diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 16f2ea9..0e1cb54 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE from ultralytics.yolo.utils.modeling.tasks import SegmentationModel from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy +from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks from ultralytics.yolo.utils.torch_utils import de_parallel # BaseTrainer python usage class SegmentationTrainer(BaseTrainer): - def get_dataloader(self, dataset_path, batch_size, rank=0): + def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0): # TODO: manage splits differently # calculate stride - check if model is initialized gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) - return build_dataloader( - img_path=dataset_path, - img_size=self.args.img_size, - batch_size=batch_size, - single_cls=self.args.single_cls, - cache=self.args.cache, - image_weights=self.args.image_weights, - stride=gs, - rect=self.args.rect, - rank=rank, - workers=self.args.workers, - shuffle=self.args.shuffle, - use_segments=True, - )[0] + return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0] def preprocess_batch(self, batch): batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 @@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer): self.model.names = self.data["names"] def get_validator(self): - return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console) + return v8.segment.SegmentationValidator(self.test_loader, + save_dir=self.save_dir, + logger=self.console, + args=self.args) def criterion(self, preds, batch): head = de_parallel(self.model).model[-1] @@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer): else: mask_gti = masks[tidxs[i]][j] lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j]) + else: + lseg += (proto * 0).sum() obji = BCEobj(pi[..., 4], tobj) lobj += obji * balance[i] # obj loss @@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer): loss = lbox + lobj + lcls + lseg return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach() - def label_loss_items(self, loss_items): + def label_loss_items(self, loss_items=None, prefix="train"): # We should just use named tensors here in future - keys = ["lbox", "lseg", "lobj", "lcls"] - return dict(zip(keys, loss_items)) + keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"] + return dict(zip(keys, loss_items)) if loss_items is not None else keys def progress_string(self): return ('\n' + '%11s' * 7) % \ ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size') + def plot_training_samples(self, batch, ni): + images = batch["img"] + masks = batch["masks"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images_and_masks(images, + batch_idx, + cls, + bboxes, + masks, + paths, + fname=self.save_dir / f"train_batch{ni}.jpg") + + def plot_metrics(self): + plot_results_with_masks(file=self.csv) # save results.png + @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) def train(cfg): diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 372f306..2669eca 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -6,23 +6,24 @@ import torch.nn.functional as F from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.utils import ops -from ultralytics.yolo.utils.checks import check_requirements +from ultralytics.yolo.utils.checks import check_file, check_requirements from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou, fitness_segmentation, mask_iou) +from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks from ultralytics.yolo.utils.torch_utils import de_parallel class SegmentationValidator(BaseValidator): - def __init__(self, dataloader, pbar=None, logger=None, args=None): - super().__init__(dataloader, pbar, logger, args) + def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None): + super().__init__(dataloader, save_dir, pbar, logger, args) if self.args.save_json: check_requirements(['pycocotools']) self.process = ops.process_mask_upsample # more accurate else: self.process = ops.process_mask # faster - self.data_dict = yaml_load(self.args.data) if self.args.data else None + self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None self.is_coco = False self.class_map = None self.targets = None @@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator): self.loss = torch.zeros(4, device=self.device) self.jdict = [] self.stats = [] + self.plot_masks = [] def get_desc(self): return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", @@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator): def update_metrics(self, preds, batch): # Metrics - plot_masks = [] # masks for plotting for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): labels = self.targets[self.targets[:, 0] == si, 1:] nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions - shape = batch["shape"][si] + shape = batch["ori_shape"][si] # path = batch["shape"][si][0] correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init @@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator): pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) if self.args.plots and self.batch_i < 3: - plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot + self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot # TODO: Save/log ''' @@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator): # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) ''' - # TODO Plot images - ''' - if self.args.plots and self.batch_i < 3: - if len(plot_masks): - plot_masks = torch.cat(plot_masks, dim=0) - plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) - plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths, - save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred - ''' - def get_stats(self): stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy if len(stats) and stats[0].any(): - # TODO: save_dir - results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names) + results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names) self.metrics.update(results) self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class - keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"] metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))} - metrics |= zip(keys, self.metrics.mean_results()) + metrics |= zip(self.metric_keys, self.metrics.mean_results()) return metrics def print_results(self): @@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator): for i, c in enumerate(self.metrics.ap_class_index): self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) - # plot TODO: save_dir if self.args.plots: - self.confusion_matrix.plot(save_dir='', names=list(self.names.values())) + self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): """ @@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator): matches = matches[np.unique(matches[:, 0], return_index=True)[1]] correct[matches[:, 1].astype(int), i] = True return torch.tensor(correct, dtype=torch.bool, device=iouv.device) + + @property + def metric_keys(self): + return [ + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP_0.5(B)", + "metrics/mAP_0.5:0.95(B)", # metrics + "metrics/precision(M)", + "metrics/recall(M)", + "metrics/mAP_0.5(M)", + "metrics/mAP_0.5:0.95(M)",] + + def plot_val_samples(self, batch, ni): + images = batch["img"] + masks = batch["masks"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images_and_masks(images, + batch_idx, + cls, + bboxes, + masks, + paths, + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names) + + def plot_predictions(self, batch, preds, ni): + images = batch["img"] + paths = batch["im_file"] + if len(self.plot_masks): + plot_masks = torch.cat(self.plot_masks, dim=0) + batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15) + plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf, + self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred + self.plot_masks.clear()