Fix yolo mode=train
CLI bug on model load (#133)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -10,13 +10,11 @@ import torchvision
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||
model_info, scale_img, time_sync)
|
||||
|
||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
'''
|
||||
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
default_keys = DEFAULT_CONFIG_DICT.keys()
|
||||
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
|
||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
|
||||
|
||||
# Append
|
||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||
|
Reference in New Issue
Block a user