Unified model loading with backwards compatibility (#132)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-02 17:37:23 +01:00
committed by GitHub
parent 8996c5c6cf
commit c3d961fb03
10 changed files with 65 additions and 50 deletions

View File

@ -164,8 +164,8 @@ class Exporter:
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
# Checks
if self.args.batch_size == 16:
self.args.batch_size = 1 # TODO: resolve batch_size 16 default in config.yaml
# if self.args.batch_size == model.args['batch_size']: # user has not modified training batch_size
self.args.batch_size = 1
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
@ -778,7 +778,7 @@ def export(cfg):
if Path(cfg.model).suffix == '.yaml':
model = DetectionModel(cfg.model)
elif Path(cfg.model).suffix == '.pt':
model = attempt_load_weights(cfg.model)
model = attempt_load_weights(cfg.model, fuse=True)
else:
TypeError(f'Unsupported model type {cfg.model}')
exporter(model=model)

View File

@ -77,13 +77,12 @@ class YOLO:
Args:
weights (str): model checkpoint to be loaded
"""
self.ckpt = torch.load(weights, map_location="cpu")
self.task = self.ckpt["train_args"]["task"]
self.overrides = dict(self.ckpt["train_args"])
self.model = attempt_load_weights(weights)
self.task = self.model.args["task"]
self.overrides = self.model.args
self.overrides["device"] = '' # reset device
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task)
self.model = attempt_load_weights(weights, fuse=False)
def reset(self):
"""
@ -189,7 +188,7 @@ class YOLO:
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
self.trainer = self.TrainerClass(overrides=overrides)
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
self.trainer.model = self.trainer.load_model(weights=self.model,
model_cfg=self.model.yaml if self.task != "classify" else None)
self.model = self.trainer.model # override here to save memory

View File

@ -106,6 +106,9 @@ class BaseValidator:
data = check_dataset_yaml(self.args.data)
else:
data = check_dataset(self.args.data)
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
self.dataloader = self.dataloader or \
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)

View File

@ -271,19 +271,20 @@ def yaml_save(file='data.yaml', data=None):
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml'):
def yaml_load(file='data.yaml', append_filename=True):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is True.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore') as f:
# Add YAML filename to dict and return
return {**yaml.safe_load(f), 'yaml_file': str(file)}
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):

View File

@ -54,7 +54,7 @@ class DetectionTrainer(BaseTrainer):
self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
return model

View File

@ -17,7 +17,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
return model