diff --git a/ultralytics/tests/test_model.py b/ultralytics/tests/test_model.py index bd1b6ee..7a6e2a6 100644 --- a/ultralytics/tests/test_model.py +++ b/ultralytics/tests/test_model.py @@ -49,6 +49,16 @@ def test_model_resume(): print("Successfully caught resume assert!") +def test_model_train_pretrained(): + model = YOLO() + model.load("balloon-detect.pt") + model.train(data="coco128.yaml", epochs=1, img_size=32) + model.new("yolov5n.yaml") + model.train(data="coco128.yaml", epochs=1, img_size=32) + img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) + model(img) + + def test(): test_model_forward() test_model_info() @@ -56,6 +66,7 @@ def test(): test_visualize_preds() test_val() test_model_resume() + test_model_train_pretrained() if __name__ == "__main__": diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index b499846..ea6138d 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,8 +1,8 @@ import torch import yaml -from omegaconf import OmegaConf from ultralytics import yolo +from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils.checks import check_yaml @@ -146,7 +146,7 @@ class YOLO: **kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section. You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed """ - if not self.model and not self.ckpt: + if not self.model: raise Exception("model not initialized. Use .new() or .load()") overrides = kwargs @@ -159,8 +159,10 @@ class YOLO: raise Exception("dataset not provided! Please check if you have defined `data` in you configs") self.trainer = self.TrainerClass(overrides=overrides) - # load pre-trained weights if found, else use the loaded model - self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model + self.trainer.model = self.trainer.load_model(weights=self.ckpt, + model_cfg=self.model.yaml if self.task != "classify" else None) + self.model = self.trainer.model # override here to save memory + self.trainer.train() def resume(self, task=None, model=None): @@ -199,6 +201,9 @@ class YOLO: return task + def to(self, device): + self.model.to(device) + def _guess_ops_from_task(self, task): model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task] # warning: eval is unsafe. Use with caution