update segment training (#57)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Laughing 2 years ago committed by GitHub
parent d0b0fe2592
commit 3a241e4cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -635,7 +635,7 @@ class Format:
def _format_img(self, img): def _format_img(self, img):
if len(img.shape) < 3: if len(img.shape) < 3:
img = np.expand_dims(img, -1) img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1)) img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
img = torch.from_numpy(img) img = torch.from_numpy(img)
return img return img

@ -151,7 +151,7 @@ class BaseDataset(Dataset):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches nb = bi[-1] + 1 # number of batches
s = np.array([x["shape"] for x in self.labels]) # hw s = np.array([x.pop("shape") for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort() irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect] self.im_files = [self.im_files[i] for i in irect]

@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from torch.utils.data import DataLoader, dataloader, distributed from torch.utils.data import DataLoader, dataloader, distributed
from ..utils import LOGGER from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset from .dataset import ClassificationDataset, YOLODataset
from .utils import PIN_MEMORY, RANK from .utils import PIN_MEMORY, RANK
@ -52,53 +52,36 @@ def seed_worker(worker_id):
random.seed(worker_seed) random.seed(worker_seed)
# TODO: we can inject most args from a config file def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
def build_dataloader( assert mode in ["train", "val"]
img_path, shuffle = mode == "train"
img_size, # if cfg.rect and shuffle:
batch_size, #
single_cls=False, #
hyp=None, #
augment=False,
cache=False, #
image_weights=False, #
stride=32,
label_path=None,
pad=0.0,
rect=False,
rank=-1,
workers=8,
prefix="",
shuffle=False,
use_segments=False,
use_keypoints=False,
):
if rect and shuffle:
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False") LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset( dataset = YOLODataset(
img_path=img_path, img_path=img_path,
img_size=img_size,
batch_size=batch_size,
label_path=label_path, label_path=label_path,
augment=augment, # augmentation img_size=cfg.img_size,
hyp=hyp, batch_size=batch_size,
rect=rect, # rectangular batches augment=True if mode == "train" else False, # augmentation
cache=cache, hyp=cfg.get("augment_hyp", None),
single_cls=single_cls, rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=None if cfg.noval else cfg.get("cache", None),
single_cls=cfg.get("single_cls", False),
stride=int(stride), stride=int(stride),
pad=pad, pad=0.0 if mode == "train" else 0.5,
prefix=prefix, prefix=colorstr(f"{mode}: "),
use_segments=use_segments, use_segments=cfg.task == "segment",
use_keypoints=use_keypoints, use_keypoints=cfg.task == "keypoint",
) )
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK) generator.manual_seed(6148914691236517205 + RANK)
return ( return (
@ -118,6 +101,7 @@ def build_dataloader(
# build classification # build classification
# TODO: using cfg like `build_dataloader`
def build_classification_dataloader(path, def build_classification_dataloader(path,
imgsz=224, imgsz=224,
batch_size=16, batch_size=16,

@ -24,11 +24,11 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.callbacks as callbacks import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import print_args from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
@ -48,13 +48,15 @@ class BaseTrainer:
self.wdir = self.save_dir / 'weights' # weights dir self.wdir = self.save_dir / 'weights' # weights dir
self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size
self.epochs = self.args.epochs
print_args(dict(self.args)) print_args(dict(self.args))
# Save run settings # Save run settings
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
# device # device
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size) self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders. # Model and Dataloaders.
@ -73,10 +75,11 @@ class BaseTrainer:
self.scheduler = None self.scheduler = None
# epoch level metrics # epoch level metrics
self.metrics = {} # handle metrics returned by validator
self.best_fitness = None self.best_fitness = None
self.fitness = None self.fitness = None
self.loss = None self.loss = None
self.tloss = None
self.csv = self.save_dir / 'results.csv'
for callback, func in callbacks.default_callbacks.items(): for callback, func in callbacks.default_callbacks.items():
self.add_callback(callback, func) self.add_callback(callback, func)
@ -122,6 +125,7 @@ class BaseTrainer:
if world_size > 1: if world_size > 1:
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
else: else:
# self._do_train(int(os.getenv("RANK", -1)), world_size)
self._do_train() self._do_train()
def _setup_ddp(self, rank, world_size): def _setup_ddp(self, rank, world_size):
@ -129,21 +133,20 @@ class BaseTrainer:
os.environ['MASTER_PORT'] = '9020' os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank) self.device = torch.device('cuda', rank)
print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.model = DDP(self.model, device_ids=[rank]) self.model = DDP(self.model, device_ids=[rank])
self.args.batch_size = self.args.batch_size // world_size
def _setup_train(self, rank): def _setup_train(self, rank, world_size):
""" """
Builds dataloaders and optimizer on correct rank process Builds dataloaders and optimizer on correct rank process
""" """
# Optimizer # Optimizer
self.set_model_attributes() self.set_model_attributes()
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
self.optimizer = build_optimizer(model=self.model, self.optimizer = build_optimizer(model=self.model,
name=self.args.optimizer, name=self.args.optimizer,
lr=self.args.lr0, lr=self.args.lr0,
@ -151,18 +154,21 @@ class BaseTrainer:
decay=self.args.weight_decay) decay=self.args.weight_decay)
# Scheduler # Scheduler
if self.args.cos_lr: if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf'] self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else: else:
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# dataloaders # dataloaders
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) batch_size = self.batch_size // world_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
if rank in {0, -1}: if rank in {0, -1}:
print(" Creating testloader rank :", rank) self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1) validator = self.get_validator()
self.validator = self.get_validator() # init metric, for plot_results
print("created testloader :", rank) metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.validator = validator
self.ema = ModelEMA(self.model) self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1): def _do_train(self, rank=-1, world_size=1):
@ -172,7 +178,7 @@ class BaseTrainer:
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.trigger_callbacks("before_train") self.trigger_callbacks("before_train")
self._setup_train(rank) self._setup_train(rank, world_size)
self.epoch = 0 self.epoch = 0
self.epoch_time = None self.epoch_time = None
@ -181,13 +187,17 @@ class BaseTrainer:
nb = len(self.train_loader) # number of batches nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1 last_opt_step = -1
for epoch in range(self.args.epochs): for epoch in range(self.epochs):
self.trigger_callbacks("on_epoch_start") self.trigger_callbacks("on_epoch_start")
self.model.train() self.model.train()
if rank != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader) pbar = enumerate(self.train_loader)
if rank in {-1, 0}: if rank in {-1, 0}:
self.console.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT) pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
self.tloss = None self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar: for i, batch in pbar:
self.trigger_callbacks("on_batch_start") self.trigger_callbacks("on_batch_start")
# forward # forward
@ -197,7 +207,7 @@ class BaseTrainer:
ni = i + nb * epoch ni = i + nb * epoch
if ni <= nw: if ni <= nw:
xi = [0, nw] # x interp xi = [0, nw] # x interp
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round()) self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
for j, x in enumerate(self.optimizer.param_groups): for j, x in enumerate(self.optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp( x['lr'] = np.interp(
@ -207,37 +217,47 @@ class BaseTrainer:
preds = self.model(batch["img"]) preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch) self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items else self.loss_items
# backward # backward
self.model.zero_grad(set_to_none=True)
self.scaler.scale(self.loss).backward() self.scaler.scale(self.loss).backward()
# optimize # optimize
if ni - last_opt_step >= accumulate: if ni - last_opt_step >= self.accumulate:
self.optimizer_step() self.optimizer_step()
last_opt_step = ni last_opt_step = ni
# log # log
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if rank in {-1, 0}: if rank in {-1, 0}:
pbar.set_description( pbar.set_description(
(" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem, ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
*losses, batch["img"].shape[-1])) (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
self.trigger_callbacks('on_batch_end') self.trigger_callbacks('on_batch_end')
if self.args.plots and ni < 3:
self.plot_training_samples(batch, ni)
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step()
if rank in [-1, 0]: if rank in [-1, 0]:
# validation # validation
self.trigger_callbacks('on_val_start') self.trigger_callbacks('on_val_start')
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == self.epochs)
if not self.args.noval or final_epoch:
self.metrics, self.fitness = self.validate() self.metrics, self.fitness = self.validate()
self.trigger_callbacks('on_val_end') self.trigger_callbacks('on_val_end')
log_vals = self.label_loss_items(self.tloss) | self.metrics | lr
self.save_metrics(metrics=log_vals)
# save model # save model
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs): if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
self.save_model() self.save_model()
self.trigger_callbacks('on_model_save') self.trigger_callbacks('on_model_save')
@ -248,9 +268,15 @@ class BaseTrainer:
# TODO: termination condition # TODO: termination condition
if rank in [-1, 0]:
# do the last evaluation with best.pt
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
self.trigger_callbacks('on_train_end') self.trigger_callbacks('on_train_end')
dist.destroy_process_group() if world_size != 1 else None dist.destroy_process_group() if world_size != 1 else None
torch.cuda.empty_cache()
def save_model(self): def save_model(self):
ckpt = { ckpt = {
@ -306,7 +332,7 @@ class BaseTrainer:
"fitness" metric. "fitness" metric.
""" """
metrics = self.validator(self) metrics = self.validator(self)
fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness: if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = self.fitness self.best_fitness = self.fitness
return metrics, fitness return metrics, fitness
@ -339,12 +365,12 @@ class BaseTrainer:
""" """
raise NotImplementedError("criterion function not implemented in trainer") raise NotImplementedError("criterion function not implemented in trainer")
def label_loss_items(self, loss_items): def label_loss_items(self, loss_items=None, prefix="train"):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor
""" """
# Not needed for classification but necessary for segmentation & detection # Not needed for classification but necessary for segmentation & detection
return {"loss": loss_items} return {"loss": loss_items} if loss_items is not None else ["loss"]
def set_model_attributes(self): def set_model_attributes(self):
""" """
@ -355,6 +381,31 @@ class BaseTrainer:
def build_targets(self, preds, targets): def build_targets(self, preds, targets):
pass pass
def progress_string(self):
return ""
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
pass
def save_metrics(self, metrics):
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
def plot_metrics(self):
pass
def final_eval(self):
# TODO: need standalone evaluator to do this
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
self.console.info(f'\nValidating {f}...')
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils? # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
@ -382,7 +433,7 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups " LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias") f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
return optimizer return optimizer

@ -1,4 +1,5 @@
import logging import logging
from pathlib import Path
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -6,6 +7,7 @@ from tqdm import tqdm
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import TQDM_BAR_FORMAT from ultralytics.yolo.utils import TQDM_BAR_FORMAT
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
@ -15,16 +17,17 @@ class BaseValidator:
Base validator class. Base validator class.
""" """
def __init__(self, dataloader, pbar=None, logger=None, args=None): def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
self.logger = logger or logging.getLogger() self.logger = logger or logging.getLogger()
self.args = args or OmegaConf.load(DEFAULT_CONFIG) self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.device = select_device(self.args.device, dataloader.batch_size) self.device = select_device(self.args.device, dataloader.batch_size)
self.save_dir = save_dir if save_dir is not None else \
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
self.cuda = self.device.type != 'cpu' self.cuda = self.device.type != 'cpu'
self.batch_i = None self.batch_i = None
self.training = True self.training = True
self.loss = None
def __call__(self, trainer=None, model=None): def __call__(self, trainer=None, model=None):
""" """
@ -35,20 +38,22 @@ class BaseValidator:
if self.training: if self.training:
model = trainer.ema.ema or trainer.model model = trainer.ema.ema or trainer.model
self.args.half &= self.device.type != 'cpu' self.args.half &= self.device.type != 'cpu'
# NOTE: half() inference in evaluation will make training stuck,
# so I comment it out for now, I think we can reuse half mode after we add EMA.
model = model.half() if self.args.half else model.float() model = model.half() if self.args.half else model.float()
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
else: # TODO: handle this when detectMultiBackend is supported else: # TODO: handle this when detectMultiBackend is supported
assert model is not None, "Either trainer or model is needed for validation" assert model is not None, "Either trainer or model is needed for validation"
# model = DetectMultiBacked(model) # model = DetectMultiBacked(model)
# TODO: implement init_model_attributes() # TODO: implement init_model_attributes()
model.eval() model.eval()
dt = Profile(), Profile(), Profile(), Profile() dt = Profile(), Profile(), Profile(), Profile()
self.loss = 0
n_batches = len(self.dataloader) n_batches = len(self.dataloader)
desc = self.get_desc() desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT) # NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training,
# so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py.
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model)) self.init_metrics(de_parallel(model))
with torch.no_grad(): with torch.no_grad():
for batch_i, batch in enumerate(bar): for batch_i, batch in enumerate(bar):
@ -59,20 +64,23 @@ class BaseValidator:
# inference # inference
with dt[1]: with dt[1]:
preds = model(batch["img"].float()) preds = model(batch["img"])
# TODO: remember to add native augmentation support when implementing model, like: # TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment) # preds, train_out = model(im, augment=augment)
# loss # loss
with dt[2]: with dt[2]:
if self.training: if self.training:
self.loss += trainer.criterion(preds, batch)[0] loss += trainer.criterion(preds, batch)[1]
# pre-process predictions # pre-process predictions
with dt[3]: with dt[3]:
preds = self.postprocess(preds) preds = self.postprocess(preds)
self.update_metrics(preds, batch) self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
stats = self.get_stats() stats = self.get_stats()
self.check_stats(stats) self.check_stats(stats)
@ -81,7 +89,7 @@ class BaseValidator:
# print speeds # print speeds
if not self.training: if not self.training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz) # shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info( self.logger.info(
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t) 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
@ -90,7 +98,8 @@ class BaseValidator:
model.float() model.float()
# TODO: implement save json # TODO: implement save json
return stats return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
if self.training else stats
def preprocess(self, batch): def preprocess(self, batch):
return batch return batch
@ -105,7 +114,7 @@ class BaseValidator:
pass pass
def get_stats(self): def get_stats(self):
pass return {}
def check_stats(self, stats): def check_stats(self, stats):
pass pass
@ -115,3 +124,14 @@ class BaseValidator:
def get_desc(self): def get_desc(self):
pass pass
@property
def metric_keys(self):
return []
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
pass
def plot_predictions(self, batch, preds, ni):
pass

@ -3,6 +3,7 @@ import logging.config
import os import os
import platform import platform
import sys import sys
import threading
from pathlib import Path from pathlib import Path
# Constants # Constants
@ -130,3 +131,13 @@ class TryExcept(contextlib.ContextDecorator):
if value: if value:
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True return True
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper

@ -26,11 +26,11 @@ deterministic: True
local_rank: -1 local_rank: -1
single_cls: False # train multi-class data as single-class single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training image_weights: False # use weighted image selection for training
shuffle: True
rect: False # support rectangular training rect: False # support rectangular training
cos_lr: False # Use cosine LR scheduler cos_lr: False # Use cosine LR scheduler
overlap_mask: True # Segmentation masks overlap overlap_mask: True # Segmentation masks overlap
mask_ratio: 4 # Segmentation mask downsample ratio mask_ratio: 4 # Segmentation mask downsample ratio
noval: False
# Val/Test settings ---------------------------------------------------------------------------------------------------- # Val/Test settings ----------------------------------------------------------------------------------------------------
save_json: False save_json: False
@ -43,7 +43,7 @@ plots: False
save_txt: False save_txt: False
# Hyperparameters ------------------------------------------------------------------------------------------------------ # Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1 momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4 weight_decay: 0.0005 # optimizer weight decay 5e-4
@ -59,6 +59,10 @@ iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore) # anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
label_smoothing: 0.0
nbs: 64 # nominal batch size
# anchors: 3
augment_hyp:
hsv_h: 0.015 # image HSV-Hue augmentation (fraction) hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction) hsv_v: 0.4 # image HSV-Value augmentation (fraction)
@ -72,9 +76,6 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability) mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
label_smoothing: 0.0
nbs: 64 # nominal batch size
# anchors: 3
# Hydra configs -------------------------------------------------------------------------------------------------------- # Hydra configs --------------------------------------------------------------------------------------------------------
hydra: hydra:

