Unified model loading with backwards compatibility (#132)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-02 17:37:23 +01:00
committed by GitHub
parent 8996c5c6cf
commit c3d961fb03
10 changed files with 65 additions and 50 deletions

View File

@ -124,7 +124,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
return func(*args, **kwargs)
def sync_analytics(cfg, all_keys=False, enabled=True):
def sync_analytics(cfg, all_keys=False, enabled=False):
"""
Sync analytics data if enabled in the global settings

View File

@ -10,11 +10,13 @@ import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
class BaseModel(nn.Module):
'''
@ -211,7 +213,7 @@ class DetectionModel(BaseModel):
return y
def load(self, weights, verbose=True):
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
csd = weights.float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
if verbose:
@ -281,21 +283,21 @@ class ClassificationModel(BaseModel):
# Functions ------------------------------------------------------------------------------------------------------------
def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
default_keys = DEFAULT_CONFIG_DICT.keys()
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
# Module compatibility updates
@ -310,7 +312,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
if len(model) == 1:
return model[-1]
# Return detection ensemble
# Return ensemble
print(f'Ensemble created with {weights}\n')
for k in 'names', 'nc', 'yaml':
setattr(model, k, getattr(model[0], k))

View File

@ -164,8 +164,8 @@ class Exporter:
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
# Checks
if self.args.batch_size == 16:
self.args.batch_size = 1 # TODO: resolve batch_size 16 default in config.yaml
# if self.args.batch_size == model.args['batch_size']: # user has not modified training batch_size
self.args.batch_size = 1
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
@ -778,7 +778,7 @@ def export(cfg):
if Path(cfg.model).suffix == '.yaml':
model = DetectionModel(cfg.model)
elif Path(cfg.model).suffix == '.pt':
model = attempt_load_weights(cfg.model)
model = attempt_load_weights(cfg.model, fuse=True)
else:
TypeError(f'Unsupported model type {cfg.model}')
exporter(model=model)

View File

@ -77,13 +77,12 @@ class YOLO:
Args:
weights (str): model checkpoint to be loaded
"""
self.ckpt = torch.load(weights, map_location="cpu")
self.task = self.ckpt["train_args"]["task"]
self.overrides = dict(self.ckpt["train_args"])
self.model = attempt_load_weights(weights)
self.task = self.model.args["task"]
self.overrides = self.model.args
self.overrides["device"] = '' # reset device
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task)
self.model = attempt_load_weights(weights, fuse=False)
def reset(self):
"""
@ -189,7 +188,7 @@ class YOLO:
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
self.trainer = self.TrainerClass(overrides=overrides)
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
self.trainer.model = self.trainer.load_model(weights=self.model,
model_cfg=self.model.yaml if self.task != "classify" else None)
self.model = self.trainer.model # override here to save memory

View File

@ -106,6 +106,9 @@ class BaseValidator:
data = check_dataset_yaml(self.args.data)
else:
data = check_dataset(self.args.data)
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
self.dataloader = self.dataloader or \
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)

View File

@ -271,19 +271,20 @@ def yaml_save(file='data.yaml', data=None):
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml'):
def yaml_load(file='data.yaml', append_filename=True):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is True.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore') as f:
# Add YAML filename to dict and return
return {**yaml.safe_load(f), 'yaml_file': str(file)}
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):

View File

@ -54,7 +54,7 @@ class DetectionTrainer(BaseTrainer):
self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
return model

View File

@ -17,7 +17,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
return model