Add initial model interface (#30)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 7b560f7861
commit 1054819a59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,13 @@
from ultralytics.yolo import YOLO
def test_model():
model = YOLO()
model.new("assets/dummy_model.yaml")
model.model = "squeezenet1_0" # temp solution before get_model is implemented
# model.load("yolov5n.pt")
model.train(data="imagenette160", epochs=1, lr0=0.01)
if __name__ == "__main__":
test_model()

@ -1,4 +1,7 @@
import ultralytics.yolo.v8 as v8
from .engine.model import YOLO
from .engine.trainer import BaseTrainer from .engine.trainer import BaseTrainer
from .engine.validator import BaseValidator from .engine.validator import BaseValidator
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import __all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import

@ -728,7 +728,7 @@ def classify_albumentations(
if vflip > 0: if vflip > 0:
T += [A.VerticalFlip(p=vflip)] T += [A.VerticalFlip(p=vflip)]
if jitter > 0: if jitter > 0:
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue
T += [A.ColorJitter(*color_jitter, 0)] T += [A.ColorJitter(*color_jitter, 0)]
else: # Use fixed crop for eval set (reproducibility) else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]

@ -51,7 +51,8 @@ def exif_size(img):
def verify_image_label(args): def verify_image_label(args):
# Verify one image-label pair # Verify one image-label pair
im_file, lb_file, prefix, keypoint = args im_file, lb_file, prefix, keypoint = args
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None # number (missing, found, empty, corrupt), message, segments, keypoints # number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None
try: try:
# verify images # verify images
im = Image.open(im_file) im = Image.open(im_file)
@ -86,10 +87,10 @@ def verify_image_label(args):
kpts = np.zeros((lb.shape[0], 39)) kpts = np.zeros((lb.shape[0], 39))
for i in range(len(lb)): for i in range(len(lb)):
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
3)) # remove the occlusion paramater from the GT 3)) # remove the occlusion parameter from the GT
kpts[i] = np.hstack((lb[i, :5], kpt)) kpts[i] = np.hstack((lb[i, :5], kpt))
lb = kpts lb = kpts
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater" assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
else: else:
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}" assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"

@ -0,0 +1,63 @@
"""
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
"""
import torch
import yaml
import ultralytics.yolo as yolo
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
# map head: [model, trainer]
MODEL_MAP = {
"Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],
"Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp
"Segment": []}
class YOLO:
def __init__(self, version=8) -> None:
self.version = version
self.model = None
self.trainer = None
self.pretrained_weights = None
def new(self, cfg: str):
cfg = check_yaml(cfg) # check YAML
self.model, self.trainer = self._get_model_and_trainer(cfg)
def load(self, weights, autodownload=True):
if not isinstance(self.pretrained_weights, type(None)):
LOGGER.info("Overwriting weights")
# TODO: weights = smart_file_loader(weights)
if self.model:
self.model.load(weights)
LOGGER.info("Checkpoint loaded successfully")
else:
# TODO: infer model and trainer
pass
self.pretrained_weights = weights
def reset(self):
pass
def train(self, **kwargs):
if 'data' not in kwargs:
raise Exception("data is required to train")
if not self.model:
raise Exception("model not initialized. Use .new() or .load()")
kwargs["model"] = self.model
trainer = self.trainer(overrides=kwargs)
trainer.train()
def _get_model_and_trainer(self, cfg):
with open(cfg, encoding='ascii', errors='ignore') as f:
cfg = yaml.safe_load(f) # model dict
model, trainer = MODEL_MAP[cfg["head"][-1][-2]]
# warning: eval is unsafe. Use with caution
trainer = eval(trainer.replace("VERSION", f"v{self.version}"))
return model(cfg), trainer

