From 6d5123297e798578085437bdec61341ef53f72c0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 3 Jan 2023 00:31:17 +0530 Subject: [PATCH] Fix CLI detect and segment resume (#134) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/engine/model.py | 15 ++++++++++----- ultralytics/yolo/v8/segment/train.py | 7 +++++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 728eda7..9c89e9b 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -47,6 +47,7 @@ class YOLO: self.trainer = None # trainer object self.task = None # task type self.ckpt = None # if loaded from *.pt + self.ckpt_path = None self.cfg = None # if loaded from *.yaml self.overrides = {} # overrides for trainer object self.init_disabled = False # disable model initialization @@ -78,6 +79,7 @@ class YOLO: weights (str): model checkpoint to be loaded """ self.model = attempt_load_weights(weights) + self.ckpt_path = weights self.task = self.model.args["task"] self.overrides = self.model.args self.overrides["device"] = '' # reset device @@ -177,8 +179,8 @@ class YOLO: """ if not self.model: raise AttributeError("model not initialized. Use .new() or .load()") - - overrides = kwargs + overrides = self.overrides.copy() + overrides.update(kwargs) if kwargs.get("cfg"): LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") overrides = yaml_load(check_yaml(kwargs["cfg"])) @@ -187,10 +189,13 @@ class YOLO: if not overrides.get("data"): raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.") + if overrides.get("resume"): + overrides["resume"] = self.ckpt_path self.trainer = self.TrainerClass(overrides=overrides) - self.trainer.model = self.trainer.load_model(weights=self.model, - model_cfg=self.model.yaml if self.task != "classify" else None) - self.model = self.trainer.model # override here to save memory + if not overrides.get("resume"): + self.trainer.model = self.trainer.load_model(weights=self.model, + model_cfg=self.model.yaml if self.task != "classify" else None) + self.model = self.trainer.model # override here to save memory self.trainer.train() diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index f8f146d..487114e 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -17,9 +17,12 @@ from ultralytics.yolo.utils.torch_utils import de_parallel class SegmentationTrainer(v8.detect.DetectionTrainer): def load_model(self, model_cfg=None, weights=None, verbose=True): - model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) + model = SegmentationModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml, + ch=3, + nc=self.data["nc"], + verbose=verbose) if weights: - model.load(weights, verbose) + model.load(weights['model'] if isinstance(weights, dict) else weights, verbose) return model def get_validator(self):