Fix load and resume and update autodownload endpoint (#136)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-03 01:42:01 +05:30
committed by GitHub
parent 6d5123297e
commit 82c849c163
5 changed files with 16 additions and 34 deletions

View File

@ -82,7 +82,7 @@ class YOLO:
self.ckpt_path = weights
self.task = self.model.args["task"]
self.overrides = self.model.args
self.overrides["device"] = '' # reset device
self._reset_ckpt_args(self.overrides)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task)
@ -199,27 +199,6 @@ class YOLO:
self.trainer.train()
def resume(self, task=None, model=None):
"""
Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
Args:
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
If `model` is specified
"""
if task:
if task.lower() not in MODEL_MAP:
raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
else:
ckpt = torch.load(model, map_location="cpu")
task = ckpt["train_args"]["task"]
del ckpt
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
task=task.lower())
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True})
self.trainer.train()
def to(self, device):
self.model.to(device)
@ -240,3 +219,10 @@ class YOLO:
def forward(self, imgs):
return self.__call__(imgs)
@staticmethod
def _reset_ckpt_args(args):
args.pop("device", None)
args.pop("project", None)
args.pop("name", None)
args.pop("batch_size", None)

View File

@ -367,7 +367,7 @@ class BaseTrainer:
if not pretrained:
model = check_file(model)
ckpt = self.load_ckpt(model) if pretrained else None
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model
return ckpt
def load_ckpt(self, ckpt):