@ -7,7 +7,7 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Union from typing import Dict, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml"
class BaseTrainer: class BaseTrainer:
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG): def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}):
self.console = LOGGER self.console = LOGGER
self.model, self.data, self.train, self.hyps = self._get_config(config) self.args = self._get_config(config, overrides)
self.validator = None self.validator = None
self.callbacks = defaultdict(list) self.callbacks = defaultdict(list)
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug self.console.info(f"Training config: \n args: \n {self.args}") # to debug
# Directories # Directories
self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok) self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
self.wdir = self.save_dir / 'weights' self.wdir = self.save_dir / 'weights'
self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
# Save run settings # Save run settings
save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True)) save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
# device # device
self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size) self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
self.console.info(f"running on device {self.device}") self.console.info(f"running on device {self.device}")
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders. # Model and Dataloaders.
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model self.trainset, self.testset = self.get_dataset(self.args.data)
self.model = self.get_model() self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
self.model = self.model.to(self.device)
# epoch level metrics # epoch level metrics
self.metrics = {} # handle metrics returned by validator self.metrics = {} # handle metrics returned by validator
@ -63,18 +62,24 @@ class BaseTrainer:
for callback, func in loggers.default_callbacks.items(): for callback, func in loggers.default_callbacks.items():
self.add_callback(callback, func) self.add_callback(callback, func)
def _get_config(self, config: Union[str, Path, DictConfig] = None): def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
""" """
Accepts yaml file name or DictConfig containing experiment configuration. Accepts yaml file name or DictConfig containing experiment configuration.
Returns train and hyps namespace Returns training args namespace
:param config: Optional file name or DictConfig object :param config: Optional file name or DictConfig object
""" """
try:
if isinstance(config, (str, Path)): if isinstance(config, (str, Path)):
config = OmegaConf.load(config) config = OmegaConf.load(config)
return config.model, config.data, config.train, config.hyps elif isinstance(config, Dict):
except KeyError as e: config = OmegaConf.create(config)
raise KeyError("Missing key(s) in config") from e
# override
if isinstance(overrides, str):
overrides = OmegaConf.load(overrides)
elif isinstance(overrides, Dict):
overrides = OmegaConf.create(overrides)
return OmegaConf.merge(config, overrides)
def add_callback(self, onevent: str, callback): def add_callback(self, onevent: str, callback):
""" """
@ -92,7 +97,7 @@ class BaseTrainer:
for callback in self.callbacks.get(onevent, []): for callback in self.callbacks.get(onevent, []):
callback(self) callback(self)
def run(self): def train(self):
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
if world_size > 1: if world_size > 1:
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
@ -109,21 +114,21 @@ class BaseTrainer:
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.model = DDP(self.model, device_ids=[rank]) self.model = DDP(self.model, device_ids=[rank])
self.train.batch_size = self.train.batch_size // world_size self.args.batch_size = self.args.batch_size // world_size
def _setup_train(self, rank): def _setup_train(self, rank):
""" """
Builds dataloaders and optimizer on correct rank process Builds dataloaders and optimizer on correct rank process
""" """
self.optimizer = build_optimizer(model=self.model, self.optimizer = build_optimizer(model=self.model,
name=self.train.optimizer, name=self.args.optimizer,
lr=self.hyps.lr0, lr=self.args.lr0,
momentum=self.hyps.momentum, momentum=self.args.momentum,
decay=self.hyps.weight_decay) decay=self.args.weight_decay)
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank) self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
if rank in {0, -1}: if rank in {0, -1}:
print(" Creating testloader rank :", rank) print(" Creating testloader rank :", rank)
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank) self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
self.validator = self.get_validator() self.validator = self.get_validator()
print("created testloader :", rank) print("created testloader :", rank)
@ -138,7 +143,7 @@ class BaseTrainer:
self.epoch_time = None self.epoch_time = None
self.epoch_time_start = time.time() self.epoch_time_start = time.time()
self.train_time_start = time.time() self.train_time_start = time.time()
for epoch in range(self.train.epochs): for epoch in range(self.args.epochs):
# callback hook. on_epoch_start # callback hook. on_epoch_start
self.model.train() self.model.train()
pbar = enumerate(self.train_loader) pbar = enumerate(self.train_loader)
@ -165,7 +170,7 @@ class BaseTrainer:
# log # log
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
if rank in {-1, 0}: if rank in {-1, 0}:
pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
if rank in [-1, 0]: if rank in [-1, 0]:
# validation # validation
@ -174,7 +179,7 @@ class BaseTrainer:
# callback: on_val_end() # callback: on_val_end()
# save model # save model
if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs): if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
self.save_model() self.save_model()
# callback; on_model_save # callback; on_model_save
@ -198,7 +203,7 @@ class BaseTrainer:
'ema': None, # deepcopy(ema.ema).half(), 'ema': None, # deepcopy(ema.ema).half(),
'updates': None, # ema.updates, 'updates': None, # ema.updates,
'optimizer': None, # optimizer.state_dict(), 'optimizer': None, # optimizer.state_dict(),
'train_args': self.train, 'train_args': self.args,
'date': datetime.now().isoformat()} 'date': datetime.now().isoformat()}
# Save last, best and delete # Save last, best and delete
@ -207,22 +212,22 @@ class BaseTrainer:
torch.save(ckpt, self.best) torch.save(ckpt, self.best)
del ckpt del ckpt
def get_dataloader(self, path): def get_dataloader(self, dataset_path, batch_size=16, rank=0):
""" """
Returns dataloader derived from torch.data.Dataloader Returns dataloader derived from torch.data.Dataloader
""" """
pass pass
def get_dataset(self): def get_dataset(self, data):
""" """
Uses self.dataset to download the dataset if needed and verify it. Download the dataset if needed and verify it.
Returns train and val split datasets Returns train and val split datasets
""" """
pass pass
def get_model(self): def get_model(self, model, pretrained=True):
""" """
Uses self.model to load/create/download dataset for any task load/create/download model for any task
""" """
pass pass
@ -238,7 +243,7 @@ class BaseTrainer:
def preprocess_batch(self, images, labels): def preprocess_batch(self, images, labels):
""" """
Allows custom preprocessing model inputs and ground truths depeding on task type Allows custom preprocessing model inputs and ground truths depending on task type
""" """
return images.to(self.device, non_blocking=True), labels.to(self.device) return images.to(self.device, non_blocking=True), labels.to(self.device)

@ -1,6 +1,7 @@
model: null model: null
data: null data: null
train:
# Training options
epochs: 300 epochs: 300
batch_size: 16 batch_size: 16
img_size: 640 img_size: 640
@ -16,8 +17,9 @@ train:
verbose: False verbose: False
seed: 0 seed: 0
local_rank: -1 local_rank: -1
#-----------------------------------#
hyps: # Hyper-parameters
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1 momentum: 0.937 # SGD momentum/Adam beta1
@ -48,6 +50,7 @@ hyps:
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
# Hydra configs -------------------------------------
# to disable hydra directory creation # to disable hydra directory creation
hydra: hydra:
output_subdir: null output_subdir: null

@ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.anchors import check_anchor_order from ultralytics.yolo.utils.anchors import check_anchor_order
from ultralytics.yolo.utils.modeling import parse_model from ultralytics.yolo.utils.modeling import parse_model
from ultralytics.yolo.utils.modeling.modules import * from ultralytics.yolo.utils.modeling.modules import *
from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info,
scale_img, time_sync)
class BaseModel(nn.Module): class BaseModel(nn.Module):
@ -67,6 +68,10 @@ class BaseModel(nn.Module):
m.anchor_grid = list(map(fn, m.anchor_grid)) m.anchor_grid = list(map(fn, m.anchor_grid))
return self return self
def load(self, weights):
# Force all tasks implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!")
class DetectionModel(BaseModel): class DetectionModel(BaseModel):
# YOLO detection model # YOLO detection model
@ -166,6 +171,12 @@ class DetectionModel(BaseModel):
b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def load(self, weights):
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
class SegmentationModel(DetectionModel): class SegmentationModel(DetectionModel):
# YOLOv5 segmentation model # YOLOv5 segmentation model
@ -197,3 +208,9 @@ class ClassificationModel(BaseModel):
def _from_yaml(self, cfg): def _from_yaml(self, cfg):
# Create a YOLOv5 classification model from a *.yaml file # Create a YOLOv5 classification model from a *.yaml file
self.model = None self.model = None
def load(self, weights):
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load

