Classify training cleanup (#33)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-11-07 00:15:57 +01:00
committed by GitHub
parent 2e9b18ce4e
commit 6fe8bead35
4 changed files with 29 additions and 31 deletions

View File

@ -107,18 +107,17 @@ def parse_model(d, ch): # model_dict, input_channels(3)
return nn.Sequential(*layers), sorted(save)
def get_model(model: str):
def get_model(model='s.pt', pretrained=True):
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
if Path(model + ".pt").is_file():
trained_model = torch.load(model + ".pt", map_location='cpu')
elif model in torchvision.models.__dict__: # try torch hub classifier models
trained_model = torch.hub.load("pytorch/vision", model, pretrained=True)
else:
model_ckpt = attempt_download(model + ".pt") # try ultralytics assets
trained_model = torch.load(model_ckpt, map_location='cpu')
return trained_model
if Path(f"{model}.pt").is_file(): # local file
return torch.load(f"{model}.pt", map_location='cpu')
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: # Ultralytics assets
return torch.load(attempt_download(f"{model}.pt"), map_location='cpu')
def yaml_load(file='data.yaml'):