standalone val (#56)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 3a241e4cea
commit 5a52e7663a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -93,9 +93,12 @@ jobs:
echo "TODO" echo "TODO"
- name: Test segmentation - name: Test segmentation
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
# TODO: redo val test without hardcoded weights
run: | run: |
yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64 yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
yolo task=segment mode=val model=runs/exp/weights/last.pt data=coco128-seg.yaml img_size=64
- name: Test classification - name: Test classification
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32 yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32
yolo task=classify mode=val model=runs/exp2/weights/last.pt data=mnist160

@ -208,6 +208,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
sample = self.torch_transforms(im) sample = self.torch_transforms(im)
return OrderedDict(img=sample, cls=j) return OrderedDict(img=sample, cls=j)
def __len__(self) -> int:
return len(self.samples)
# TODO: support semantic segmentation # TODO: support semantic segmentation
class SemanticDataset(BaseDataset): class SemanticDataset(BaseDataset):

@ -0,0 +1,19 @@
import pandas as pd
def export_formats():
# YOLOv5 export formats
x = [
['PyTorch', '-', '.pt', True, True],
['TorchScript', 'torchscript', '.torchscript', True, True],
['ONNX', 'onnx', '.onnx', True, True],
['OpenVINO', 'openvino', '_openvino_model', True, False],
['TensorRT', 'engine', '.engine', False, True],
['CoreML', 'coreml', '.mlmodel', True, False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
['TensorFlow GraphDef', 'pb', '.pb', True, True],
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
['TensorFlow.js', 'tfjs', '_web_model', False, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])

@ -25,7 +25,7 @@ import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.callbacks as callbacks import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import print_args from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
@ -299,13 +299,16 @@ class BaseTrainer:
""" """
Get train, val path from data dict if it exists. Returns None if data format is not recognized Get train, val path from data dict if it exists. Returns None if data format is not recognized
""" """
return data["train"], data["val"] return data["train"], data.get("val") or data.get("test")
def get_model(self, model: Union[str, Path]): def get_model(self, model: Union[str, Path]):
""" """
load/create/download model for any task load/create/download model for any task
""" """
pretrained = not str(model).endswith(".yaml") pretrained = True
if str(model).endswith(".yaml"):
model = check_file(model)
pretrained = False
return self.load_model(model_cfg=None if pretrained else model, return self.load_model(model_cfg=None if pretrained else model,
weights=get_model(model) if pretrained else None, weights=get_model(model) if pretrained else None,
data=self.data) # model data=self.data) # model
@ -376,7 +379,7 @@ class BaseTrainer:
""" """
To set or update model parameters before training. To set or update model parameters before training.
""" """
pass self.model.names = self.data["names"]
def build_targets(self, preds, targets): def build_targets(self, preds, targets):
pass pass

@ -5,11 +5,14 @@ import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import TQDM_BAR_FORMAT from ultralytics.yolo.utils import LOGGER, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device from ultralytics.yolo.utils.torch_utils import check_img_size, de_parallel, select_device
class BaseValidator: class BaseValidator:
@ -17,17 +20,18 @@ class BaseValidator:
Base validator class. Base validator class.
""" """
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
self.logger = logger or logging.getLogger() self.logger = logger or LOGGER
self.args = args or OmegaConf.load(DEFAULT_CONFIG) self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.device = select_device(self.args.device, dataloader.batch_size) self.model = None
self.save_dir = save_dir if save_dir is not None else \ self.data = None
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) self.device = None
self.cuda = self.device.type != 'cpu'
self.batch_i = None self.batch_i = None
self.training = True self.training = True
self.save_dir = save_dir if save_dir is not None else \
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
def __call__(self, trainer=None, model=None): def __call__(self, trainer=None, model=None):
""" """
@ -36,14 +40,35 @@ class BaseValidator:
""" """
self.training = trainer is not None self.training = trainer is not None
if self.training: if self.training:
self.device = trainer.device
self.data = trainer.data
model = trainer.ema.ema or trainer.model model = trainer.ema.ema or trainer.model
self.args.half &= self.device.type != 'cpu' self.args.half &= self.device.type != 'cpu'
model = model.half() if self.args.half else model.float() model = model.half() if self.args.half else model.float()
self.model = model
loss = torch.zeros_like(trainer.loss_items, device=trainer.device) loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
else: # TODO: handle this when detectMultiBackend is supported else: # TODO: handle this when detectMultiBackend is supported
assert model is not None, "Either trainer or model is needed for validation" assert model is not None, "Either trainer or model is needed for validation"
# model = DetectMultiBacked(model) self.device = select_device(self.args.device, self.args.batch_size)
# TODO: implement init_model_attributes() self.args.half &= self.device.type != 'cpu'
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
self.model = model
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_img_size(self.args.img_size, s=stride)
if engine:
self.args.batch_size = model.batch_size
else:
self.device = model.device
if not (pt or jit):
self.args.batch_size = 1 # export.py models default to batch-size 1
self.logger.info(
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if self.args.data.endswith(".yaml"):
data = check_dataset_yaml(self.args.data)
else:
data = check_dataset(self.args.data)
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
model.eval() model.eval()
@ -101,6 +126,9 @@ class BaseValidator:
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \ return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
if self.training else stats if self.training else stats
def get_dataloader(self, dataset_path, batch_size):
raise Exception("get_dataloder function not implemented for this validator")
def preprocess(self, batch): def preprocess(self, batch):
return batch return batch

@ -28,17 +28,22 @@ single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training image_weights: False # use weighted image selection for training
rect: False # support rectangular training rect: False # support rectangular training
cos_lr: False # Use cosine LR scheduler cos_lr: False # Use cosine LR scheduler
overlap_mask: True # Segmentation masks overlap # Segmentation
mask_ratio: 4 # Segmentation mask downsample ratio overlap_mask: True # masks overlap
noval: False mask_ratio: 4 # mask downsample ratio
# Classification
dropout: False # use dropout
# Val/Test settings ---------------------------------------------------------------------------------------------------- # Val/Test settings ----------------------------------------------------------------------------------------------------
noval: False
save_json: False save_json: False
save_hybrid: False save_hybrid: False
conf_thres: 0.001 conf_thres: 0.001
iou_thres: 0.6 iou_thres: 0.6
max_det: 300 max_det: 300
half: True half: True
dnn: False # use OpenCV DNN for ONNX inference
plots: False plots: False
save_txt: False save_txt: False

@ -113,8 +113,8 @@ def get_model(model='s.pt', pretrained=True):
model = model.split(".")[0] model = model.split(".")[0]
if Path(f"{model}.pt").is_file(): # local file if Path(f"{model}.pt").is_file(): # local file
return torch.load(f"{model}.pt", map_location='cpu') return attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0 elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: # Ultralytics assets else: # Ultralytics assets
return torch.load(attempt_download(f"{model}.pt"), map_location='cpu') return attempt_load_weights(f"{model}.pt", device='cpu')

@ -304,7 +304,7 @@ class AutoBackend(nn.Module):
def _model_type(p='path/to/model.pt'): def _model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from export import export_formats from ultralytics.yolo.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False): if not is_url(p, check=False):
check_suffix(p, sf) # checks check_suffix(p, sf) # checks

@ -172,7 +172,7 @@ class DetectionModel(BaseModel):
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32 csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_state_dicts(csd, self.state_dict()) # intersect csd = intersect_state_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load self.load_state_dict(csd, strict=False) # load
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}') LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
class SegmentationModel(DetectionModel): class SegmentationModel(DetectionModel):

@ -164,6 +164,25 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def check_img_size(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. img_size=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. img_size=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def make_divisible(x, divisor):
# Returns nearest x divisible by divisor
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
def copy_attr(a, b, include=(), exclude=()): def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...] # Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items(): for k, v in b.__dict__.items():

@ -1,4 +1,4 @@
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
from ultralytics.yolo.v8.classify.val import ClassificationValidator from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
__all__ = ["train"] __all__ = ["train"]

@ -19,6 +19,13 @@ class ClassificationTrainer(BaseTrainer):
else: else:
model = ClassificationModel(model_cfg, weights, data["nc"]) model = ClassificationModel(model_cfg, weights, data["nc"])
ClassificationModel.reshape_outputs(model, data["nc"]) ClassificationModel.reshape_outputs(model, data["nc"])
for m in model.modules():
if not weights and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model return model
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"): def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):

@ -1,5 +1,8 @@
import hydra
import torch import torch
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
@ -24,6 +27,21 @@ class ClassificationValidator(BaseValidator):
top1, top5 = acc.mean(0).tolist() top1, top5 = acc.mean(0).tolist()
return {"top1": top1, "top5": top5, "fitness": top5} return {"top1": top1, "top5": top5, "fitness": top5}
def get_dataloader(self, dataset_path, batch_size):
return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, batch_size=batch_size)
@property @property
def metric_keys(self): def metric_keys(self):
return ["top1", "top5"] return ["top1", "top5"]
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def val(cfg):
cfg.data = cfg.data or "imagenette160"
cfg.model = cfg.model or "resnet18"
validator = ClassificationValidator(args=cfg)
validator(model=cfg.model)
if __name__ == "__main__":
val()

