General trainer cleanup (#147)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent f8a13c49a0
commit 0e5a7ae623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,94 @@
from ultralytics import YOLO
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT
from ultralytics.yolo.v8 import classify, detect, segment
CFG_DET = 'yolov8n.yaml'
CFG_SEG = 'yolov8n-seg.yaml'
CFG_CLS = 'squeezenet1_0'
CFG = get_config(DEFAULT_CONFIG)
SOURCE = ROOT / "assets"
def test_detect():
overrides = {"data": "coco128.yaml", "model": CFG_DET, "imgsz": 32, "epochs": 1, "save": False}
CFG.data = "coco128.yaml"
# trainer
trainer = detect.DetectionTrainer(overrides=overrides)
trainer.train()
trained_model = trainer.best
# Validator
val = detect.DetectionValidator(args=CFG)
val(model=trained_model)
# predictor
pred = detect.DetectionPredictor(overrides={"imgsz": [640, 640]})
pred(source=SOURCE, model=trained_model)
overrides["resume"] = trainer.last
trainer = detect.DetectionTrainer(overrides=overrides)
try:
trainer.train()
except Exception as e:
print(f"Expected exception caught: {e}")
return
Exception("Resume test failed!")
def test_segment():
overrides = {"data": "coco128-seg.yaml", "model": CFG_SEG, "imgsz": 32, "epochs": 1, "save": False}
CFG.data = "coco128-seg.yaml"
CFG.v5loader = False
# YOLO(CFG_SEG).train(**overrides) # This works
# trainer
trainer = segment.SegmentationTrainer(overrides=overrides)
trainer.train()
trained_model = trainer.best
# Validator
val = segment.SegmentationValidator(args=CFG)
val(model=trained_model)
# predictor
pred = segment.SegmentationPredictor(overrides={"imgsz": [640, 640]})
pred(source=SOURCE, model=trained_model)
# test resume
overrides["resume"] = trainer.last
trainer = segment.SegmentationTrainer(overrides=overrides)
try:
trainer.train()
except Exception as e:
print(f"Expected exception caught: {e}")
return
Exception("Resume test failed!")
def test_classify():
overrides = {
"data": "imagenette160",
"model": "squeezenet1_0",
"imgsz": 32,
"epochs": 1,
"batch": 64,
"save": False}
CFG.data = "imagenette160"
CFG.imgsz = 32
CFG.batch = 64
# YOLO(CFG_SEG).train(**overrides) # This works
# trainer
trainer = classify.ClassificationTrainer(overrides=overrides)
trainer.train()
trained_model = trainer.best
# Validator
val = classify.ClassificationValidator(args=CFG)
val(model=trained_model)
# predictor
pred = classify.ClassificationPredictor(overrides={"imgsz": [640, 640]})
pred(source=SOURCE, model=trained_model)

@ -282,6 +282,7 @@ class ClassificationModel(BaseModel):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
LOGGER.info("WARNING: Deprecated in favor of attempt_load_one_weight()")
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download from ultralytics.yolo.utils.downloads import attempt_download
@ -321,6 +322,34 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
return model return model
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Loads a single model weights
from ultralytics.yolo.utils.downloads import attempt_download
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
# Module compatibility updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
m.inplace = inplace # torch 1.7.0 compatibility
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt
return model, ckpt
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary # Parse a YOLOv5 model.yaml dictionary
if verbose: if verbose:
@ -375,16 +404,3 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
ch = [] ch = []
ch.append(c2) ch.append(c2)
return nn.Sequential(*layers), sorted(save) return nn.Sequential(*layers), sorted(save)
def get_model(model='s.pt', pretrained=True):
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
if Path(f"{model}.pt").is_file(): # local file
return attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: # Ultralytics assets
return attempt_load_weights(f"{model}.pt", device='cpu')

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from ultralytics import yolo # noqa from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
@ -45,8 +45,8 @@ class YOLO:
self.trainer = None # trainer object self.trainer = None # trainer object
self.task = None # task type self.task = None # task type
self.ckpt = None # if loaded from *.pt self.ckpt = None # if loaded from *.pt
self.ckpt_path = None
self.cfg = None # if loaded from *.yaml self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object self.overrides = {} # overrides for trainer object
# Load or create new YOLO model # Load or create new YOLO model
@ -78,7 +78,7 @@ class YOLO:
Args: Args:
weights (str): model checkpoint to be loaded weights (str): model checkpoint to be loaded
""" """
self.model = attempt_load_weights(weights) self.model, self.ckpt = attempt_load_one_weight(weights)
self.ckpt_path = weights self.ckpt_path = weights
self.task = self.model.args["task"] self.task = self.model.args["task"]
self.overrides = self.model.args self.overrides = self.model.args
@ -188,14 +188,14 @@ class YOLO:
overrides["mode"] = "train" overrides["mode"] = "train"
if not overrides.get("data"): if not overrides.get("data"):
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.") raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
if overrides.get("resume"): if overrides.get("resume"):
overrides["resume"] = self.ckpt_path overrides["resume"] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides) self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): if not overrides.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.load_model(weights=self.model, self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
model_cfg=self.model.yaml if self.task != "classify" else None) cfg=self.model.yaml if self.task != "classify" else None)
self.model = self.trainer.model # override here to save memory self.model = self.trainer.model
self.trainer.train() self.trainer.train()

