Fix yolo mode=train CLI bug on model load (#133)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-02 19:24:44 +01:00
committed by GitHub
parent c3d961fb03
commit 8f3cd52844
8 changed files with 77 additions and 12 deletions

View File

@ -5,11 +5,10 @@ import time
import requests
from ultralytics.hub.config import HUB_API_ROOT
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG)
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):

View File

@ -10,13 +10,11 @@ import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
class BaseModel(nn.Module):
'''
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
default_keys = DEFAULT_CONFIG_DICT.keys()
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode

View File

@ -362,7 +362,7 @@ class BaseTrainer:
return
# We should improve the code flow here. This function looks hacky
model = self.model
pretrained = not (str(model).endswith(".yaml"))
pretrained = not str(model).endswith(".yaml")
# config
if not pretrained:
model = check_file(model)

View File

@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
# Default config dictionary
with open(DEFAULT_CONFIG, errors='ignore') as f:
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
def is_colab():
"""

View File

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version
@ -270,6 +270,7 @@ class ModelEMA:
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

View File

@ -54,9 +54,12 @@ class DetectionTrainer(BaseTrainer):
self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
ch=3,
nc=self.data["nc"],
verbose=verbose)
if weights:
model.load(weights, verbose)
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
return model
def get_validator(self):