ultralytics 8.0.37 add TFLite metadata in AutoBackend (#953)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Aarni Koskela <akx@iki.fi>
This commit is contained in:
Glenn Jocher
2023-02-14 14:28:23 +04:00
committed by GitHub
parent 20fe708f31
commit bdc6cd4d8b
18 changed files with 86 additions and 46 deletions

View File

@ -1,7 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import ast
import contextlib
import json
import platform
import zipfile
from collections import OrderedDict, namedtuple
from pathlib import Path
from urllib.parse import urlparse
@ -207,6 +209,12 @@ class AutoBackend(nn.Module):
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv8 TF.js inference is not supported')
elif paddle: # PaddlePaddle
@ -214,7 +222,7 @@ class AutoBackend(nn.Module):
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
import paddle.inference as pdi
if not Path(w).is_file(): # if not *.pdmodel
w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
weights = Path(w).with_suffix('.pdiparams')
config = pdi.Config(str(w), str(weights))
if cuda:
@ -328,6 +336,9 @@ class AutoBackend(nn.Module):
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(self.output_details) == 2: # segment
y = [y[1], np.transpose(y[0], (0, 3, 1, 2))]
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels

View File

@ -1,5 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import ast
import contextlib
from copy import deepcopy
from pathlib import Path
@ -427,6 +428,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
# TODO: re-implement with eval() removal if possible
# args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a
with contextlib.suppress(NameError):
args[j] = eval(a) if isinstance(a, str) else a # eval strings
@ -480,28 +483,9 @@ def guess_model_task(model):
Raises:
SyntaxError: If the task of the model could not be determined.
"""
cfg = None
if isinstance(model, dict):
cfg = model
elif isinstance(model, nn.Module): # PyTorch model
for x in 'model.args', 'model.model.args', 'model.model.model.args':
with contextlib.suppress(Exception):
return eval(x)['task']
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
with contextlib.suppress(Exception):
cfg = eval(x)
break
elif isinstance(model, (str, Path)):
model = str(model)
if '-seg' in model:
return "segment"
elif '-cls' in model:
return "classify"
else:
return "detect"
# Guess from YAML dictionary
if cfg:
def cfg2task(cfg):
# Guess from YAML dictionary
m = cfg["head"][-1][-2].lower() # output module name
if m in ["classify", "classifier", "cls", "fc"]:
return "classify"
@ -510,8 +494,20 @@ def guess_model_task(model):
if m in ["segment"]:
return "segment"
# Guess from model cfg
if isinstance(model, dict):
with contextlib.suppress(Exception):
return cfg2task(model)
# Guess from PyTorch model
if isinstance(model, nn.Module):
if isinstance(model, nn.Module): # PyTorch model
for x in 'model.args', 'model.model.args', 'model.model.model.args':
with contextlib.suppress(Exception):
return eval(x)['task']
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
with contextlib.suppress(Exception):
return cfg2task(eval(x))
for m in model.modules():
if isinstance(m, Detect):
return "detect"
@ -520,6 +516,16 @@ def guess_model_task(model):
elif isinstance(m, Classify):
return "classify"
# Guess from model filename
if isinstance(model, (str, Path)):
model = Path(model).stem
if '-seg' in model:
return "segment"
elif '-cls' in model:
return "classify"
else:
return "detect"
# Unable to determine task from model
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")