@ -1,2 +1,2 @@
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
from ultralytics.yolo.v8.segment.val import SegmentationValidator from ultralytics.yolo.v8.segment.val import SegmentationValidator, val

@ -33,6 +33,8 @@ class SegmentationTrainer(BaseTrainer):
anchors=self.args.get("anchors")) anchors=self.args.get("anchors"))
if weights: if weights:
model.load(weights) model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model return model
def set_model_attributes(self): def set_model_attributes(self):
@ -257,7 +259,7 @@ class SegmentationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):
cfg.model = v8.ROOT / "models/yolov5n-seg.yaml" cfg.model = cfg.model or "models/yolov5n-seg.yaml"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
trainer = SegmentationTrainer(cfg) trainer = SegmentationTrainer(cfg)
trainer.train() trainer.train()

@ -1,9 +1,12 @@
import os import os
import hydra
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_file, check_requirements from ultralytics.yolo.utils.checks import check_file, check_requirements
@ -16,7 +19,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationValidator(BaseValidator): class SegmentationValidator(BaseValidator):
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, logger, args)
if self.args.save_json: if self.args.save_json:
check_requirements(['pycocotools']) check_requirements(['pycocotools'])
@ -43,14 +46,17 @@ class SegmentationValidator(BaseValidator):
return batch return batch
def init_metrics(self, model): def init_metrics(self, model):
if self.training:
head = de_parallel(model).model[-1] head = de_parallel(model).model[-1]
if self.data_dict: else:
self.is_coco = isinstance(self.data_dict.get('val'), head = de_parallel(model).model.model[-1]
str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt')
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
if self.data:
self.is_coco = isinstance(self.data.get('val'),
str) and self.data['val'].endswith(f'coco{os.sep}val2017.txt')
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.nm = head.nm if hasattr(head, "nm") else 32
self.nc = head.nc self.nc = head.nc
self.nm = head.nm
self.names = model.names self.names = model.names
if isinstance(self.names, (list, tuple)): # old format if isinstance(self.names, (list, tuple)): # old format
self.names = dict(enumerate(self.names)) self.names = dict(enumerate(self.names))
@ -206,6 +212,12 @@ class SegmentationValidator(BaseValidator):
correct[matches[:, 1].astype(int), i] = True correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device) return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
def get_dataloader(self, dataset_path, batch_size):
# TODO: manage splits differently
# calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
@property @property
def metric_keys(self): def metric_keys(self):
return [ return [
@ -243,3 +255,14 @@ class SegmentationValidator(BaseValidator):
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf, plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
self.plot_masks.clear() self.plot_masks.clear()
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def val(cfg):
cfg.data = cfg.data or "coco128-seg.yaml"
validator = SegmentationValidator(args=cfg)
validator(model=cfg.model)
if __name__ == "__main__":
val()

Loading…
Cancel
Save