|
|
|
@ -313,13 +313,39 @@ class ClassificationModel(BaseModel):
|
|
|
|
|
# Functions ------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_safe_load(weight):
|
|
|
|
|
"""
|
|
|
|
|
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it
|
|
|
|
|
catches the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
|
|
|
|
function. After installation, the function again attempts to load the model using torch.load().
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
weight (str): The file path of the PyTorch model.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The loaded PyTorch model.
|
|
|
|
|
"""
|
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download
|
|
|
|
|
|
|
|
|
|
file = attempt_download(weight) # search online if missing locally
|
|
|
|
|
try:
|
|
|
|
|
return torch.load(file, map_location='cpu') # load
|
|
|
|
|
except ModuleNotFoundError as e:
|
|
|
|
|
if e.name == 'omegaconf': # e.name is missing module name
|
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
|
|
|
|
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
|
|
|
|
f"\nRecommend fixes are to train a new model using updated ultraltyics package or to "
|
|
|
|
|
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
|
|
|
|
check_requirements(e.name) # install missing module
|
|
|
|
|
return torch.load(file, map_location='cpu') # load
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
|
|
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download
|
|
|
|
|
|
|
|
|
|
model = Ensemble()
|
|
|
|
|
for w in weights if isinstance(weights, list) else [weights]:
|
|
|
|
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
|
|
|
|
ckpt = torch_safe_load(w) # load ckpt
|
|
|
|
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
|
|
|
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
|
|
|
|
|
|
|
|
@ -355,18 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|
|
|
|
|
|
|
|
|
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|
|
|
|
# Loads a single model weights
|
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download
|
|
|
|
|
|
|
|
|
|
weight = attempt_download(weight)
|
|
|
|
|
try:
|
|
|
|
|
ckpt = torch.load(weight, map_location='cpu') # load
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from "
|
|
|
|
|
"ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for "
|
|
|
|
|
"omegaconf models in the future.\nPlease train a new model or download updated models "
|
|
|
|
|
"from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
|
|
|
|
check_requirements('omegaconf')
|
|
|
|
|
ckpt = torch.load(weight, map_location='cpu') # load
|
|
|
|
|
ckpt = torch_safe_load(weight) # load ckpt
|
|
|
|
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
|
|
|
|
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
|
|
|
|
|
|
|
|
|