Smart Model loading (#31)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,10 +1,10 @@
|
||||
import contextlib
|
||||
|
||||
import torchvision
|
||||
import yaml
|
||||
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
from .modules import *
|
||||
from ultralytics.yolo.utils.modeling.modules import *
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
|
||||
@ -26,7 +26,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
|
||||
# Module compatibility updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect):
|
||||
m.inplace = inplace # torch 1.7.0 compatibility
|
||||
if t is Detect and not isinstance(m.anchor_grid, list):
|
||||
delattr(m, 'anchor_grid')
|
||||
@ -107,6 +107,20 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
||||
return nn.Sequential(*layers), sorted(save)
|
||||
|
||||
|
||||
def get_model(model: str):
|
||||
if model.endswith(".pt"):
|
||||
model = model.split(".")[0]
|
||||
|
||||
if Path(model + ".pt").is_file():
|
||||
trained_model = torch.load(model + ".pt", map_location='cpu')
|
||||
elif model in torchvision.models.__dict__: # try torch hub classifier models
|
||||
trained_model = torch.hub.load("pytorch/vision", model, pretrained=True)
|
||||
else:
|
||||
model_ckpt = attempt_download(model + ".pt") # try ultralytics assets
|
||||
trained_model = torch.load(model_ckpt, map_location='cpu')
|
||||
return trained_model
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml'):
|
||||
# Single-line safe yaml loading
|
||||
with open(file, errors='ignore') as f:
|
||||
|
Reference in New Issue
Block a user