Unified model loading with backwards compatibility (#132)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-02 17:37:23 +01:00
committed by GitHub
parent 8996c5c6cf
commit c3d961fb03
10 changed files with 65 additions and 50 deletions

View File

@ -10,11 +10,13 @@ import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
class BaseModel(nn.Module):
'''
@ -211,7 +213,7 @@ class DetectionModel(BaseModel):
return y
def load(self, weights, verbose=True):
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
csd = weights.float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
if verbose:
@ -281,21 +283,21 @@ class ClassificationModel(BaseModel):
# Functions ------------------------------------------------------------------------------------------------------------
def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
default_keys = DEFAULT_CONFIG_DICT.keys()
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
# Module compatibility updates
@ -310,7 +312,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
if len(model) == 1:
return model[-1]
# Return detection ensemble
# Return ensemble
print(f'Ensemble created with {weights}\n')
for k in 'names', 'nc', 'yaml':
setattr(model, k, getattr(model[0], k))