standalone val (#56)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-30 15:04:44 +05:30
committed by GitHub
parent 3a241e4cea
commit 5a52e7663a
16 changed files with 161 additions and 31 deletions

View File

@ -1,4 +1,4 @@
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
from ultralytics.yolo.v8.classify.val import ClassificationValidator
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
__all__ = ["train"]

View File

@ -19,6 +19,13 @@ class ClassificationTrainer(BaseTrainer):
else:
model = ClassificationModel(model_cfg, weights, data["nc"])
ClassificationModel.reshape_outputs(model, data["nc"])
for m in model.modules():
if not weights and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):

View File

@ -1,5 +1,8 @@
import hydra
import torch
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.engine.validator import BaseValidator
@ -24,6 +27,21 @@ class ClassificationValidator(BaseValidator):
top1, top5 = acc.mean(0).tolist()
return {"top1": top1, "top5": top5, "fitness": top5}
def get_dataloader(self, dataset_path, batch_size):
return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, batch_size=batch_size)
@property
def metric_keys(self):
return ["top1", "top5"]
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def val(cfg):
cfg.data = cfg.data or "imagenette160"
cfg.model = cfg.model or "resnet18"
validator = ClassificationValidator(args=cfg)
validator(model=cfg.model)
if __name__ == "__main__":
val()

View File

@ -1,2 +1,2 @@
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
from ultralytics.yolo.v8.segment.val import SegmentationValidator
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val

View File

@ -33,6 +33,8 @@ class SegmentationTrainer(BaseTrainer):
anchors=self.args.get("anchors"))
if weights:
model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model
def set_model_attributes(self):
@ -257,7 +259,7 @@ class SegmentationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.model = v8.ROOT / "models/yolov5n-seg.yaml"
cfg.model = cfg.model or "models/yolov5n-seg.yaml"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
trainer = SegmentationTrainer(cfg)
trainer.train()

View File

@ -1,9 +1,12 @@
import os
import hydra
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_file, check_requirements
@ -16,7 +19,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationValidator(BaseValidator):
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
if self.args.save_json:
check_requirements(['pycocotools'])
@ -43,14 +46,17 @@ class SegmentationValidator(BaseValidator):
return batch
def init_metrics(self, model):
head = de_parallel(model).model[-1]
if self.data_dict:
self.is_coco = isinstance(self.data_dict.get('val'),
str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt')
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
if self.training:
head = de_parallel(model).model[-1]
else:
head = de_parallel(model).model.model[-1]
if self.data:
self.is_coco = isinstance(self.data.get('val'),
str) and self.data['val'].endswith(f'coco{os.sep}val2017.txt')
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.nm = head.nm if hasattr(head, "nm") else 32
self.nc = head.nc
self.nm = head.nm
self.names = model.names
if isinstance(self.names, (list, tuple)): # old format
self.names = dict(enumerate(self.names))
@ -206,6 +212,12 @@ class SegmentationValidator(BaseValidator):
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
def get_dataloader(self, dataset_path, batch_size):
# TODO: manage splits differently
# calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
@property
def metric_keys(self):
return [
@ -243,3 +255,14 @@ class SegmentationValidator(BaseValidator):
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
self.plot_masks.clear()
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def val(cfg):
cfg.data = cfg.data or "coco128-seg.yaml"
validator = SegmentationValidator(args=cfg)
validator(model=cfg.model)
if __name__ == "__main__":
val()