Add max_dim==2 argument to check_imgsz() (#789)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
This commit is contained in:
Glenn Jocher
2023-02-04 01:48:44 +04:00
committed by GitHub
parent 5a80ad98db
commit 0d182e80f1
11 changed files with 96 additions and 52 deletions

View File

@ -18,6 +18,16 @@ from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy
def check_class_names(names):
# Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.
if isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
return names
class AutoBackend(nn.Module):
def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
@ -228,11 +238,7 @@ class AutoBackend(nn.Module):
# class names
if 'names' not in locals(): # names missing
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default
elif isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
names = check_class_names(names)
self.__dict__.update(locals()) # assign all variables to self

View File

@ -347,23 +347,24 @@ def torch_safe_load(weight):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
ensemble = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch_safe_load(w) # load ckpt
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
ckpt.pt_path = weights # attach *.pt file path to model
if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.])
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weights # attach *.pt file path to model
model.task = guess_model_task(model)
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])
# Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
# Module compatibility updates
for m in model.modules():
for m in ensemble.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
m.inplace = inplace # torch 1.7.0 compatibility
@ -371,16 +372,16 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model
if len(model) == 1:
return model[-1]
if len(ensemble) == 1:
return ensemble[-1]
# Return ensemble
print(f'Ensemble created with {weights}\n')
for k in 'names', 'nc', 'yaml':
setattr(model, k, getattr(model[0], k))
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
return model
setattr(ensemble, k, getattr(ensemble[0], k))
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts: {[m.nc for m in ensemble]}'
return ensemble
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
@ -392,6 +393,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model
model.task = guess_model_task(model)
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])