Smart Model loading (#31)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-01 04:22:12 +05:30
committed by GitHub
parent 1054819a59
commit 92c60758dd
4 changed files with 80 additions and 42 deletions

View File

@ -1,10 +1,10 @@
import contextlib
import torchvision
import yaml
from ultralytics.yolo.utils.downloads import attempt_download
from .modules import *
from ultralytics.yolo.utils.modeling.modules import *
def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
@ -26,7 +26,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
# Module compatibility updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect):
m.inplace = inplace # torch 1.7.0 compatibility
if t is Detect and not isinstance(m.anchor_grid, list):
delattr(m, 'anchor_grid')
@ -107,6 +107,20 @@ def parse_model(d, ch): # model_dict, input_channels(3)
return nn.Sequential(*layers), sorted(save)
def get_model(model: str):
if model.endswith(".pt"):
model = model.split(".")[0]
if Path(model + ".pt").is_file():
trained_model = torch.load(model + ".pt", map_location='cpu')
elif model in torchvision.models.__dict__: # try torch hub classifier models
trained_model = torch.hub.load("pytorch/vision", model, pretrained=True)
else:
model_ckpt = attempt_download(model + ".pt") # try ultralytics assets
trained_model = torch.load(model_ckpt, map_location='cpu')
return trained_model
def yaml_load(file='data.yaml'):
# Single-line safe yaml loading
with open(file, errors='ignore') as f: