update segment training (#57)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||
@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||
return model
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
|
@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator):
|
||||
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
top1, top5 = acc.mean(0).tolist()
|
||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return ["top1", "top5"]
|
||||
|
@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||
from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(BaseTrainer):
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_dataloader(
|
||||
img_path=dataset_path,
|
||||
img_size=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
single_cls=self.args.single_cls,
|
||||
cache=self.args.cache,
|
||||
image_weights=self.args.image_weights,
|
||||
stride=gs,
|
||||
rect=self.args.rect,
|
||||
rank=rank,
|
||||
workers=self.args.workers,
|
||||
shuffle=self.args.shuffle,
|
||||
use_segments=True,
|
||||
)[0]
|
||||
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||
@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_validator(self):
|
||||
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
|
||||
return v8.segment.SegmentationValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
args=self.args)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
head = de_parallel(self.model).model[-1]
|
||||
@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer):
|
||||
else:
|
||||
mask_gti = masks[tidxs[i]][j]
|
||||
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
|
||||
else:
|
||||
lseg += (proto * 0).sum()
|
||||
|
||||
obji = BCEobj(pi[..., 4], tobj)
|
||||
lobj += obji * balance[i] # obj loss
|
||||
@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer):
|
||||
loss = lbox + lobj + lcls + lseg
|
||||
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
|
||||
|
||||
def label_loss_items(self, loss_items):
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
# We should just use named tensors here in future
|
||||
keys = ["lbox", "lseg", "lobj", "lcls"]
|
||||
return dict(zip(keys, loss_items))
|
||||
keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
|
||||
return dict(zip(keys, loss_items)) if loss_items is not None else keys
|
||||
|
||||
def progress_string(self):
|
||||
return ('\n' + '%11s' * 7) % \
|
||||
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
images = batch["img"]
|
||||
masks = batch["masks"]
|
||||
cls = batch["cls"].squeeze(-1)
|
||||
bboxes = batch["bboxes"]
|
||||
paths = batch["im_file"]
|
||||
batch_idx = batch["batch_idx"]
|
||||
plot_images_and_masks(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths,
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg")
|
||||
|
||||
def plot_metrics(self):
|
||||
plot_results_with_masks(file=self.csv) # save results.png
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
|
@ -6,23 +6,24 @@ import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
|
||||
fitness_segmentation, mask_iou)
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
class SegmentationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, pbar, logger, args)
|
||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
if self.args.save_json:
|
||||
check_requirements(['pycocotools'])
|
||||
self.process = ops.process_mask_upsample # more accurate
|
||||
else:
|
||||
self.process = ops.process_mask # faster
|
||||
self.data_dict = yaml_load(self.args.data) if self.args.data else None
|
||||
self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
|
||||
self.is_coco = False
|
||||
self.class_map = None
|
||||
self.targets = None
|
||||
@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator):
|
||||
self.loss = torch.zeros(4, device=self.device)
|
||||
self.jdict = []
|
||||
self.stats = []
|
||||
self.plot_masks = []
|
||||
|
||||
def get_desc(self):
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
|
||||
@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator):
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
# Metrics
|
||||
plot_masks = [] # masks for plotting
|
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||
labels = self.targets[self.targets[:, 0] == si, 1:]
|
||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch["shape"][si]
|
||||
shape = batch["ori_shape"][si]
|
||||
# path = batch["shape"][si][0]
|
||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator):
|
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
|
||||
# TODO: Save/log
|
||||
'''
|
||||
@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator):
|
||||
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
|
||||
'''
|
||||
|
||||
# TODO Plot images
|
||||
'''
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
if len(plot_masks):
|
||||
plot_masks = torch.cat(plot_masks, dim=0)
|
||||
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
|
||||
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
|
||||
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
|
||||
'''
|
||||
|
||||
def get_stats(self):
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
# TODO: save_dir
|
||||
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
|
||||
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
|
||||
self.metrics.update(results)
|
||||
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
|
||||
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
|
||||
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
|
||||
metrics |= zip(keys, self.metrics.mean_results())
|
||||
metrics |= zip(self.metric_keys, self.metrics.mean_results())
|
||||
return metrics
|
||||
|
||||
def print_results(self):
|
||||
@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator):
|
||||
for i, c in enumerate(self.metrics.ap_class_index):
|
||||
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
|
||||
# plot TODO: save_dir
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
|
||||
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator):
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return [
|
||||
"metrics/precision(B)",
|
||||
"metrics/recall(B)",
|
||||
"metrics/mAP_0.5(B)",
|
||||
"metrics/mAP_0.5:0.95(B)", # metrics
|
||||
"metrics/precision(M)",
|
||||
"metrics/recall(M)",
|
||||
"metrics/mAP_0.5(M)",
|
||||
"metrics/mAP_0.5:0.95(M)",]
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
images = batch["img"]
|
||||
masks = batch["masks"]
|
||||
cls = batch["cls"].squeeze(-1)
|
||||
bboxes = batch["bboxes"]
|
||||
paths = batch["im_file"]
|
||||
batch_idx = batch["batch_idx"]
|
||||
plot_images_and_masks(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
images = batch["img"]
|
||||
paths = batch["im_file"]
|
||||
if len(self.plot_masks):
|
||||
plot_masks = torch.cat(self.plot_masks, dim=0)
|
||||
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
|
||||
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
|
||||
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
||||
self.plot_masks.clear()
|
||||
|
Reference in New Issue
Block a user