diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..7253750 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,94 @@ +from ultralytics import YOLO +from ultralytics.yolo.configs import get_config +from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT +from ultralytics.yolo.v8 import classify, detect, segment + +CFG_DET = 'yolov8n.yaml' +CFG_SEG = 'yolov8n-seg.yaml' +CFG_CLS = 'squeezenet1_0' +CFG = get_config(DEFAULT_CONFIG) +SOURCE = ROOT / "assets" + + +def test_detect(): + overrides = {"data": "coco128.yaml", "model": CFG_DET, "imgsz": 32, "epochs": 1, "save": False} + CFG.data = "coco128.yaml" + # trainer + trainer = detect.DetectionTrainer(overrides=overrides) + trainer.train() + trained_model = trainer.best + + # Validator + val = detect.DetectionValidator(args=CFG) + val(model=trained_model) + + # predictor + pred = detect.DetectionPredictor(overrides={"imgsz": [640, 640]}) + pred(source=SOURCE, model=trained_model) + + overrides["resume"] = trainer.last + trainer = detect.DetectionTrainer(overrides=overrides) + try: + trainer.train() + except Exception as e: + print(f"Expected exception caught: {e}") + return + + Exception("Resume test failed!") + + +def test_segment(): + overrides = {"data": "coco128-seg.yaml", "model": CFG_SEG, "imgsz": 32, "epochs": 1, "save": False} + CFG.data = "coco128-seg.yaml" + CFG.v5loader = False + + # YOLO(CFG_SEG).train(**overrides) # This works + # trainer + trainer = segment.SegmentationTrainer(overrides=overrides) + trainer.train() + trained_model = trainer.best + + # Validator + val = segment.SegmentationValidator(args=CFG) + val(model=trained_model) + + # predictor + pred = segment.SegmentationPredictor(overrides={"imgsz": [640, 640]}) + pred(source=SOURCE, model=trained_model) + + # test resume + overrides["resume"] = trainer.last + trainer = segment.SegmentationTrainer(overrides=overrides) + try: + trainer.train() + except Exception as e: + print(f"Expected exception caught: {e}") + return + + Exception("Resume test failed!") + + +def test_classify(): + overrides = { + "data": "imagenette160", + "model": "squeezenet1_0", + "imgsz": 32, + "epochs": 1, + "batch": 64, + "save": False} + CFG.data = "imagenette160" + CFG.imgsz = 32 + CFG.batch = 64 + # YOLO(CFG_SEG).train(**overrides) # This works + # trainer + trainer = classify.ClassificationTrainer(overrides=overrides) + trainer.train() + trained_model = trainer.best + + # Validator + val = classify.ClassificationValidator(args=CFG) + val(model=trained_model) + + # predictor + pred = classify.ClassificationPredictor(overrides={"imgsz": [640, 640]}) + pred(source=SOURCE, model=trained_model) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 5243129..b9b34b5 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -282,6 +282,7 @@ class ClassificationModel(BaseModel): def attempt_load_weights(weights, device=None, inplace=True, fuse=False): + LOGGER.info("WARNING: Deprecated in favor of attempt_load_one_weight()") # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a from ultralytics.yolo.utils.downloads import attempt_download @@ -321,6 +322,34 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): return model +def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): + # Loads a single model weights + from ultralytics.yolo.utils.downloads import attempt_download + + ckpt = torch.load(attempt_download(weight), map_location='cpu') # load + args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args + model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + + # Model compatibility updates + model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model + model.pt_path = weight # attach *.pt file path to model + if not hasattr(model, 'stride'): + model.stride = torch.tensor([32.]) + + model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode + + # Module compatibility updates + for m in model.modules(): + t = type(m) + if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment): + m.inplace = inplace # torch 1.7.0 compatibility + elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): + m.recompute_scale_factor = None # torch 1.11.0 compatibility + + # Return model and ckpt + return model, ckpt + + def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) # Parse a YOLOv5 model.yaml dictionary if verbose: @@ -375,16 +404,3 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) ch = [] ch.append(c2) return nn.Sequential(*layers), sorted(save) - - -def get_model(model='s.pt', pretrained=True): - # Load a YOLO model locally, from torchvision, or from Ultralytics assets - if model.endswith(".pt"): - model = model.split(".")[0] - - if Path(f"{model}.pt").is_file(): # local file - return attempt_load_weights(f"{model}.pt", device='cpu') - elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0 - return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) - else: # Ultralytics assets - return attempt_load_weights(f"{model}.pt", device='cpu') diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 7c0bd30..4f97a3d 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,7 +1,7 @@ from pathlib import Path from ultralytics import yolo # noqa -from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight from ultralytics.yolo.configs import get_config from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load @@ -45,8 +45,8 @@ class YOLO: self.trainer = None # trainer object self.task = None # task type self.ckpt = None # if loaded from *.pt - self.ckpt_path = None self.cfg = None # if loaded from *.yaml + self.ckpt_path = None self.overrides = {} # overrides for trainer object # Load or create new YOLO model @@ -78,7 +78,7 @@ class YOLO: Args: weights (str): model checkpoint to be loaded """ - self.model = attempt_load_weights(weights) + self.model, self.ckpt = attempt_load_one_weight(weights) self.ckpt_path = weights self.task = self.model.args["task"] self.overrides = self.model.args @@ -188,14 +188,14 @@ class YOLO: overrides["mode"] = "train" if not overrides.get("data"): raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.") - if overrides.get("resume"): overrides["resume"] = self.ckpt_path + self.trainer = self.TrainerClass(overrides=overrides) - if not overrides.get("resume"): - self.trainer.model = self.trainer.load_model(weights=self.model, - model_cfg=self.model.yaml if self.task != "classify" else None) - self.model = self.trainer.model # override here to save memory + if not overrides.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, + cfg=self.model.yaml if self.task != "classify" else None) + self.model = self.trainer.model self.trainer.train() diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 7c60454..81dbb4a 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -23,6 +23,7 @@ from tqdm import tqdm import ultralytics.yolo.utils as utils from ultralytics import __version__ +from ultralytics.nn.tasks import attempt_load_one_weight from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, @@ -380,21 +381,18 @@ class BaseTrainer: """ load/create/download model for any task """ - if isinstance(self.model, torch.nn.Module): # if loaded model is passed + if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed return - # We should improve the code flow here. This function looks hacky - model = self.model - pretrained = not str(model).endswith(".yaml") - # config - if not pretrained: - model = check_file(model) - ckpt = self.load_ckpt(model) if pretrained else None - weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts - self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights) - return ckpt - def load_ckpt(self, ckpt): - return torch.load(ckpt, map_location='cpu') + model, weights = self.model, None + ckpt = None + if str(model).endswith(".pt"): + weights, ckpt = attempt_load_one_weight(model) + cfg = ckpt["model"].yaml + else: + cfg = model + self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights) + return ckpt def optimizer_step(self): self.scaler.unscale_(self.optimizer) # unscale gradients @@ -433,7 +431,7 @@ class BaseTrainer: if rank in {-1, 0}: self.console.info(text) - def load_model(self, model_cfg=None, weights=None, verbose=True): + def get_model(self, cfg=None, weights=None, verbose=True): raise NotImplementedError("This task trainer doesn't support loading cfg files") def get_validator(self): diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 0a6c71b..6d5bf0e 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -1,7 +1,10 @@ +from pathlib import Path + import hydra import torch +import torchvision -from ultralytics.nn.tasks import ClassificationModel, get_model +from ultralytics.nn.tasks import ClassificationModel, attempt_load_weights from ultralytics.yolo import v8 from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.engine.trainer import BaseTrainer @@ -10,29 +13,47 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG class ClassificationTrainer(BaseTrainer): + def __init__(self, config=DEFAULT_CONFIG, overrides={}): + overrides["task"] = "classify" + super().__init__(config, overrides) + def set_model_attributes(self): self.model.names = self.data["names"] - def load_model(self, model_cfg=None, weights=None, verbose=True): - # TODO: why treat clf models as unique. We should have clf yamls? YES WE SHOULD! - if isinstance(weights, dict): # yolo ckpt - weights = weights["model"] - if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision - model = weights - else: - model = ClassificationModel(model_cfg, weights, self.data["nc"]) - ClassificationModel.reshape_outputs(model, self.data["nc"]) - for m in model.modules(): - if not weights and hasattr(m, 'reset_parameters'): - m.reset_parameters() - if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None: - m.p = self.args.dropout # set dropout - for p in model.parameters(): - p.requires_grad = True # for training + def get_model(self, cfg=None, weights=None): + model = ClassificationModel(cfg, nc=self.data["nc"]) + if weights: + model.load(weights) + return model - def load_ckpt(self, ckpt): - return get_model(ckpt) + def setup_model(self): + """ + load/create/download model for any task + """ + # classification models require special handling + + if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed + return + + model = self.model + pretrained = False + # Load a YOLO model locally, from torchvision, or from Ultralytics assets + if model.endswith(".pt"): + model = model.split(".")[0] + pretrained = True + else: + self.model = self.get_model(cfg=model) + + # order: check local file -> torchvision assets -> ultralytics asset + if Path(f"{model}.pt").is_file(): # local file + self.model = attempt_load_weights(f"{model}.pt", device='cpu') + elif model in torchvision.models.__dict__: + self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) + else: + self.model = attempt_load_weights(f"{model}.pt", device='cpu') + + return # dont return ckpt. Classification doesn't support resume def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"): return build_classification_dataloader(path=dataset_path, diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index f8a5e17..1e3dd18 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -55,10 +55,11 @@ class DetectionTrainer(BaseTrainer): # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc self.model.names = self.data["names"] - def load_model(self, model_cfg=None, weights=None, verbose=True): - model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) + def get_model(self, cfg=None, weights=None, verbose=True): + model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) if weights: - model.load(weights, verbose) + model.load(model) + return model def get_validator(self): diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 6445355..7e32bed 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -18,10 +18,15 @@ from ultralytics.yolo.utils.torch_utils import de_parallel # BaseTrainer python usage class SegmentationTrainer(v8.detect.DetectionTrainer): - def load_model(self, model_cfg=None, weights=None, verbose=True): - model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) + def __init__(self, config=DEFAULT_CONFIG, overrides={}): + overrides["task"] = "segment" + super().__init__(config, overrides) + + def get_model(self, cfg=None, weights=None, verbose=True): + model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) if weights: - model.load(weights, verbose) + model.load(weights) + return model def get_validator(self): diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 801e006..288d378 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -19,6 +19,7 @@ class SegmentationValidator(DetectionValidator): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): super().__init__(dataloader, save_dir, pbar, logger, args) + self.args.task = "segment" self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots) def preprocess(self, batch):