Global settings typechecking (#148)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 19334ebb16
commit 3cbf3ec455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,7 +48,6 @@ class YOLO:
self.ckpt_path = None
self.cfg = None # if loaded from *.yaml
self.overrides = {} # overrides for trainer object
self.init_disabled = False # disable model initialization
# Load or create new YOLO model
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)

@ -365,8 +365,15 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
yaml_save(file, defaults)
settings = yaml_load(file)
if settings.keys() != defaults.keys():
settings = {**defaults, **settings} # merge **defaults with **settings (prefer **settings)
# Check that settings keys and types match defaults
correct = settings.keys() == defaults.keys() and \
all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values()))
if not correct:
LOGGER.warning('WARNING ⚠️ Different global settings detected, resetting to defaults. '
'This may be due to an ultralytics package update. '
f'View and update your global settings directly in {file}')
settings = defaults # merge **defaults with **settings (prefer **settings)
yaml_save(file, settings) # save updated defaults
return settings

@ -268,8 +268,23 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
def strip_optimizer(f='best.pt', s=''):
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
Usage:
from ultralytics.yolo.utils.torch_utils import strip_optimizer
from pathlib import Path
for f in Path('/Users/glennjocher/Downloads/weights').glob('*.pt'):
strip_optimizer(f)
Args:
f (str): file path to model state to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten.
Returns:
None
"""
x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
if x.get('ema'):

Loading…
Cancel
Save