Add initial model interface (#30)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
63
ultralytics/yolo/engine/model.py
Normal file
63
ultralytics/yolo/engine/model.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
|
||||
"""
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
import ultralytics.yolo as yolo
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||
|
||||
# map head: [model, trainer]
|
||||
MODEL_MAP = {
|
||||
"Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],
|
||||
"Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp
|
||||
"Segment": []}
|
||||
|
||||
|
||||
class YOLO:
|
||||
|
||||
def __init__(self, version=8) -> None:
|
||||
self.version = version
|
||||
self.model = None
|
||||
self.trainer = None
|
||||
self.pretrained_weights = None
|
||||
|
||||
def new(self, cfg: str):
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
self.model, self.trainer = self._get_model_and_trainer(cfg)
|
||||
|
||||
def load(self, weights, autodownload=True):
|
||||
if not isinstance(self.pretrained_weights, type(None)):
|
||||
LOGGER.info("Overwriting weights")
|
||||
# TODO: weights = smart_file_loader(weights)
|
||||
if self.model:
|
||||
self.model.load(weights)
|
||||
LOGGER.info("Checkpoint loaded successfully")
|
||||
else:
|
||||
# TODO: infer model and trainer
|
||||
pass
|
||||
|
||||
self.pretrained_weights = weights
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def train(self, **kwargs):
|
||||
if 'data' not in kwargs:
|
||||
raise Exception("data is required to train")
|
||||
if not self.model:
|
||||
raise Exception("model not initialized. Use .new() or .load()")
|
||||
kwargs["model"] = self.model
|
||||
trainer = self.trainer(overrides=kwargs)
|
||||
trainer.train()
|
||||
|
||||
def _get_model_and_trainer(self, cfg):
|
||||
with open(cfg, encoding='ascii', errors='ignore') as f:
|
||||
cfg = yaml.safe_load(f) # model dict
|
||||
model, trainer = MODEL_MAP[cfg["head"][-1][-2]]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
trainer = eval(trainer.replace("VERSION", f"v{self.version}"))
|
||||
|
||||
return model(cfg), trainer
|
@ -7,7 +7,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml"
|
||||
|
||||
class BaseTrainer:
|
||||
|
||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}):
|
||||
self.console = LOGGER
|
||||
self.model, self.data, self.train, self.hyps = self._get_config(config)
|
||||
self.args = self._get_config(config, overrides)
|
||||
self.validator = None
|
||||
self.callbacks = defaultdict(list)
|
||||
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
||||
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
|
||||
# Directories
|
||||
self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
self.wdir = self.save_dir / 'weights'
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
|
||||
|
||||
# Save run settings
|
||||
save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
|
||||
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||
|
||||
# device
|
||||
self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size)
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
||||
self.console.info(f"running on device {self.device}")
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders.
|
||||
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
||||
self.model = self.get_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.trainset, self.testset = self.get_dataset(self.args.data)
|
||||
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
|
||||
|
||||
# epoch level metrics
|
||||
self.metrics = {} # handle metrics returned by validator
|
||||
@ -63,18 +62,24 @@ class BaseTrainer:
|
||||
for callback, func in loggers.default_callbacks.items():
|
||||
self.add_callback(callback, func)
|
||||
|
||||
def _get_config(self, config: Union[str, Path, DictConfig] = None):
|
||||
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||
"""
|
||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||
Returns train and hyps namespace
|
||||
Returns training args namespace
|
||||
:param config: Optional file name or DictConfig object
|
||||
"""
|
||||
try:
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
return config.model, config.data, config.train, config.hyps
|
||||
except KeyError as e:
|
||||
raise KeyError("Missing key(s) in config") from e
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
elif isinstance(config, Dict):
|
||||
config = OmegaConf.create(config)
|
||||
|
||||
# override
|
||||
if isinstance(overrides, str):
|
||||
overrides = OmegaConf.load(overrides)
|
||||
elif isinstance(overrides, Dict):
|
||||
overrides = OmegaConf.create(overrides)
|
||||
|
||||
return OmegaConf.merge(config, overrides)
|
||||
|
||||
def add_callback(self, onevent: str, callback):
|
||||
"""
|
||||
@ -92,7 +97,7 @@ class BaseTrainer:
|
||||
for callback in self.callbacks.get(onevent, []):
|
||||
callback(self)
|
||||
|
||||
def run(self):
|
||||
def train(self):
|
||||
world_size = torch.cuda.device_count()
|
||||
if world_size > 1:
|
||||
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
||||
@ -109,21 +114,21 @@ class BaseTrainer:
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
self.train.batch_size = self.train.batch_size // world_size
|
||||
self.args.batch_size = self.args.batch_size // world_size
|
||||
|
||||
def _setup_train(self, rank):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process
|
||||
"""
|
||||
self.optimizer = build_optimizer(model=self.model,
|
||||
name=self.train.optimizer,
|
||||
lr=self.hyps.lr0,
|
||||
momentum=self.hyps.momentum,
|
||||
decay=self.hyps.weight_decay)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=self.args.weight_decay)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
||||
if rank in {0, -1}:
|
||||
print(" Creating testloader rank :", rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
|
||||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
|
||||
@ -138,7 +143,7 @@ class BaseTrainer:
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
for epoch in range(self.train.epochs):
|
||||
for epoch in range(self.args.epochs):
|
||||
# callback hook. on_epoch_start
|
||||
self.model.train()
|
||||
pbar = enumerate(self.train_loader)
|
||||
@ -165,7 +170,7 @@ class BaseTrainer:
|
||||
# log
|
||||
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||
if rank in {-1, 0}:
|
||||
pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
|
||||
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
|
||||
|
||||
if rank in [-1, 0]:
|
||||
# validation
|
||||
@ -174,7 +179,7 @@ class BaseTrainer:
|
||||
# callback: on_val_end()
|
||||
|
||||
# save model
|
||||
if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs):
|
||||
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
|
||||
self.save_model()
|
||||
# callback; on_model_save
|
||||
|
||||
@ -198,7 +203,7 @@ class BaseTrainer:
|
||||
'ema': None, # deepcopy(ema.ema).half(),
|
||||
'updates': None, # ema.updates,
|
||||
'optimizer': None, # optimizer.state_dict(),
|
||||
'train_args': self.train,
|
||||
'train_args': self.args,
|
||||
'date': datetime.now().isoformat()}
|
||||
|
||||
# Save last, best and delete
|
||||
@ -207,22 +212,22 @@ class BaseTrainer:
|
||||
torch.save(ckpt, self.best)
|
||||
del ckpt
|
||||
|
||||
def get_dataloader(self, path):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_dataset(self):
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
Uses self.dataset to download the dataset if needed and verify it.
|
||||
Download the dataset if needed and verify it.
|
||||
Returns train and val split datasets
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_model(self):
|
||||
def get_model(self, model, pretrained=True):
|
||||
"""
|
||||
Uses self.model to load/create/download dataset for any task
|
||||
load/create/download model for any task
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -238,7 +243,7 @@ class BaseTrainer:
|
||||
|
||||
def preprocess_batch(self, images, labels):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depeding on task type
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type
|
||||
"""
|
||||
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||
|
||||
|
Reference in New Issue
Block a user