Fix load and resume and update autodownload endpoint (#136)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -82,7 +82,7 @@ class YOLO:
|
||||
self.ckpt_path = weights
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args
|
||||
self.overrides["device"] = '' # reset device
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._guess_ops_from_task(self.task)
|
||||
|
||||
@ -199,27 +199,6 @@ class YOLO:
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
def resume(self, task=None, model=None):
|
||||
"""
|
||||
Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
|
||||
Args:
|
||||
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
|
||||
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
|
||||
If `model` is specified
|
||||
"""
|
||||
if task:
|
||||
if task.lower() not in MODEL_MAP:
|
||||
raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
||||
else:
|
||||
ckpt = torch.load(model, map_location="cpu")
|
||||
task = ckpt["train_args"]["task"]
|
||||
del ckpt
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||
task=task.lower())
|
||||
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True})
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
@ -240,3 +219,10 @@ class YOLO:
|
||||
|
||||
def forward(self, imgs):
|
||||
return self.__call__(imgs)
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
args.pop("device", None)
|
||||
args.pop("project", None)
|
||||
args.pop("name", None)
|
||||
args.pop("batch_size", None)
|
||||
|
@ -367,7 +367,7 @@ class BaseTrainer:
|
||||
if not pretrained:
|
||||
model = check_file(model)
|
||||
ckpt = self.load_ckpt(model) if pretrained else None
|
||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model
|
||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model
|
||||
return ckpt
|
||||
|
||||
def load_ckpt(self, ckpt):
|
||||
|
Reference in New Issue
Block a user