From 82c849c16374f4b84f0f29f52b186f3d332bd5d1 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 3 Jan 2023 01:42:01 +0530 Subject: [PATCH] Fix load and resume and update autodownload endpoint (#136) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/engine/model.py | 30 ++++++++-------------------- ultralytics/yolo/engine/trainer.py | 2 +- ultralytics/yolo/utils/downloads.py | 4 +++- ultralytics/yolo/v8/detect/train.py | 7 ++----- ultralytics/yolo/v8/segment/train.py | 7 ++----- 5 files changed, 16 insertions(+), 34 deletions(-) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 9c89e9b..31ce376 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -82,7 +82,7 @@ class YOLO: self.ckpt_path = weights self.task = self.model.args["task"] self.overrides = self.model.args - self.overrides["device"] = '' # reset device + self._reset_ckpt_args(self.overrides) self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ self._guess_ops_from_task(self.task) @@ -199,27 +199,6 @@ class YOLO: self.trainer.train() - def resume(self, task=None, model=None): - """ - Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence. - Args: - task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified. - model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed. - If `model` is specified - """ - if task: - if task.lower() not in MODEL_MAP: - raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}") - else: - ckpt = torch.load(model, map_location="cpu") - task = ckpt["train_args"]["task"] - del ckpt - self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( - task=task.lower()) - self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True}) - - self.trainer.train() - def to(self, device): self.model.to(device) @@ -240,3 +219,10 @@ class YOLO: def forward(self, imgs): return self.__call__(imgs) + + @staticmethod + def _reset_ckpt_args(args): + args.pop("device", None) + args.pop("project", None) + args.pop("name", None) + args.pop("batch_size", None) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 372d328..f0f6a80 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -367,7 +367,7 @@ class BaseTrainer: if not pretrained: model = check_file(model) ckpt = self.load_ckpt(model) if pretrained else None - self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model + self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model return ckpt def load_ckpt(self, ckpt): diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 6eb0806..61f273b 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -45,11 +45,12 @@ def is_url(url, check=True): return False -def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'): +def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'): # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. def github_assets(repository, version='latest'): # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...]) + # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...]) if version != 'latest': version = f'tags/{version}' # i.e. tags/v6.2 response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api @@ -70,6 +71,7 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'): # GitHub assets assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default + assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default try: tag, assets = github_assets(repo, release) except Exception: diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 0671b4e..db760d5 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -54,12 +54,9 @@ class DetectionTrainer(BaseTrainer): self.model.names = self.data["names"] def load_model(self, model_cfg=None, weights=None, verbose=True): - model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml, - ch=3, - nc=self.data["nc"], - verbose=verbose) + model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) if weights: - model.load(weights['model'] if isinstance(weights, dict) else weights, verbose) + model.load(weights, verbose) return model def get_validator(self): diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 487114e..f8f146d 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -17,12 +17,9 @@ from ultralytics.yolo.utils.torch_utils import de_parallel class SegmentationTrainer(v8.detect.DetectionTrainer): def load_model(self, model_cfg=None, weights=None, verbose=True): - model = SegmentationModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml, - ch=3, - nc=self.data["nc"], - verbose=verbose) + model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose) if weights: - model.load(weights['model'] if isinstance(weights, dict) else weights, verbose) + model.load(weights, verbose) return model def get_validator(self):