Unified model loading with backwards compatibility (#132)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 8996c5c6cf
commit c3d961fb03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,9 +42,9 @@ Ultralytics YOLO comes with pythonic Model and Trainer interface.
import ultralytics import ultralytics
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO("s-seg.yaml") # automatically detects task type model = YOLO("yolov8n-seg.yaml") # automatically detects task type
model = YOLO("s-seg.pt") # load checkpoint model = YOLO("yolov8n.pt") # load checkpoint
model.train(data="coco128-segments", epochs=1, lr0=0.01, ...) model.train(data="coco128-seg.yaml", epochs=1, lr0=0.01, ...)
model.train(data="coco128-segments", epochs=1, lr0=0.01, device="0,1,2,3") # DDP mode model.train(data="coco128-seg.yaml", epochs=1, lr0=0.01, device="0,1,2,3") # DDP mode
``` ```
[API Guide](sdk.md){ .md-button .md-button--primary} [API Guide](sdk.md){ .md-button .md-button--primary}

@ -1,11 +1,7 @@
import torch import torch
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.yolo.utils import ROOT
def test_model_init():
model = YOLO("yolov8n.yaml")
model.info()
def test_model_forward(): def test_model_forward():
@ -29,9 +25,9 @@ def test_model_fuse():
model.fuse() model.fuse()
def test_visualize_preds(): def test_predict_dir():
model = YOLO("yolov8n.pt") model = YOLO("yolov8n.pt")
model.predict(source="ultralytics/assets") model.predict(source=ROOT / "assets")
def test_val(): def test_val():
@ -39,7 +35,7 @@ def test_val():
model.val(data="coco128.yaml", imgsz=32) model.val(data="coco128.yaml", imgsz=32)
def test_model_resume(): def test_train_resume():
model = YOLO("yolov8n.yaml") model = YOLO("yolov8n.yaml")
model.train(epochs=1, imgsz=32, data="coco128.yaml") model.train(epochs=1, imgsz=32, data="coco128.yaml")
try: try:
@ -48,16 +44,21 @@ def test_model_resume():
print("Successfully caught resume assert!") print("Successfully caught resume assert!")
def test_model_train_pretrained(): def test_train_scratch():
model = YOLO("yolov8n.pt")
model.train(data="coco128.yaml", epochs=1, imgsz=32)
model = YOLO("yolov8n.yaml") model = YOLO("yolov8n.yaml")
model.train(data="coco128.yaml", epochs=1, imgsz=32) model.train(data="coco128.yaml", epochs=1, imgsz=32)
img = torch.rand(1, 3, 320, 320) img = torch.rand(1, 3, 320, 320)
model(img) model(img)
def test_exports(): def test_train_pretrained():
model = YOLO("yolov8n.pt")
model.train(data="coco128.yaml", epochs=1, imgsz=32)
img = torch.rand(1, 3, 320, 320)
model(img)
def test_export_torchscript():
""" """
Format Argument Suffix CPU GPU Format Argument Suffix CPU GPU
0 PyTorch - .pt True True 0 PyTorch - .pt True True
@ -74,26 +75,35 @@ def test_exports():
11 PaddlePaddle paddle _paddle_model True True 11 PaddlePaddle paddle _paddle_model True True
""" """
from ultralytics.yolo.engine.exporter import export_formats from ultralytics.yolo.engine.exporter import export_formats
print(export_formats()) print(export_formats())
model = YOLO("yolov8n.yaml") model = YOLO("yolov8n.yaml")
model.export(format='torchscript') model.export(format='torchscript')
def test_export_onnx():
model = YOLO("yolov8n.yaml")
model.export(format='onnx') model.export(format='onnx')
def test_export_openvino():
model = YOLO("yolov8n.yaml")
model.export(format='openvino') model.export(format='openvino')
def test_export_coreml():
model = YOLO("yolov8n.yaml")
model.export(format='coreml') model.export(format='coreml')
model.export(format='paddle')
def test(): def test_export_paddle():
test_model_forward() model = YOLO("yolov8n.yaml")
test_model_info() model.export(format='paddle')
test_model_fuse()
test_visualize_preds()
test_val()
test_model_resume()
test_model_train_pretrained()
if __name__ == "__main__": # def run_all_tests(): # do not name function test_...
test() # pass
#
#
# if __name__ == "__main__":
# run_all_tests()

@ -124,7 +124,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
return func(*args, **kwargs) return func(*args, **kwargs)
def sync_analytics(cfg, all_keys=False, enabled=True): def sync_analytics(cfg, all_keys=False, enabled=False):
""" """
Sync analytics data if enabled in the global settings Sync analytics data if enabled in the global settings

@ -10,11 +10,13 @@ import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment) GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import LOGGER, colorstr, yaml_load from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible, from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync) model_info, scale_img, time_sync)
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
class BaseModel(nn.Module): class BaseModel(nn.Module):
''' '''
@ -211,7 +213,7 @@ class DetectionModel(BaseModel):
return y return y
def load(self, weights, verbose=True): def load(self, weights, verbose=True):
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32 csd = weights.float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, self.state_dict()) # intersect csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load self.load_state_dict(csd, strict=False) # load
if verbose: if verbose:
@ -281,21 +283,21 @@ class ClassificationModel(BaseModel):
# Functions ------------------------------------------------------------------------------------------------------------ # Functions ------------------------------------------------------------------------------------------------------------
def attempt_load_weights(weights, device=None, inplace=True, fuse=True): def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download from ultralytics.yolo.utils.downloads import attempt_download
default_keys = DEFAULT_CONFIG_DICT.keys()
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates # Model compatibility updates
if not hasattr(ckpt, 'stride'): ckpt.args = {k: v for k, v in args.items() if k in default_keys}
ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
# Module compatibility updates # Module compatibility updates
@ -310,7 +312,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
if len(model) == 1: if len(model) == 1:
return model[-1] return model[-1]
# Return detection ensemble # Return ensemble
print(f'Ensemble created with {weights}\n') print(f'Ensemble created with {weights}\n')
for k in 'names', 'nc', 'yaml': for k in 'names', 'nc', 'yaml':
setattr(model, k, getattr(model[0], k)) setattr(model, k, getattr(model[0], k))

