General trainer cleanup (#147)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-07 19:25:48 +05:30
committed by GitHub
parent f8a13c49a0
commit 0e5a7ae623
8 changed files with 196 additions and 60 deletions

View File

@ -282,6 +282,7 @@ class ClassificationModel(BaseModel):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
LOGGER.info("WARNING: Deprecated in favor of attempt_load_one_weight()")
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
@ -321,6 +322,34 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
return model
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Loads a single model weights
from ultralytics.yolo.utils.downloads import attempt_download
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
# Module compatibility updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
m.inplace = inplace # torch 1.7.0 compatibility
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt
return model, ckpt
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary
if verbose:
@ -375,16 +404,3 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
ch = []
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
def get_model(model='s.pt', pretrained=True):
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
if Path(f"{model}.pt").is_file(): # local file
return attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: # Ultralytics assets
return attempt_load_weights(f"{model}.pt", device='cpu')