Fix yolo mode=train
CLI bug on model load (#133)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
|
||||
# Default config dictionary
|
||||
with open(DEFAULT_CONFIG, errors='ignore') as f:
|
||||
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
|
||||
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
|
||||
|
||||
|
||||
def is_colab():
|
||||
"""
|
||||
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import ultralytics
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
|
||||
from ultralytics.yolo.utils.checks import git_describe
|
||||
|
||||
from .checks import check_version
|
||||
@ -270,6 +270,7 @@ class ModelEMA:
|
||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
Reference in New Issue
Block a user