[WIP] Model interface (#68)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent e6737f1207
commit 7ae45c6cc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,5 @@
from ultralytics.yolo import v8
from .engine.model import YOLO from .engine.model import YOLO
from .engine.trainer import BaseTrainer from .engine.trainer import BaseTrainer
from .engine.validator import BaseValidator from .engine.validator import BaseValidator

@ -1,55 +1,45 @@
""" """
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13 Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
""" """
import torch
import yaml import yaml
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import attempt_load_weights
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
# map head: [model, trainer] # map head: [model, trainer]
MODEL_MAP = { MODEL_MAP = {
"classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], "classify": [ClassificationModel, 'yolo.VERSION.classify.ClassificationTrainer'],
"detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp "detect": [DetectionModel, 'yolo.VERSION.detect.DetectionTrainer'],
"segment": []} "segment": [SegmentationModel, 'yolo.VERSION.segment.SegmentationTrainer']}
class YOLO: class YOLO:
def __init__(self, task=None, version=8) -> None: def __init__(self, version=8) -> None:
self.version = version self.version = version
self.ModelClass = None self.ModelClass = None
self.TrainerClass = None self.TrainerClass = None
self.model = None self.model = None
self.pretrained_weights = None self.trainer = None
if task: self.task = None
if task.lower() not in MODEL_MAP: self.ckpt = None
raise Exception(f"Unsupported task {task}. The supported tasks are: \n {MODEL_MAP.keys()}")
self.ModelClass, self.TrainerClass = MODEL_MAP[task]
self.TrainerClass = eval(self.trainer.replace("VERSION", f"v{self.version}"))
def new(self, cfg: str): def new(self, cfg: str):
cfg = check_yaml(cfg) # check YAML cfg = check_yaml(cfg) # check YAML
if self.model:
self.model = self.model(cfg)
else:
with open(cfg, encoding='ascii', errors='ignore') as f: with open(cfg, encoding='ascii', errors='ignore') as f:
cfg = yaml.safe_load(f) # model dict cfg = yaml.safe_load(f) # model dict
self.ModelClass, self.TrainerClass = self._get_model_and_trainer(cfg["head"]) self.ModelClass, self.TrainerClass, self.task = self._guess_model_trainer_and_task(cfg["head"][-1][-2])
self.model = self.ModelClass(cfg) # initialize self.model = self.ModelClass(cfg) # initialize
def load(self, weights, autodownload=True): def load(self, weights):
if not isinstance(self.pretrained_weights, type(None)): self.ckpt = torch.load(weights, map_location="cpu")
LOGGER.info("Overwriting weights") self.task = self.ckpt["train_args"]["task"]
# TODO: weights = smart_file_loader(weights) _, trainer_class_literal = MODEL_MAP[self.task]
if self.model: self.TrainerClass = eval(trainer_class_literal.replace("VERSION", f"v{self.version}"))
self.model.load(weights) self.model = attempt_load_weights(weights)
LOGGER.info("Checkpoint loaded successfully")
else:
self.model = get_model(weights)
self.ModelClass, self.TrainerClass = self._guess_model_and_trainer(list(self.model.named_children()))
self.pretrained_weights = weights
def reset(self): def reset(self):
for m in self.model.modules(): for m in self.model.modules():
@ -61,16 +51,31 @@ class YOLO:
def train(self, **kwargs): def train(self, **kwargs):
if 'data' not in kwargs: if 'data' not in kwargs:
raise Exception("data is required to train") raise Exception("data is required to train")
if not self.model: if not self.model and not self.ckpt:
raise Exception("model not initialized. Use .new() or .load()") raise Exception("model not initialized. Use .new() or .load()")
# kwargs["model"] = self.model
trainer = self.TrainerClass(overrides=kwargs)
trainer.model = self.model
trainer.train()
def _guess_model_and_trainer(self, cfg): kwargs["task"] = self.task
kwargs["mode"] = "train"
self.trainer = self.TrainerClass(overrides=kwargs)
# load pre-trained weights if found, else use the loaded model
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
self.trainer.train()
def resume(self, task=None, model=None):
if not task:
raise Exception(
"pass the task type and/or model(optional) from which you want to resume: `model.resume(task="
")`")
if task.lower() not in MODEL_MAP:
raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
_, trainer_class_literal = MODEL_MAP[task.lower()]
self.TrainerClass = eval(trainer_class_literal.replace("VERSION", f"v{self.version}"))
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model if model else True})
self.trainer.train()
def _guess_model_trainer_and_task(self, head):
# TODO: warn # TODO: warn
head = cfg[-1][-2] task = None
if head.lower() in ["classify", "classifier", "cls", "fc"]: if head.lower() in ["classify", "classifier", "cls", "fc"]:
task = "classify" task = "classify"
if head.lower() in ["detect"]: if head.lower() in ["detect"]:
@ -81,11 +86,9 @@ class YOLO:
# warning: eval is unsafe. Use with caution # warning: eval is unsafe. Use with caution
trainer_class = eval(trainer_class.replace("VERSION", f"v{self.version}")) trainer_class = eval(trainer_class.replace("VERSION", f"v{self.version}"))
return model_class, trainer_class return model_class, trainer_class, task
if __name__ == "__main__": def __call__(self, imgs):
model = YOLO() if not self.model:
# model.new("assets/dummy_model.yaml") LOGGER.info("model not initialized!")
model.load("yolov5n-cls.pt") return self.model(imgs)
model.train(data="imagenette160", epochs=1, lr0=0.01)