@ -174,3 +174,8 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
return decorate return decorate
def intersect_state_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}

@ -1,3 +1,4 @@
from ultralytics.yolo.v8.classify import train from ultralytics.yolo.v8.classify.train import ClassificationTrainer
from ultralytics.yolo.v8.classify.val import ClassificationValidator
__all__ = ["train"] __all__ = ["train"]

@ -5,11 +5,10 @@ from pathlib import Path
import hydra import hydra
import torch import torch
import torchvision import torchvision
from val import ClassificationValidator
from ultralytics.yolo import BaseTrainer, v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import WorkingDirectory from ultralytics.yolo.utils.files import WorkingDirectory
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
@ -18,9 +17,9 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer
# BaseTrainer python usage # BaseTrainer python usage
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def get_dataset(self): def get_dataset(self, dataset):
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
data = Path("datasets") / self.data data = Path("datasets") / dataset
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
data_dir = data if data.is_dir() else (Path.cwd() / data) data_dir = data if data.is_dir() else (Path.cwd() / data)
if not data_dir.is_dir(): if not data_dir.is_dir():
@ -29,7 +28,7 @@ class ClassificationTrainer(BaseTrainer):
if str(data) == 'imagenet': if str(data) == 'imagenet':
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else: else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{self.data}.zip' url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent) download(url, dir=data_dir.parent)
# TODO: add colorstr # TODO: add colorstr
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
@ -39,17 +38,18 @@ class ClassificationTrainer(BaseTrainer):
return train_set, test_set return train_set, test_set
def get_dataloader(self, dataset, batch_size=None, rank=0): def get_dataloader(self, dataset_path, batch_size=None, rank=0):
return build_classification_dataloader(path=dataset, batch_size=self.train.batch_size, rank=rank) return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank)
def get_model(self): def get_model(self, model, pretrained):
# temp. minimal. only supports torchvision models # temp. minimal. only supports torchvision models
if self.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0 model = self.args.model
model = torchvision.models.__dict__[self.model](weights='IMAGENET1K_V1' if self.train.pretrained else None) if model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: else:
raise ModuleNotFoundError(f'--model {self.model} not found.') raise ModuleNotFoundError(f'--model {model} not found.')
for m in model.modules(): for m in model.modules():
if not self.train.pretrained and hasattr(m, 'reset_parameters'): if not pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters() m.reset_parameters()
for p in model.parameters(): for p in model.parameters():
p.requires_grad = True # for training p.requires_grad = True # for training
@ -57,7 +57,7 @@ class ClassificationTrainer(BaseTrainer):
return model return model
def get_validator(self): def get_validator(self):
return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
def criterion(self, preds, targets): def criterion(self, preds, targets):
return torch.nn.functional.cross_entropy(preds, targets) return torch.nn.functional.cross_entropy(preds, targets)
@ -66,17 +66,17 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0]) @hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
def train(cfg): def train(cfg):
cfg.model = cfg.model or "squeezenet1_0" cfg.model = cfg.model or "squeezenet1_0"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "imagenette" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg) trainer = ClassificationTrainer(cfg)
trainer.run() trainer.train()
if __name__ == "__main__": if __name__ == "__main__":
""" """
CLI usage: CLI usage:
python ../path/to/train.py train.epochs=10 train.project="name" hyps.lr0=0.1 python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1
TODO: TODO:
Direct cli support, i.e, yolov8 classify_train train.epochs 10 Direct cli support, i.e, yolov8 classify_train args.epochs 10
""" """
train() train()

@ -1,9 +1,9 @@
import torch import torch
from ultralytics import yolo from ultralytics.yolo.engine.validator import BaseValidator
class ClassificationValidator(yolo.BaseValidator): class ClassificationValidator(BaseValidator):
def init_metrics(self): def init_metrics(self):
self.correct = torch.tensor([]) self.correct = torch.tensor([])

Loading…
Cancel
Save