|
|
@ -82,7 +82,7 @@ class YOLO:
|
|
|
|
self.ckpt_path = weights
|
|
|
|
self.ckpt_path = weights
|
|
|
|
self.task = self.model.args["task"]
|
|
|
|
self.task = self.model.args["task"]
|
|
|
|
self.overrides = self.model.args
|
|
|
|
self.overrides = self.model.args
|
|
|
|
self.overrides["device"] = '' # reset device
|
|
|
|
self._reset_ckpt_args(self.overrides)
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
|
|
|
self._guess_ops_from_task(self.task)
|
|
|
|
self._guess_ops_from_task(self.task)
|
|
|
|
|
|
|
|
|
|
|
@ -199,27 +199,6 @@ class YOLO:
|
|
|
|
|
|
|
|
|
|
|
|
self.trainer.train()
|
|
|
|
self.trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
def resume(self, task=None, model=None):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
|
|
|
|
|
|
|
|
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
|
|
|
|
|
|
|
|
If `model` is specified
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if task:
|
|
|
|
|
|
|
|
if task.lower() not in MODEL_MAP:
|
|
|
|
|
|
|
|
raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ckpt = torch.load(model, map_location="cpu")
|
|
|
|
|
|
|
|
task = ckpt["train_args"]["task"]
|
|
|
|
|
|
|
|
del ckpt
|
|
|
|
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
|
|
|
|
|
|
|
task=task.lower())
|
|
|
|
|
|
|
|
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to(self, device):
|
|
|
|
def to(self, device):
|
|
|
|
self.model.to(device)
|
|
|
|
self.model.to(device)
|
|
|
|
|
|
|
|
|
|
|
@ -240,3 +219,10 @@ class YOLO:
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, imgs):
|
|
|
|
def forward(self, imgs):
|
|
|
|
return self.__call__(imgs)
|
|
|
|
return self.__call__(imgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def _reset_ckpt_args(args):
|
|
|
|
|
|
|
|
args.pop("device", None)
|
|
|
|
|
|
|
|
args.pop("project", None)
|
|
|
|
|
|
|
|
args.pop("name", None)
|
|
|
|
|
|
|
|
args.pop("batch_size", None)
|
|
|
|