@ -8,7 +8,6 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Union
import numpy as np import numpy as np
import torch import torch
@ -28,7 +27,6 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +61,7 @@ class BaseTrainer:
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders. # Model and Dataloaders.
self.model = self.args.model
self.data = self.args.data self.data = self.args.data
if self.data.endswith(".yaml"): if self.data.endswith(".yaml"):
self.data = check_dataset_yaml(self.data) self.data = check_dataset_yaml(self.data)
@ -125,6 +124,7 @@ class BaseTrainer:
""" """
# model # model
ckpt = self.setup_model() ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes() self.set_model_attributes()
if world_size > 1: if world_size > 1:
self.model = DDP(self.model, device_ids=[rank]) self.model = DDP(self.model, device_ids=[rank])
@ -288,13 +288,16 @@ class BaseTrainer:
""" """
load/create/download model for any task load/create/download model for any task
""" """
model = self.args.model if isinstance(self.model, torch.nn.Module): # if loaded model is passed
return
# We should improve the code flow here. This function looks hacky
model = self.model
pretrained = not (str(model).endswith(".yaml")) pretrained = not (str(model).endswith(".yaml"))
# config # config
if not pretrained: if not pretrained:
model = check_file(model) model = check_file(model)
ckpt = self.load_ckpt(model) if pretrained else None ckpt = self.load_ckpt(model) if pretrained else None
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model
return ckpt return ckpt
def load_ckpt(self, ckpt): def load_ckpt(self, ckpt):
@ -402,7 +405,7 @@ class BaseTrainer:
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run()) last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
args_yaml = last.parent.parent / 'args.yaml' # train options yaml args_yaml = last.parent.parent / 'args.yaml' # train options yaml
if args_yaml.is_file(): if args_yaml.is_file():
args = self._get_config(args_yaml) # replace args = get_config(args_yaml) # replace
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
self.args = args self.args = args
@ -424,8 +427,7 @@ class BaseTrainer:
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs') f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
if self.epochs < start_epoch: if self.epochs < start_epoch:
LOGGER.info( LOGGER.info(
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
)
self.epochs += ckpt['epoch'] # finetune additional epochs self.epochs += ckpt['epoch'] # finetune additional epochs
self.best_fitness = best_fitness self.best_fitness = best_fitness
self.start_epoch = start_epoch self.start_epoch = start_epoch
@ -460,9 +462,3 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias") f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
return optimizer return optimizer
# Dummy validator
def val(trainer: BaseTrainer):
trainer.console.info("validating")
return {"metric_1": 0.1, "metric_2": 0.2, "fitness": 1}

@ -13,8 +13,10 @@ class ClassificationTrainer(BaseTrainer):
def set_model_attributes(self): def set_model_attributes(self):
self.model.names = self.data["names"] self.model.names = self.data["names"]
def load_model(self, model_cfg, weights): def load_model(self, model_cfg=None, weights=None):
# TODO: why treat clf models as unique. We should have clf yamls? # TODO: why treat clf models as unique. We should have clf yamls?
if isinstance(weights, dict): # yolo ckpt
weights = weights["model"]
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
model = weights model = weights
else: else:

@ -15,7 +15,7 @@ from .val import DetectionValidator
# BaseTrainer python usage # BaseTrainer python usage
class DetectionTrainer(SegmentationTrainer): class DetectionTrainer(SegmentationTrainer):
def load_model(self, model_cfg, weights): def load_model(self, model_cfg=None, weights=None):
model = DetectionModel(model_cfg or weights["model"].yaml, model = DetectionModel(model_cfg or weights["model"].yaml,
ch=3, ch=3,
nc=self.data["nc"], nc=self.data["nc"],

@ -26,7 +26,7 @@ class SegmentationTrainer(BaseTrainer):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
return batch return batch
def load_model(self, model_cfg, weights): def load_model(self, model_cfg=None, weights=None):
model = SegmentationModel(model_cfg or weights["model"].yaml, model = SegmentationModel(model_cfg or weights["model"].yaml,
ch=3, ch=3,
nc=self.data["nc"], nc=self.data["nc"],

Loading…
Cancel
Save