General trainer cleanup (#147)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||
@ -45,8 +45,8 @@ class YOLO:
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.ckpt_path = None
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
|
||||
# Load or create new YOLO model
|
||||
@ -78,7 +78,7 @@ class YOLO:
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
"""
|
||||
self.model = attempt_load_weights(weights)
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.ckpt_path = weights
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args
|
||||
@ -188,14 +188,14 @@ class YOLO:
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||
|
||||
if overrides.get("resume"):
|
||||
overrides["resume"] = self.ckpt_path
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
if not overrides.get("resume"):
|
||||
self.trainer.model = self.trainer.load_model(weights=self.model,
|
||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model # override here to save memory
|
||||
if not overrides.get("resume"): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
|
||||
cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
|
@ -23,6 +23,7 @@ from tqdm import tqdm
|
||||
|
||||
import ultralytics.yolo.utils as utils
|
||||
from ultralytics import __version__
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
@ -380,21 +381,18 @@ class BaseTrainer:
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if loaded model is passed
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
# We should improve the code flow here. This function looks hacky
|
||||
model = self.model
|
||||
pretrained = not str(model).endswith(".yaml")
|
||||
# config
|
||||
if not pretrained:
|
||||
model = check_file(model)
|
||||
ckpt = self.load_ckpt(model) if pretrained else None
|
||||
weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts
|
||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights)
|
||||
return ckpt
|
||||
|
||||
def load_ckpt(self, ckpt):
|
||||
return torch.load(ckpt, map_location='cpu')
|
||||
model, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(model).endswith(".pt"):
|
||||
weights, ckpt = attempt_load_one_weight(model)
|
||||
cfg = ckpt["model"].yaml
|
||||
else:
|
||||
cfg = model
|
||||
self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights)
|
||||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
@ -433,7 +431,7 @@ class BaseTrainer:
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(text)
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
|
@ -1,7 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from ultralytics.nn.tasks import ClassificationModel, get_model
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_weights
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
@ -10,29 +13,47 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
overrides["task"] = "classify"
|
||||
super().__init__(config, overrides)
|
||||
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
# TODO: why treat clf models as unique. We should have clf yamls? YES WE SHOULD!
|
||||
if isinstance(weights, dict): # yolo ckpt
|
||||
weights = weights["model"]
|
||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||
model = weights
|
||||
else:
|
||||
model = ClassificationModel(model_cfg, weights, self.data["nc"])
|
||||
ClassificationModel.reshape_outputs(model, self.data["nc"])
|
||||
for m in model.modules():
|
||||
if not weights and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
def get_model(self, cfg=None, weights=None):
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"])
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def load_ckpt(self, ckpt):
|
||||
return get_model(ckpt)
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
# classification models require special handling
|
||||
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model = self.model
|
||||
pretrained = False
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
model = model.split(".")[0]
|
||||
pretrained = True
|
||||
else:
|
||||
self.model = self.get_model(cfg=model)
|
||||
|
||||
# order: check local file -> torchvision assets -> ultralytics asset
|
||||
if Path(f"{model}.pt").is_file(): # local file
|
||||
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
|
||||
elif model in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
else:
|
||||
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
|
@ -55,10 +55,11 @@ class DetectionTrainer(BaseTrainer):
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||
if weights:
|
||||
model.load(weights, verbose)
|
||||
model.load(model)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
|
@ -18,10 +18,15 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
overrides["task"] = "segment"
|
||||
super().__init__(config, overrides)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||
if weights:
|
||||
model.load(weights, verbose)
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
|
@ -19,6 +19,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
self.args.task = "segment"
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots)
|
||||
|
||||
def preprocess(self, batch):
|
||||
|
Reference in New Issue
Block a user