Fix `yolo mode=train` CLI bug on model load (#133)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent c3d961fb03
commit 8f3cd52844
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,59 @@
import os
from ultralytics.yolo.utils import ROOT
def test_checks():
os.system('yolo mode=checks')
# Train checks ---------------------------------------------------------------------------------------------------------
def test_train_detect():
os.system('yolo mode=train task=detect model=yolov8n.yaml data=coco128.yaml imgsz=32 epochs=1')
def test_train_segment():
os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
def test_train_classify():
pass
# Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect():
os.system('yolo mode=val task=detect model=yolov8n.pt data=coco128.yaml imgsz=32 epochs=1')
def test_val_segment():
pass
def test_val_classify():
pass
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
os.system(f"yolo mode=predict model=yolov8n.pt source={ROOT / 'assets'}")
def test_predict_segment():
pass
def test_predict_classify():
pass
# Export checks --------------------------------------------------------------------------------------------------------
def test_export_detect_torchscript():
os.system('yolo mode=export model=yolov8n.pt format=torchscript')
def test_export_segment_torchscript():
pass
def test_export_classify_torchscript():
pass

@ -5,11 +5,10 @@ import time
import requests
from ultralytics.hub.config import HUB_API_ROOT
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG)
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):

@ -10,13 +10,11 @@ import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
class BaseModel(nn.Module):
'''
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
default_keys = DEFAULT_CONFIG_DICT.keys()
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode

@ -362,7 +362,7 @@ class BaseTrainer:
return
# We should improve the code flow here. This function looks hacky
model = self.model
pretrained = not (str(model).endswith(".yaml"))
pretrained = not str(model).endswith(".yaml")
# config
if not pretrained:
model = check_file(model)

@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
# Default config dictionary
with open(DEFAULT_CONFIG, errors='ignore') as f:
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
def is_colab():
"""

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version
@ -270,6 +270,7 @@ class ModelEMA:
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@ -54,9 +54,12 @@ class DetectionTrainer(BaseTrainer):
self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
ch=3,
nc=self.data["nc"],
verbose=verbose)
if weights:
model.load(weights, verbose)
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
return model
def get_validator(self):

Loading…
Cancel
Save