Add initial model interface (#30)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-10-26 01:21:15 +05:30
committed by GitHub
parent 7b560f7861
commit 1054819a59
12 changed files with 220 additions and 109 deletions

View File

@ -0,0 +1,63 @@
"""
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
"""
import torch
import yaml
import ultralytics.yolo as yolo
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
# map head: [model, trainer]
MODEL_MAP = {
"Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],
"Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp
"Segment": []}
class YOLO:
def __init__(self, version=8) -> None:
self.version = version
self.model = None
self.trainer = None
self.pretrained_weights = None
def new(self, cfg: str):
cfg = check_yaml(cfg) # check YAML
self.model, self.trainer = self._get_model_and_trainer(cfg)
def load(self, weights, autodownload=True):
if not isinstance(self.pretrained_weights, type(None)):
LOGGER.info("Overwriting weights")
# TODO: weights = smart_file_loader(weights)
if self.model:
self.model.load(weights)
LOGGER.info("Checkpoint loaded successfully")
else:
# TODO: infer model and trainer
pass
self.pretrained_weights = weights
def reset(self):
pass
def train(self, **kwargs):
if 'data' not in kwargs:
raise Exception("data is required to train")
if not self.model:
raise Exception("model not initialized. Use .new() or .load()")
kwargs["model"] = self.model
trainer = self.trainer(overrides=kwargs)
trainer.train()
def _get_model_and_trainer(self, cfg):
with open(cfg, encoding='ascii', errors='ignore') as f:
cfg = yaml.safe_load(f) # model dict
model, trainer = MODEL_MAP[cfg["head"][-1][-2]]
# warning: eval is unsafe. Use with caution
trainer = eval(trainer.replace("VERSION", f"v{self.version}"))
return model(cfg), trainer

View File

@ -7,7 +7,7 @@ import time
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Union
from typing import Dict, Union
import torch
import torch.distributed as dist
@ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml"
class BaseTrainer:
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}):
self.console = LOGGER
self.model, self.data, self.train, self.hyps = self._get_config(config)
self.args = self._get_config(config, overrides)
self.validator = None
self.callbacks = defaultdict(list)
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
# Directories
self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
self.wdir = self.save_dir / 'weights'
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
# Save run settings
save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
# device
self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size)
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
self.console.info(f"running on device {self.device}")
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders.
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
self.model = self.get_model()
self.model = self.model.to(self.device)
self.trainset, self.testset = self.get_dataset(self.args.data)
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
@ -63,18 +62,24 @@ class BaseTrainer:
for callback, func in loggers.default_callbacks.items():
self.add_callback(callback, func)
def _get_config(self, config: Union[str, Path, DictConfig] = None):
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
"""
Accepts yaml file name or DictConfig containing experiment configuration.
Returns train and hyps namespace
Returns training args namespace
:param config: Optional file name or DictConfig object
"""
try:
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
return config.model, config.data, config.train, config.hyps
except KeyError as e:
raise KeyError("Missing key(s) in config") from e
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
elif isinstance(config, Dict):
config = OmegaConf.create(config)
# override
if isinstance(overrides, str):
overrides = OmegaConf.load(overrides)
elif isinstance(overrides, Dict):
overrides = OmegaConf.create(overrides)
return OmegaConf.merge(config, overrides)
def add_callback(self, onevent: str, callback):
"""
@ -92,7 +97,7 @@ class BaseTrainer:
for callback in self.callbacks.get(onevent, []):
callback(self)
def run(self):
def train(self):
world_size = torch.cuda.device_count()
if world_size > 1:
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
@ -109,21 +114,21 @@ class BaseTrainer:
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.model = self.model.to(self.device)
self.model = DDP(self.model, device_ids=[rank])
self.train.batch_size = self.train.batch_size // world_size
self.args.batch_size = self.args.batch_size // world_size
def _setup_train(self, rank):
"""
Builds dataloaders and optimizer on correct rank process
"""
self.optimizer = build_optimizer(model=self.model,
name=self.train.optimizer,
lr=self.hyps.lr0,
momentum=self.hyps.momentum,
decay=self.hyps.weight_decay)
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=self.args.weight_decay)
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
if rank in {0, -1}:
print(" Creating testloader rank :", rank)
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
self.validator = self.get_validator()
print("created testloader :", rank)
@ -138,7 +143,7 @@ class BaseTrainer:
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
for epoch in range(self.train.epochs):
for epoch in range(self.args.epochs):
# callback hook. on_epoch_start
self.model.train()
pbar = enumerate(self.train_loader)
@ -165,7 +170,7 @@ class BaseTrainer:
# log
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
if rank in {-1, 0}:
pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
if rank in [-1, 0]:
# validation
@ -174,7 +179,7 @@ class BaseTrainer:
# callback: on_val_end()
# save model
if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs):
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
self.save_model()
# callback; on_model_save
@ -198,7 +203,7 @@ class BaseTrainer:
'ema': None, # deepcopy(ema.ema).half(),
'updates': None, # ema.updates,
'optimizer': None, # optimizer.state_dict(),
'train_args': self.train,
'train_args': self.args,
'date': datetime.now().isoformat()}
# Save last, best and delete
@ -207,22 +212,22 @@ class BaseTrainer:
torch.save(ckpt, self.best)
del ckpt
def get_dataloader(self, path):
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
"""
Returns dataloader derived from torch.data.Dataloader
"""
pass
def get_dataset(self):
def get_dataset(self, data):
"""
Uses self.dataset to download the dataset if needed and verify it.
Download the dataset if needed and verify it.
Returns train and val split datasets
"""
pass
def get_model(self):
def get_model(self, model, pretrained=True):
"""
Uses self.model to load/create/download dataset for any task
load/create/download model for any task
"""
pass
@ -238,7 +243,7 @@ class BaseTrainer:
def preprocess_batch(self, images, labels):
"""
Allows custom preprocessing model inputs and ground truths depeding on task type
Allows custom preprocessing model inputs and ground truths depending on task type
"""
return images.to(self.device, non_blocking=True), labels.to(self.device)