@ -23,6 +23,7 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils import ultralytics.yolo.utils as utils
from ultralytics import __version__ from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
@ -380,21 +381,18 @@ class BaseTrainer:
""" """
load/create/download model for any task load/create/download model for any task
""" """
if isinstance(self.model, torch.nn.Module): # if loaded model is passed if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return return
# We should improve the code flow here. This function looks hacky
model = self.model
pretrained = not str(model).endswith(".yaml")
# config
if not pretrained:
model = check_file(model)
ckpt = self.load_ckpt(model) if pretrained else None
weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts
self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights)
return ckpt
def load_ckpt(self, ckpt): model, weights = self.model, None
return torch.load(ckpt, map_location='cpu') ckpt = None
if str(model).endswith(".pt"):
weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt["model"].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights)
return ckpt
def optimizer_step(self): def optimizer_step(self):
self.scaler.unscale_(self.optimizer) # unscale gradients self.scaler.unscale_(self.optimizer) # unscale gradients
@ -433,7 +431,7 @@ class BaseTrainer:
if rank in {-1, 0}: if rank in {-1, 0}:
self.console.info(text) self.console.info(text)
def load_model(self, model_cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
raise NotImplementedError("This task trainer doesn't support loading cfg files") raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self): def get_validator(self):

@ -1,7 +1,10 @@
from pathlib import Path
import hydra import hydra
import torch import torch
import torchvision
from ultralytics.nn.tasks import ClassificationModel, get_model from ultralytics.nn.tasks import ClassificationModel, attempt_load_weights
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
@ -10,29 +13,47 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
overrides["task"] = "classify"
super().__init__(config, overrides)
def set_model_attributes(self): def set_model_attributes(self):
self.model.names = self.data["names"] self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None):
# TODO: why treat clf models as unique. We should have clf yamls? YES WE SHOULD! model = ClassificationModel(cfg, nc=self.data["nc"])
if isinstance(weights, dict): # yolo ckpt if weights:
weights = weights["model"] model.load(weights)
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
model = weights
else:
model = ClassificationModel(model_cfg, weights, self.data["nc"])
ClassificationModel.reshape_outputs(model, self.data["nc"])
for m in model.modules():
if not weights and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model return model
def load_ckpt(self, ckpt): def setup_model(self):
return get_model(ckpt) """
load/create/download model for any task
"""
# classification models require special handling
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = self.model
pretrained = False
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
pretrained = True
else:
self.model = self.get_model(cfg=model)
# order: check local file -> torchvision assets -> ultralytics asset
if Path(f"{model}.pt").is_file(): # local file
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
return # dont return ckpt. Classification doesn't support resume
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"): def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
return build_classification_dataloader(path=dataset_path, return build_classification_dataloader(path=dataset_path,

@ -55,10 +55,11 @@ class DetectionTrainer(BaseTrainer):
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
self.model.names = self.data["names"] self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
if weights: if weights:
model.load(weights, verbose) model.load(model)
return model return model
def get_validator(self): def get_validator(self):

@ -18,10 +18,15 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage # BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True): def __init__(self, config=DEFAULT_CONFIG, overrides={}):
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) overrides["task"] = "segment"
super().__init__(config, overrides)
def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
if weights: if weights:
model.load(weights, verbose) model.load(weights)
return model return model
def get_validator(self): def get_validator(self):

@ -19,6 +19,7 @@ class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, logger, args)
self.args.task = "segment"
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots) self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots)
def preprocess(self, batch): def preprocess(self, batch):

Loading…
Cancel
Save