@ -283,6 +283,50 @@ def smooth(y, f=0.05):
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
y = smooth(py.mean(0), 0.05)
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
def compute_ap(recall, precision): def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves """ Compute the average precision, given the recall and precision curves
# Arguments # Arguments
@ -365,14 +409,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
f1 = 2 * p * r / (p + r + eps) f1 = 2 * p * r / (p + r + eps)
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict names = dict(enumerate(names)) # to dict
# TODO: plot
'''
if plot: if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names) plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1') plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision') plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall') plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
'''
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i] p, r, f1 = p[:, i], r[:, i], f1[:, i]

@ -1,12 +1,16 @@
import contextlib
import math
from pathlib import Path from pathlib import Path
from urllib.error import URLError from urllib.error import URLError
import cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd
import torch import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
from .checks import check_font, check_requirements, is_ascii from .checks import check_font, check_requirements, is_ascii
from .files import increment_path from .files import increment_path
@ -179,3 +183,147 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
return crop return crop
@threaded
def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy().astype(int)
if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
max_subplots = 16 # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im
# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(i + 1):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i
boxes = xywh2xyxy(bboxes[idx]).T
classes = cls[idx].astype('int')
labels = confs is None # labels if no conf column
conf = None if labels else confs[idx] # check for confidence presence (label vs pred)
if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
c = classes[j]
color = colors(c)
c = names[c] if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)
# Plot masks
if len(masks):
if masks.max() > 1.0: # mean that masks are overlap
image_masks = masks[[i]] # (1, 640, 640)
nl = idx.sum()
index = np.arange(nl).reshape(nl, 1, 1) + 1
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
else:
image_masks = masks[idx]
im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):
if labels or conf[j] > 0.25: # 0.25 conf thresh
color = colors(classes[j])
mh, mw = image_masks[j].shape
if mh != h or mw != w:
mask = image_masks[j].astype(np.uint8)
mask = cv2.resize(mask, (w, h))
mask = mask.astype(bool)
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
def plot_results_with_masks(file="path/to/results.csv", dir="", best=True):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
save_dir = Path(file).parent if file else Path(dir)
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
ax = ax.ravel()
files = list(save_dir.glob("results*.csv"))
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
for f in files:
try:
data = pd.read_csv(f)
index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
0.1 * data.values[:, 11])
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
y = data.values[:, j]
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2)
if best:
# best
ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3)
ax[i].set_title(s[j] + f"\n{round(y[index], 5)}")
else:
# last
ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3)
ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}")
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
print(f"Warning: Plotting error for {f}: {e}")
ax[1].legend()
fig.savefig(save_dir / "results.png", dpi=200)
plt.close()
def output_to_target(output, max_det=300):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
targets = []
for i, o in enumerate(output):
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6]

@ -245,3 +245,19 @@ class ModelEMA:
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes # Update EMA attributes
copy_attr(self.ema, model, include, exclude) copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def set_model_attributes(self):
self.model.names = self.data["names"]
def load_model(self, model_cfg, weights, data): def load_model(self, model_cfg, weights, data):
# TODO: why treat clf models as unique. We should have clf yamls? # TODO: why treat clf models as unique. We should have clf yamls?
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer):
ClassificationModel.reshape_outputs(model, data["nc"]) ClassificationModel.reshape_outputs(model, data["nc"])
return model return model
def get_dataloader(self, dataset_path, batch_size=None, rank=0): def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
return build_classification_dataloader(path=dataset_path, return build_classification_dataloader(path=dataset_path,
imgsz=self.args.img_size, imgsz=self.args.img_size,
batch_size=batch_size, batch_size=batch_size,

@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator):
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
top1, top5 = acc.mean(0).tolist() top1, top5 = acc.mean(0).tolist()
return {"top1": top1, "top5": top5, "fitness": top5} return {"top1": top1, "top5": top5, "fitness": top5}
@property
def metric_keys(self):
return ["top1", "top5"]

@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks
from ultralytics.yolo.utils.torch_utils import de_parallel from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage # BaseTrainer python usage
class SegmentationTrainer(BaseTrainer): class SegmentationTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size, rank=0): def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
# TODO: manage splits differently # TODO: manage splits differently
# calculate stride - check if model is initialized # calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_dataloader( return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
img_path=dataset_path,
img_size=self.args.img_size,
batch_size=batch_size,
single_cls=self.args.single_cls,
cache=self.args.cache,
image_weights=self.args.image_weights,
stride=gs,
rect=self.args.rect,
rank=rank,
workers=self.args.workers,
shuffle=self.args.shuffle,
use_segments=True,
)[0]
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer):
self.model.names = self.data["names"] self.model.names = self.data["names"]
def get_validator(self): def get_validator(self):
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console) return v8.segment.SegmentationValidator(self.test_loader,
save_dir=self.save_dir,
logger=self.console,
args=self.args)
def criterion(self, preds, batch): def criterion(self, preds, batch):
head = de_parallel(self.model).model[-1] head = de_parallel(self.model).model[-1]
@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer):
else: else:
mask_gti = masks[tidxs[i]][j] mask_gti = masks[tidxs[i]][j]
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j]) lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
else:
lseg += (proto * 0).sum()
obji = BCEobj(pi[..., 4], tobj) obji = BCEobj(pi[..., 4], tobj)
lobj += obji * balance[i] # obj loss lobj += obji * balance[i] # obj loss
@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer):
loss = lbox + lobj + lcls + lseg loss = lbox + lobj + lcls + lseg
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach() return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
def label_loss_items(self, loss_items): def label_loss_items(self, loss_items=None, prefix="train"):
# We should just use named tensors here in future # We should just use named tensors here in future
keys = ["lbox", "lseg", "lobj", "lcls"] keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
return dict(zip(keys, loss_items)) return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self): def progress_string(self):
return ('\n' + '%11s' * 7) % \ return ('\n' + '%11s' * 7) % \
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size') ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
def plot_training_samples(self, batch, ni):
images = batch["img"]
masks = batch["masks"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
paths,
fname=self.save_dir / f"train_batch{ni}.jpg")
def plot_metrics(self):
plot_results_with_masks(file=self.csv) # save results.png
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):

@ -6,23 +6,24 @@ import torch.nn.functional as F
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou, from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
fitness_segmentation, mask_iou) fitness_segmentation, mask_iou)
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
from ultralytics.yolo.utils.torch_utils import de_parallel from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationValidator(BaseValidator): class SegmentationValidator(BaseValidator):
def __init__(self, dataloader, pbar=None, logger=None, args=None): def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, logger, args)
if self.args.save_json: if self.args.save_json:
check_requirements(['pycocotools']) check_requirements(['pycocotools'])
self.process = ops.process_mask_upsample # more accurate self.process = ops.process_mask_upsample # more accurate
else: else:
self.process = ops.process_mask # faster self.process = ops.process_mask # faster
self.data_dict = yaml_load(self.args.data) if self.args.data else None self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
self.is_coco = False self.is_coco = False
self.class_map = None self.class_map = None
self.targets = None self.targets = None
@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator):
self.loss = torch.zeros(4, device=self.device) self.loss = torch.zeros(4, device=self.device)
self.jdict = [] self.jdict = []
self.stats = [] self.stats = []
self.plot_masks = []
def get_desc(self): def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator):
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
# Metrics # Metrics
plot_masks = [] # masks for plotting
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
labels = self.targets[self.targets[:, 0] == si, 1:] labels = self.targets[self.targets[:, 0] == si, 1:]
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
shape = batch["shape"][si] shape = batch["ori_shape"][si]
# path = batch["shape"][si][0] # path = batch["shape"][si][0]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator):
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3: if self.args.plots and self.batch_i < 3:
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# TODO: Save/log # TODO: Save/log
''' '''
@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator):
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
''' '''
# TODO Plot images
'''
if self.args.plots and self.batch_i < 3:
if len(plot_masks):
plot_masks = torch.cat(plot_masks, dim=0)
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
'''
def get_stats(self): def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any(): if len(stats) and stats[0].any():
# TODO: save_dir results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
self.metrics.update(results) self.metrics.update(results)
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))} metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
metrics |= zip(keys, self.metrics.mean_results()) metrics |= zip(self.metric_keys, self.metrics.mean_results())
return metrics return metrics
def print_results(self): def print_results(self):
@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator):
for i, c in enumerate(self.metrics.ap_class_index): for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
# plot TODO: save_dir
if self.args.plots: if self.args.plots:
self.confusion_matrix.plot(save_dir='', names=list(self.names.values())) self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
""" """
@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator):
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device) return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
@property
def metric_keys(self):
return [
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP_0.5(B)",
"metrics/mAP_0.5:0.95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP_0.5(M)",
"metrics/mAP_0.5:0.95(M)",]
def plot_val_samples(self, batch, ni):
images = batch["img"]
masks = batch["masks"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
paths,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names)
def plot_predictions(self, batch, preds, ni):
images = batch["img"]
paths = batch["im_file"]
if len(self.plot_masks):
plot_masks = torch.cat(self.plot_masks, dim=0)
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
self.plot_masks.clear()

Loading…
Cancel
Save