From 1054819a598f39fcb6d94179f3c4344604c452f2 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 26 Oct 2022 01:21:15 +0530 Subject: [PATCH] Add initial model interface (#30) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/tests/test_model.py | 13 +++ ultralytics/yolo/__init__.py | 5 +- ultralytics/yolo/data/augment.py | 2 +- ultralytics/yolo/data/utils.py | 7 +- ultralytics/yolo/engine/model.py | 63 +++++++++++++ ultralytics/yolo/engine/trainer.py | 77 ++++++++-------- ultralytics/yolo/utils/configs/defaults.yaml | 95 ++++++++++---------- ultralytics/yolo/utils/modeling/tasks.py | 19 +++- ultralytics/yolo/utils/torch_utils.py | 5 ++ ultralytics/yolo/v8/classify/__init__.py | 3 +- ultralytics/yolo/v8/classify/train.py | 36 ++++---- ultralytics/yolo/v8/classify/val.py | 4 +- 12 files changed, 220 insertions(+), 109 deletions(-) create mode 100644 ultralytics/tests/test_model.py create mode 100644 ultralytics/yolo/engine/model.py diff --git a/ultralytics/tests/test_model.py b/ultralytics/tests/test_model.py new file mode 100644 index 0000000..353fab1 --- /dev/null +++ b/ultralytics/tests/test_model.py @@ -0,0 +1,13 @@ +from ultralytics.yolo import YOLO + + +def test_model(): + model = YOLO() + model.new("assets/dummy_model.yaml") + model.model = "squeezenet1_0" # temp solution before get_model is implemented + # model.load("yolov5n.pt") + model.train(data="imagenette160", epochs=1, lr0=0.01) + + +if __name__ == "__main__": + test_model() diff --git a/ultralytics/yolo/__init__.py b/ultralytics/yolo/__init__.py index 85f6f6c..fa1c3b2 100644 --- a/ultralytics/yolo/__init__.py +++ b/ultralytics/yolo/__init__.py @@ -1,4 +1,7 @@ +import ultralytics.yolo.v8 as v8 + +from .engine.model import YOLO from .engine.trainer import BaseTrainer from .engine.validator import BaseValidator -__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import +__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index 553b057..6c936ad 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -728,7 +728,7 @@ def classify_albumentations( if vflip > 0: T += [A.VerticalFlip(p=vflip)] if jitter > 0: - color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue + color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue T += [A.ColorJitter(*color_jitter, 0)] else: # Use fixed crop for eval set (reproducibility) T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index 70b79d4..0027742 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -51,7 +51,8 @@ def exif_size(img): def verify_image_label(args): # Verify one image-label pair im_file, lb_file, prefix, keypoint = args - nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None # number (missing, found, empty, corrupt), message, segments, keypoints + # number (missing, found, empty, corrupt), message, segments, keypoints + nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None try: # verify images im = Image.open(im_file) @@ -86,10 +87,10 @@ def verify_image_label(args): kpts = np.zeros((lb.shape[0], 39)) for i in range(len(lb)): kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, - 3)) # remove the occlusion paramater from the GT + 3)) # remove the occlusion parameter from the GT kpts[i] = np.hstack((lb[i, :5], kpt)) lb = kpts - assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater" + assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter" else: assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}" diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py new file mode 100644 index 0000000..838014b --- /dev/null +++ b/ultralytics/yolo/engine/model.py @@ -0,0 +1,63 @@ +""" +Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13 +""" +import torch +import yaml + +import ultralytics.yolo as yolo +from ultralytics.yolo.utils import LOGGER +from ultralytics.yolo.utils.checks import check_yaml +from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel + +# map head: [model, trainer] +MODEL_MAP = { + "Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], + "Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp + "Segment": []} + + +class YOLO: + + def __init__(self, version=8) -> None: + self.version = version + self.model = None + self.trainer = None + self.pretrained_weights = None + + def new(self, cfg: str): + cfg = check_yaml(cfg) # check YAML + self.model, self.trainer = self._get_model_and_trainer(cfg) + + def load(self, weights, autodownload=True): + if not isinstance(self.pretrained_weights, type(None)): + LOGGER.info("Overwriting weights") + # TODO: weights = smart_file_loader(weights) + if self.model: + self.model.load(weights) + LOGGER.info("Checkpoint loaded successfully") + else: + # TODO: infer model and trainer + pass + + self.pretrained_weights = weights + + def reset(self): + pass + + def train(self, **kwargs): + if 'data' not in kwargs: + raise Exception("data is required to train") + if not self.model: + raise Exception("model not initialized. Use .new() or .load()") + kwargs["model"] = self.model + trainer = self.trainer(overrides=kwargs) + trainer.train() + + def _get_model_and_trainer(self, cfg): + with open(cfg, encoding='ascii', errors='ignore') as f: + cfg = yaml.safe_load(f) # model dict + model, trainer = MODEL_MAP[cfg["head"][-1][-2]] + # warning: eval is unsafe. Use with caution + trainer = eval(trainer.replace("VERSION", f"v{self.version}")) + + return model(cfg), trainer diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 6ba33cd..875bc35 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Union +from typing import Dict, Union import torch import torch.distributed as dist @@ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml" class BaseTrainer: - def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG): + def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}): self.console = LOGGER - self.model, self.data, self.train, self.hyps = self._get_config(config) + self.args = self._get_config(config, overrides) self.validator = None self.callbacks = defaultdict(list) - self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug + self.console.info(f"Training config: \n args: \n {self.args}") # to debug # Directories - self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok) + self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) self.wdir = self.save_dir / 'weights' self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # Save run settings - save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True)) + save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # device - self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size) + self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size) self.console.info(f"running on device {self.device}") self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') # Model and Dataloaders. - self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model - self.model = self.get_model() - self.model = self.model.to(self.device) + self.trainset, self.testset = self.get_dataset(self.args.data) + self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) # epoch level metrics self.metrics = {} # handle metrics returned by validator @@ -63,18 +62,24 @@ class BaseTrainer: for callback, func in loggers.default_callbacks.items(): self.add_callback(callback, func) - def _get_config(self, config: Union[str, Path, DictConfig] = None): + def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): """ Accepts yaml file name or DictConfig containing experiment configuration. - Returns train and hyps namespace + Returns training args namespace :param config: Optional file name or DictConfig object """ - try: - if isinstance(config, (str, Path)): - config = OmegaConf.load(config) - return config.model, config.data, config.train, config.hyps - except KeyError as e: - raise KeyError("Missing key(s) in config") from e + if isinstance(config, (str, Path)): + config = OmegaConf.load(config) + elif isinstance(config, Dict): + config = OmegaConf.create(config) + + # override + if isinstance(overrides, str): + overrides = OmegaConf.load(overrides) + elif isinstance(overrides, Dict): + overrides = OmegaConf.create(overrides) + + return OmegaConf.merge(config, overrides) def add_callback(self, onevent: str, callback): """ @@ -92,7 +97,7 @@ class BaseTrainer: for callback in self.callbacks.get(onevent, []): callback(self) - def run(self): + def train(self): world_size = torch.cuda.device_count() if world_size > 1: mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) @@ -109,21 +114,21 @@ class BaseTrainer: dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) self.model = self.model.to(self.device) self.model = DDP(self.model, device_ids=[rank]) - self.train.batch_size = self.train.batch_size // world_size + self.args.batch_size = self.args.batch_size // world_size def _setup_train(self, rank): """ Builds dataloaders and optimizer on correct rank process """ self.optimizer = build_optimizer(model=self.model, - name=self.train.optimizer, - lr=self.hyps.lr0, - momentum=self.hyps.momentum, - decay=self.hyps.weight_decay) - self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank) + name=self.args.optimizer, + lr=self.args.lr0, + momentum=self.args.momentum, + decay=self.args.weight_decay) + self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) if rank in {0, -1}: print(" Creating testloader rank :", rank) - self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank) + self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) self.validator = self.get_validator() print("created testloader :", rank) @@ -138,7 +143,7 @@ class BaseTrainer: self.epoch_time = None self.epoch_time_start = time.time() self.train_time_start = time.time() - for epoch in range(self.train.epochs): + for epoch in range(self.args.epochs): # callback hook. on_epoch_start self.model.train() pbar = enumerate(self.train_loader) @@ -165,7 +170,7 @@ class BaseTrainer: # log mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) if rank in {-1, 0}: - pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 + pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 if rank in [-1, 0]: # validation @@ -174,7 +179,7 @@ class BaseTrainer: # callback: on_val_end() # save model - if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs): + if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs): self.save_model() # callback; on_model_save @@ -198,7 +203,7 @@ class BaseTrainer: 'ema': None, # deepcopy(ema.ema).half(), 'updates': None, # ema.updates, 'optimizer': None, # optimizer.state_dict(), - 'train_args': self.train, + 'train_args': self.args, 'date': datetime.now().isoformat()} # Save last, best and delete @@ -207,22 +212,22 @@ class BaseTrainer: torch.save(ckpt, self.best) del ckpt - def get_dataloader(self, path): + def get_dataloader(self, dataset_path, batch_size=16, rank=0): """ Returns dataloader derived from torch.data.Dataloader """ pass - def get_dataset(self): + def get_dataset(self, data): """ - Uses self.dataset to download the dataset if needed and verify it. + Download the dataset if needed and verify it. Returns train and val split datasets """ pass - def get_model(self): + def get_model(self, model, pretrained=True): """ - Uses self.model to load/create/download dataset for any task + load/create/download model for any task """ pass @@ -238,7 +243,7 @@ class BaseTrainer: def preprocess_batch(self, images, labels): """ - Allows custom preprocessing model inputs and ground truths depeding on task type + Allows custom preprocessing model inputs and ground truths depending on task type """ return images.to(self.device, non_blocking=True), labels.to(self.device) diff --git a/ultralytics/yolo/utils/configs/defaults.yaml b/ultralytics/yolo/utils/configs/defaults.yaml index 35b6e02..b2e7474 100644 --- a/ultralytics/yolo/utils/configs/defaults.yaml +++ b/ultralytics/yolo/utils/configs/defaults.yaml @@ -1,53 +1,56 @@ model: null data: null -train: - epochs: 300 - batch_size: 16 - img_size: 640 - nosave: False - cache: False # True/ram for ram, or disc - device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu - workers: 8 - project: "ultralytics-yolo" - name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ? - exist_ok: False - pretrained: False - optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] - verbose: False - seed: 0 - local_rank: -1 -hyps: - lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) - lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) - momentum: 0.937 # SGD momentum/Adam beta1 - weight_decay: 0.0005 # optimizer weight decay 5e-4 - warmup_epochs: 3.0 # warmup epochs (fractions ok) - warmup_momentum: 0.8 # warmup initial momentum - warmup_bias_lr: 0.1 # warmup initial bias lr - box: 0.05 # box loss gain - cls: 0.5 # cls loss gain - cls_pw: 1.0 # cls BCELoss positive_weight - obj: 1.0 # obj loss gain (scale with pixels) - obj_pw: 1.0 # obj BCELoss positive_weight - iou_t: 0.20 # IoU training threshold - anchor_t: 4.0 # anchor-multiple threshold - # anchors: 3 # anchors per output layer (0 to ignore) - fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) - hsv_h: 0.015 # image HSV-Hue augmentation (fraction) - hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) - hsv_v: 0.4 # image HSV-Value augmentation (fraction) - degrees: 0.0 # image rotation (+/- deg) - translate: 0.1 # image translation (+/- fraction) - scale: 0.5 # image scale (+/- gain) - shear: 0.0 # image shear (+/- deg) - perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 - flipud: 0.0 # image flip up-down (probability) - fliplr: 0.5 # image flip left-right (probability) - mosaic: 1.0 # image mosaic (probability) - mixup: 0.0 # image mixup (probability) - copy_paste: 0.0 # segment copy-paste (probability) +# Training options +epochs: 300 +batch_size: 16 +img_size: 640 +nosave: False +cache: False # True/ram for ram, or disc +device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu +workers: 8 +project: "ultralytics-yolo" +name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ? +exist_ok: False +pretrained: False +optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] +verbose: False +seed: 0 +local_rank: -1 +#-----------------------------------# +# Hyper-parameters +lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr +box: 0.05 # box loss gain +cls: 0.5 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 1.0 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.20 # IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +# anchors: 3 # anchors per output layer (0 to ignore) +fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.015 # image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # image HSV-Value augmentation (fraction) +degrees: 0.0 # image rotation (+/- deg) +translate: 0.1 # image translation (+/- fraction) +scale: 0.5 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability) +mixup: 0.0 # image mixup (probability) +copy_paste: 0.0 # segment copy-paste (probability) + +# Hydra configs ------------------------------------- # to disable hydra directory creation hydra: output_subdir: null diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py index 6f2184e..0fe7e00 100644 --- a/ultralytics/yolo/utils/modeling/tasks.py +++ b/ultralytics/yolo/utils/modeling/tasks.py @@ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils.anchors import check_anchor_order from ultralytics.yolo.utils.modeling import parse_model from ultralytics.yolo.utils.modeling.modules import * -from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync +from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info, + scale_img, time_sync) class BaseModel(nn.Module): @@ -67,6 +68,10 @@ class BaseModel(nn.Module): m.anchor_grid = list(map(fn, m.anchor_grid)) return self + def load(self, weights): + # Force all tasks implement this function + raise NotImplementedError("This function needs to be implemented by derived classes!") + class DetectionModel(BaseModel): # YOLO detection model @@ -166,6 +171,12 @@ class DetectionModel(BaseModel): b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + def load(self, weights): + ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = intersect_state_dicts(csd, self.state_dict()) # intersect + self.load_state_dict(csd, strict=False) # load + class SegmentationModel(DetectionModel): # YOLOv5 segmentation model @@ -197,3 +208,9 @@ class ClassificationModel(BaseModel): def _from_yaml(self, cfg): # Create a YOLOv5 classification model from a *.yaml file self.model = None + + def load(self, weights): + ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = intersect_state_dicts(csd, self.state_dict()) # intersect + self.load_state_dict(csd, strict=False) # load diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 469c23d..0466810 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -174,3 +174,8 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) return decorate + + +def intersect_state_dicts(da, db, exclude=()): + # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values + return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} diff --git a/ultralytics/yolo/v8/classify/__init__.py b/ultralytics/yolo/v8/classify/__init__.py index 278a980..23a43a3 100644 --- a/ultralytics/yolo/v8/classify/__init__.py +++ b/ultralytics/yolo/v8/classify/__init__.py @@ -1,3 +1,4 @@ -from ultralytics.yolo.v8.classify import train +from ultralytics.yolo.v8.classify.train import ClassificationTrainer +from ultralytics.yolo.v8.classify.val import ClassificationValidator __all__ = ["train"] diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index feb05e2..534dbfd 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -5,11 +5,10 @@ from pathlib import Path import hydra import torch import torchvision -from val import ClassificationValidator -from ultralytics.yolo import BaseTrainer, v8 +from ultralytics.yolo import v8 from ultralytics.yolo.data import build_classification_dataloader -from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG +from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.files import WorkingDirectory from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first @@ -18,9 +17,9 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer # BaseTrainer python usage class ClassificationTrainer(BaseTrainer): - def get_dataset(self): + def get_dataset(self, dataset): # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module - data = Path("datasets") / self.data + data = Path("datasets") / dataset with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): data_dir = data if data.is_dir() else (Path.cwd() / data) if not data_dir.is_dir(): @@ -29,7 +28,7 @@ class ClassificationTrainer(BaseTrainer): if str(data) == 'imagenet': subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) else: - url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{self.data}.zip' + url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' download(url, dir=data_dir.parent) # TODO: add colorstr s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" @@ -39,17 +38,18 @@ class ClassificationTrainer(BaseTrainer): return train_set, test_set - def get_dataloader(self, dataset, batch_size=None, rank=0): - return build_classification_dataloader(path=dataset, batch_size=self.train.batch_size, rank=rank) + def get_dataloader(self, dataset_path, batch_size=None, rank=0): + return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank) - def get_model(self): + def get_model(self, model, pretrained): # temp. minimal. only supports torchvision models - if self.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0 - model = torchvision.models.__dict__[self.model](weights='IMAGENET1K_V1' if self.train.pretrained else None) + model = self.args.model + if model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0 + model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) else: - raise ModuleNotFoundError(f'--model {self.model} not found.') + raise ModuleNotFoundError(f'--model {model} not found.') for m in model.modules(): - if not self.train.pretrained and hasattr(m, 'reset_parameters'): + if not pretrained and hasattr(m, 'reset_parameters'): m.reset_parameters() for p in model.parameters(): p.requires_grad = True # for training @@ -57,7 +57,7 @@ class ClassificationTrainer(BaseTrainer): return model def get_validator(self): - return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator + return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console) def criterion(self, preds, targets): return torch.nn.functional.cross_entropy(preds, targets) @@ -66,17 +66,17 @@ class ClassificationTrainer(BaseTrainer): @hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0]) def train(cfg): cfg.model = cfg.model or "squeezenet1_0" - cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist") + cfg.data = cfg.data or "imagenette" # or yolo.ClassificationDataset("mnist") trainer = ClassificationTrainer(cfg) - trainer.run() + trainer.train() if __name__ == "__main__": """ CLI usage: - python ../path/to/train.py train.epochs=10 train.project="name" hyps.lr0=0.1 + python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1 TODO: - Direct cli support, i.e, yolov8 classify_train train.epochs 10 + Direct cli support, i.e, yolov8 classify_train args.epochs 10 """ train() diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index 4657ffc..3d3b4e9 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -1,9 +1,9 @@ import torch -from ultralytics import yolo +from ultralytics.yolo.engine.validator import BaseValidator -class ClassificationValidator(yolo.BaseValidator): +class ClassificationValidator(BaseValidator): def init_metrics(self): self.correct = torch.tensor([])