|
|
|
@ -10,11 +10,13 @@ import torchvision
|
|
|
|
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
|
|
|
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
|
|
|
|
GhostBottleneck, GhostConv, Segment)
|
|
|
|
|
from ultralytics.yolo.utils import LOGGER, colorstr, yaml_load
|
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_yaml
|
|
|
|
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
|
|
|
|
model_info, scale_img, time_sync)
|
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseModel(nn.Module):
|
|
|
|
|
'''
|
|
|
|
@ -211,7 +213,7 @@ class DetectionModel(BaseModel):
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
def load(self, weights, verbose=True):
|
|
|
|
|
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
|
|
|
|
|
csd = weights.float().state_dict() # checkpoint state_dict as FP32
|
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
|
|
|
|
self.load_state_dict(csd, strict=False) # load
|
|
|
|
|
if verbose:
|
|
|
|
@ -281,21 +283,21 @@ class ClassificationModel(BaseModel):
|
|
|
|
|
# Functions ------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
|
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
|
|
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download
|
|
|
|
|
default_keys = DEFAULT_CONFIG_DICT.keys()
|
|
|
|
|
|
|
|
|
|
model = Ensemble()
|
|
|
|
|
for w in weights if isinstance(weights, list) else [weights]:
|
|
|
|
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
|
|
|
|
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
|
|
|
|
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
|
|
|
|
|
|
|
|
|
# Model compatibility updates
|
|
|
|
|
if not hasattr(ckpt, 'stride'):
|
|
|
|
|
ckpt.stride = torch.tensor([32.])
|
|
|
|
|
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
|
|
|
|
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
|
|
|
|
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
|
|
|
|
|
|
|
|
|
|
# Append
|
|
|
|
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
|
|
|
|
|
|
|
|
|
# Module compatibility updates
|
|
|
|
@ -310,7 +312,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
|
|
|
|
|
if len(model) == 1:
|
|
|
|
|
return model[-1]
|
|
|
|
|
|
|
|
|
|
# Return detection ensemble
|
|
|
|
|
# Return ensemble
|
|
|
|
|
print(f'Ensemble created with {weights}\n')
|
|
|
|
|
for k in 'names', 'nc', 'yaml':
|
|
|
|
|
setattr(model, k, getattr(model[0], k))
|
|
|
|
|