@ -164,8 +164,8 @@ class Exporter:
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic' assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
# Checks # Checks
if self.args.batch_size == 16: # if self.args.batch_size == model.args['batch_size']: # user has not modified training batch_size
self.args.batch_size = 1 # TODO: resolve batch_size 16 default in config.yaml self.args.batch_size = 1
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if self.args.optimize: if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
@ -778,7 +778,7 @@ def export(cfg):
if Path(cfg.model).suffix == '.yaml': if Path(cfg.model).suffix == '.yaml':
model = DetectionModel(cfg.model) model = DetectionModel(cfg.model)
elif Path(cfg.model).suffix == '.pt': elif Path(cfg.model).suffix == '.pt':
model = attempt_load_weights(cfg.model) model = attempt_load_weights(cfg.model, fuse=True)
else: else:
TypeError(f'Unsupported model type {cfg.model}') TypeError(f'Unsupported model type {cfg.model}')
exporter(model=model) exporter(model=model)

@ -77,13 +77,12 @@ class YOLO:
Args: Args:
weights (str): model checkpoint to be loaded weights (str): model checkpoint to be loaded
""" """
self.ckpt = torch.load(weights, map_location="cpu") self.model = attempt_load_weights(weights)
self.task = self.ckpt["train_args"]["task"] self.task = self.model.args["task"]
self.overrides = dict(self.ckpt["train_args"]) self.overrides = self.model.args
self.overrides["device"] = '' # reset device self.overrides["device"] = '' # reset device
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task) self._guess_ops_from_task(self.task)
self.model = attempt_load_weights(weights, fuse=False)
def reset(self): def reset(self):
""" """
@ -189,7 +188,7 @@ class YOLO:
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.") raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
self.trainer = self.TrainerClass(overrides=overrides) self.trainer = self.TrainerClass(overrides=overrides)
self.trainer.model = self.trainer.load_model(weights=self.ckpt, self.trainer.model = self.trainer.load_model(weights=self.model,
model_cfg=self.model.yaml if self.task != "classify" else None) model_cfg=self.model.yaml if self.task != "classify" else None)
self.model = self.trainer.model # override here to save memory self.model = self.trainer.model # override here to save memory

@ -106,6 +106,9 @@ class BaseValidator:
data = check_dataset_yaml(self.args.data) data = check_dataset_yaml(self.args.data)
else: else:
data = check_dataset(self.args.data) data = check_dataset(self.args.data)
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
self.dataloader = self.dataloader or \ self.dataloader = self.dataloader or \
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size) self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)

@ -271,19 +271,20 @@ def yaml_save(file='data.yaml', data=None):
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml'): def yaml_load(file='data.yaml', append_filename=True):
""" """
Load YAML data from a file. Load YAML data from a file.
Args: Args:
file (str, optional): File name. Default is 'data.yaml'. file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is True.
Returns: Returns:
dict: YAML data and file name. dict: YAML data and file name.
""" """
with open(file, errors='ignore') as f: with open(file, errors='ignore') as f:
# Add YAML filename to dict and return # Add YAML filename to dict and return
return {**yaml.safe_load(f), 'yaml_file': str(file)} return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):

@ -54,7 +54,7 @@ class DetectionTrainer(BaseTrainer):
self.model.names = self.data["names"] self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True): def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose) model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
if weights: if weights:
model.load(weights, verbose) model.load(weights, verbose)
return model return model

@ -17,7 +17,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True): def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose) model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
if weights: if weights:
model.load(weights, verbose) model.load(weights, verbose)
return model return model

Loading…
Cancel
Save