Fix yolo mode=train CLI bug on model load (#133)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-02 19:24:44 +01:00
committed by GitHub
parent c3d961fb03
commit 8f3cd52844
8 changed files with 77 additions and 12 deletions

View File

@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
# Default config dictionary
with open(DEFAULT_CONFIG, errors='ignore') as f:
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
def is_colab():
"""

View File

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version
@ -270,6 +270,7 @@ class ModelEMA:
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")