ultralytics 8.0.132
add AutoBackend NCNN inference (#3615)
Co-authored-by: triple Mu <gpu@163.com>
This commit is contained in:
@ -3,6 +3,7 @@
|
||||
import ast
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import zipfile
|
||||
from collections import OrderedDict, namedtuple
|
||||
@ -15,7 +16,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, yaml_load
|
||||
from ultralytics.yolo.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
@ -75,6 +76,7 @@ class AutoBackend(nn.Module):
|
||||
| TensorFlow Lite | *.tflite |
|
||||
| TensorFlow Edge TPU | *_edgetpu.tflite |
|
||||
| PaddlePaddle | *_paddle_model |
|
||||
| ncnn | *_ncnn_model |
|
||||
"""
|
||||
super().__init__()
|
||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||
@ -253,8 +255,19 @@ class AutoBackend(nn.Module):
|
||||
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
||||
output_names = predictor.get_output_names()
|
||||
metadata = w.parents[1] / 'metadata.yaml'
|
||||
elif ncnn: # PaddlePaddle
|
||||
raise NotImplementedError('YOLOv8 NCNN inference is not currently supported.')
|
||||
elif ncnn: # ncnn
|
||||
LOGGER.info(f'Loading {w} for ncnn inference...')
|
||||
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires NCNN
|
||||
import ncnn as pyncnn
|
||||
net = pyncnn.Net()
|
||||
net.opt.num_threads = os.cpu_count()
|
||||
net.opt.use_vulkan_compute = cuda
|
||||
w = Path(w)
|
||||
if not w.is_file(): # if not *.param
|
||||
w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
|
||||
net.load_param(str(w))
|
||||
net.load_model(str(w.with_suffix('.bin')))
|
||||
metadata = w.parent / 'metadata.yaml'
|
||||
elif triton: # NVIDIA Triton Inference Server
|
||||
LOGGER.info('Triton Inference Server not supported...')
|
||||
'''
|
||||
@ -358,6 +371,19 @@ class AutoBackend(nn.Module):
|
||||
self.input_handle.copy_from_cpu(im)
|
||||
self.predictor.run()
|
||||
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
||||
elif self.ncnn: # ncnn
|
||||
im = (im[0] * 255.).cpu().numpy().astype(np.uint8)
|
||||
im = np.ascontiguousarray(im.transpose(1, 2, 0))
|
||||
mat_in = self.pyncnn.Mat.from_pixels(im, self.pyncnn.Mat.PixelType.PIXEL_RGB, *im.shape[:2])
|
||||
mat_in.substract_mean_normalize([], [1 / 255.0, 1 / 255.0, 1 / 255.0])
|
||||
ex = self.net.create_extractor()
|
||||
input_names, output_names = self.net.input_names(), self.net.output_names()
|
||||
ex.input(input_names[0], mat_in)
|
||||
y = []
|
||||
for output_name in output_names:
|
||||
mat_out = self.pyncnn.Mat()
|
||||
ex.extract(output_name, mat_out)
|
||||
y.append(np.array(mat_out)[None])
|
||||
elif self.triton: # NVIDIA Triton Inference Server
|
||||
y = self.model(im)
|
||||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||
|
Reference in New Issue
Block a user