Fix load and resume and update autodownload endpoint (#136)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 6d5123297e
commit 82c849c163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -82,7 +82,7 @@ class YOLO:
self.ckpt_path = weights self.ckpt_path = weights
self.task = self.model.args["task"] self.task = self.model.args["task"]
self.overrides = self.model.args self.overrides = self.model.args
self.overrides["device"] = '' # reset device self._reset_ckpt_args(self.overrides)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task) self._guess_ops_from_task(self.task)
@ -199,27 +199,6 @@ class YOLO:
self.trainer.train() self.trainer.train()
def resume(self, task=None, model=None):
"""
Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
Args:
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
If `model` is specified
"""
if task:
if task.lower() not in MODEL_MAP:
raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
else:
ckpt = torch.load(model, map_location="cpu")
task = ckpt["train_args"]["task"]
del ckpt
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
task=task.lower())
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True})
self.trainer.train()
def to(self, device): def to(self, device):
self.model.to(device) self.model.to(device)
@ -240,3 +219,10 @@ class YOLO:
def forward(self, imgs): def forward(self, imgs):
return self.__call__(imgs) return self.__call__(imgs)
@staticmethod
def _reset_ckpt_args(args):
args.pop("device", None)
args.pop("project", None)
args.pop("name", None)
args.pop("batch_size", None)

@ -367,7 +367,7 @@ class BaseTrainer:
if not pretrained: if not pretrained:
model = check_file(model) model = check_file(model)
ckpt = self.load_ckpt(model) if pretrained else None ckpt = self.load_ckpt(model) if pretrained else None
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model
return ckpt return ckpt
def load_ckpt(self, ckpt): def load_ckpt(self, ckpt):

@ -45,11 +45,12 @@ def is_url(url, check=True):
return False return False
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'): def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
def github_assets(repository, version='latest'): def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...]) # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...])
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
if version != 'latest': if version != 'latest':
version = f'tags/{version}' # i.e. tags/v6.2 version = f'tags/{version}' # i.e. tags/v6.2
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
@ -70,6 +71,7 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
# GitHub assets # GitHub assets
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
try: try:
tag, assets = github_assets(repo, release) tag, assets = github_assets(repo, release)
except Exception: except Exception:

@ -54,12 +54,9 @@ class DetectionTrainer(BaseTrainer):
self.model.names = self.data["names"] self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True): def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml, model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
ch=3,
nc=self.data["nc"],
verbose=verbose)
if weights: if weights:
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose) model.load(weights, verbose)
return model return model
def get_validator(self): def get_validator(self):

@ -17,12 +17,9 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True): def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml, model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
ch=3,
nc=self.data["nc"],
verbose=verbose)
if weights: if weights:
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose) model.load(weights, verbose)
return model return model
def get_validator(self): def get_validator(self):

Loading…
Cancel
Save