General trainer cleanup (#147)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -282,6 +282,7 @@ class ClassificationModel(BaseModel):
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
LOGGER.info("WARNING: Deprecated in favor of attempt_load_one_weight()")
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
@ -321,6 +322,34 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
return model
|
||||
|
||||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
# Loads a single model weights
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
|
||||
model.pt_path = weight # attach *.pt file path to model
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
|
||||
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
|
||||
|
||||
# Module compatibility updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
||||
m.inplace = inplace # torch 1.7.0 compatibility
|
||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model and ckpt
|
||||
return model, ckpt
|
||||
|
||||
|
||||
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
# Parse a YOLOv5 model.yaml dictionary
|
||||
if verbose:
|
||||
@ -375,16 +404,3 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
ch = []
|
||||
ch.append(c2)
|
||||
return nn.Sequential(*layers), sorted(save)
|
||||
|
||||
|
||||
def get_model(model='s.pt', pretrained=True):
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
model = model.split(".")[0]
|
||||
|
||||
if Path(f"{model}.pt").is_file(): # local file
|
||||
return attempt_load_weights(f"{model}.pt", device='cpu')
|
||||
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
||||
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
else: # Ultralytics assets
|
||||
return attempt_load_weights(f"{model}.pt", device='cpu')
|
||||
|
Reference in New Issue
Block a user