ultralytics 8.0.37
add TFLite metadata in AutoBackend (#953)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Aarni Koskela <akx@iki.fi>
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import json
|
||||
import platform
|
||||
import zipfile
|
||||
from collections import OrderedDict, namedtuple
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@ -207,6 +209,12 @@ class AutoBackend(nn.Module):
|
||||
interpreter.allocate_tensors() # allocate
|
||||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
# load metadata
|
||||
with contextlib.suppress(zipfile.BadZipFile):
|
||||
with zipfile.ZipFile(w, "r") as model:
|
||||
meta_file = model.namelist()[0]
|
||||
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||
stride, names = int(meta['stride']), meta['names']
|
||||
elif tfjs: # TF.js
|
||||
raise NotImplementedError('ERROR: YOLOv8 TF.js inference is not supported')
|
||||
elif paddle: # PaddlePaddle
|
||||
@ -214,7 +222,7 @@ class AutoBackend(nn.Module):
|
||||
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
||||
import paddle.inference as pdi
|
||||
if not Path(w).is_file(): # if not *.pdmodel
|
||||
w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
|
||||
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
||||
weights = Path(w).with_suffix('.pdiparams')
|
||||
config = pdi.Config(str(w), str(weights))
|
||||
if cuda:
|
||||
@ -328,6 +336,9 @@ class AutoBackend(nn.Module):
|
||||
scale, zero_point = output['quantization']
|
||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||
y.append(x)
|
||||
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
||||
if len(self.output_details) == 2: # segment
|
||||
y = [y[1], np.transpose(y[0], (0, 3, 1, 2))]
|
||||
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
||||
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@ -427,6 +428,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||
m = eval(m) if isinstance(m, str) else m # eval strings
|
||||
for j, a in enumerate(args):
|
||||
# TODO: re-implement with eval() removal if possible
|
||||
# args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a
|
||||
with contextlib.suppress(NameError):
|
||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||
|
||||
@ -480,28 +483,9 @@ def guess_model_task(model):
|
||||
Raises:
|
||||
SyntaxError: If the task of the model could not be determined.
|
||||
"""
|
||||
cfg = None
|
||||
if isinstance(model, dict):
|
||||
cfg = model
|
||||
elif isinstance(model, nn.Module): # PyTorch model
|
||||
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
||||
with contextlib.suppress(Exception):
|
||||
return eval(x)['task']
|
||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||
with contextlib.suppress(Exception):
|
||||
cfg = eval(x)
|
||||
break
|
||||
elif isinstance(model, (str, Path)):
|
||||
model = str(model)
|
||||
if '-seg' in model:
|
||||
return "segment"
|
||||
elif '-cls' in model:
|
||||
return "classify"
|
||||
else:
|
||||
return "detect"
|
||||
|
||||
# Guess from YAML dictionary
|
||||
if cfg:
|
||||
def cfg2task(cfg):
|
||||
# Guess from YAML dictionary
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ["classify", "classifier", "cls", "fc"]:
|
||||
return "classify"
|
||||
@ -510,8 +494,20 @@ def guess_model_task(model):
|
||||
if m in ["segment"]:
|
||||
return "segment"
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
with contextlib.suppress(Exception):
|
||||
return cfg2task(model)
|
||||
|
||||
# Guess from PyTorch model
|
||||
if isinstance(model, nn.Module):
|
||||
if isinstance(model, nn.Module): # PyTorch model
|
||||
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
||||
with contextlib.suppress(Exception):
|
||||
return eval(x)['task']
|
||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||
with contextlib.suppress(Exception):
|
||||
return cfg2task(eval(x))
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Detect):
|
||||
return "detect"
|
||||
@ -520,6 +516,16 @@ def guess_model_task(model):
|
||||
elif isinstance(m, Classify):
|
||||
return "classify"
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
model = Path(model).stem
|
||||
if '-seg' in model:
|
||||
return "segment"
|
||||
elif '-cls' in model:
|
||||
return "classify"
|
||||
else:
|
||||
return "detect"
|
||||
|
||||
# Unable to determine task from model
|
||||
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
|
||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
|
Reference in New Issue
Block a user