Fix yolo mode=train CLI bug on model load (#133)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-02 19:24:44 +01:00
committed by GitHub
parent c3d961fb03
commit 8f3cd52844
8 changed files with 77 additions and 12 deletions

View File

@ -10,13 +10,11 @@ import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
class BaseModel(nn.Module):
'''
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
default_keys = DEFAULT_CONFIG_DICT.keys()
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode