ultralytics 8.0.132 add AutoBackend NCNN inference (#3615)

Co-authored-by: triple Mu <gpu@163.com>
This commit is contained in:
Glenn Jocher
2023-07-10 16:49:51 +02:00
committed by GitHub
parent 391b7e67cf
commit 495edc261f
9 changed files with 39 additions and 12 deletions

View File

@ -3,6 +3,7 @@
import ast
import contextlib
import json
import os
import platform
import zipfile
from collections import OrderedDict, namedtuple
@ -15,7 +16,7 @@ import torch
import torch.nn as nn
from PIL import Image
from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, yaml_load
from ultralytics.yolo.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy
@ -75,6 +76,7 @@ class AutoBackend(nn.Module):
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model |
| ncnn | *_ncnn_model |
"""
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
@ -253,8 +255,19 @@ class AutoBackend(nn.Module):
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / 'metadata.yaml'
elif ncnn: # PaddlePaddle
raise NotImplementedError('YOLOv8 NCNN inference is not currently supported.')
elif ncnn: # ncnn
LOGGER.info(f'Loading {w} for ncnn inference...')
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires NCNN
import ncnn as pyncnn
net = pyncnn.Net()
net.opt.num_threads = os.cpu_count()
net.opt.use_vulkan_compute = cuda
w = Path(w)
if not w.is_file(): # if not *.param
w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
net.load_param(str(w))
net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml'
elif triton: # NVIDIA Triton Inference Server
LOGGER.info('Triton Inference Server not supported...')
'''
@ -358,6 +371,19 @@ class AutoBackend(nn.Module):
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.ncnn: # ncnn
im = (im[0] * 255.).cpu().numpy().astype(np.uint8)
im = np.ascontiguousarray(im.transpose(1, 2, 0))
mat_in = self.pyncnn.Mat.from_pixels(im, self.pyncnn.Mat.PixelType.PIXEL_RGB, *im.shape[:2])
mat_in.substract_mean_normalize([], [1 / 255.0, 1 / 255.0, 1 / 255.0])
ex = self.net.create_extractor()
input_names, output_names = self.net.input_names(), self.net.output_names()
ex.input(input_names[0], mat_in)
y = []
for output_name in output_names:
mat_out = self.pyncnn.Mat()
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)