standalone val (#56)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator
|
||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
||||
|
||||
__all__ = ["train"]
|
||||
|
@ -19,6 +19,13 @@ class ClassificationTrainer(BaseTrainer):
|
||||
else:
|
||||
model = ClassificationModel(model_cfg, weights, data["nc"])
|
||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||
for m in model.modules():
|
||||
if not weights and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
|
@ -1,5 +1,8 @@
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
|
||||
|
||||
@ -24,6 +27,21 @@ class ClassificationValidator(BaseValidator):
|
||||
top1, top5 = acc.mean(0).tolist()
|
||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, batch_size=batch_size)
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return ["top1", "top5"]
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
cfg.data = cfg.data or "imagenette160"
|
||||
cfg.model = cfg.model or "resnet18"
|
||||
validator = ClassificationValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
val()
|
||||
|
@ -1,2 +1,2 @@
|
||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
|
||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator
|
||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val
|
||||
|
@ -33,6 +33,8 @@ class SegmentationTrainer(BaseTrainer):
|
||||
anchors=self.args.get("anchors"))
|
||||
if weights:
|
||||
model.load(weights)
|
||||
for _, v in model.named_parameters():
|
||||
v.requires_grad = True # train all layers
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
@ -257,7 +259,7 @@ class SegmentationTrainer(BaseTrainer):
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = v8.ROOT / "models/yolov5n-seg.yaml"
|
||||
cfg.model = cfg.model or "models/yolov5n-seg.yaml"
|
||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
trainer = SegmentationTrainer(cfg)
|
||||
trainer.train()
|
||||
|
@ -1,9 +1,12 @@
|
||||
import os
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||
@ -16,7 +19,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
class SegmentationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
if self.args.save_json:
|
||||
check_requirements(['pycocotools'])
|
||||
@ -43,14 +46,17 @@ class SegmentationValidator(BaseValidator):
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
head = de_parallel(model).model[-1]
|
||||
if self.data_dict:
|
||||
self.is_coco = isinstance(self.data_dict.get('val'),
|
||||
str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt')
|
||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||
if self.training:
|
||||
head = de_parallel(model).model[-1]
|
||||
else:
|
||||
head = de_parallel(model).model.model[-1]
|
||||
|
||||
if self.data:
|
||||
self.is_coco = isinstance(self.data.get('val'),
|
||||
str) and self.data['val'].endswith(f'coco{os.sep}val2017.txt')
|
||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||
self.nm = head.nm if hasattr(head, "nm") else 32
|
||||
self.nc = head.nc
|
||||
self.nm = head.nm
|
||||
self.names = model.names
|
||||
if isinstance(self.names, (list, tuple)): # old format
|
||||
self.names = dict(enumerate(self.names))
|
||||
@ -206,6 +212,12 @@ class SegmentationValidator(BaseValidator):
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
|
||||
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return [
|
||||
@ -243,3 +255,14 @@ class SegmentationValidator(BaseValidator):
|
||||
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
|
||||
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
||||
self.plot_masks.clear()
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
cfg.data = cfg.data or "coco128-seg.yaml"
|
||||
validator = SegmentationValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
val()
|
||||
|
Reference in New Issue
Block a user