Avoid CUDA round-trip for relevant export formats (#3727)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-14 20:38:31 +02:00
committed by GitHub
parent c5991d7cd8
commit 135a10f1fa
5 changed files with 40 additions and 32 deletions

View File

@ -83,16 +83,23 @@ class AutoBackend(nn.Module):
nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
self._model_type(w)
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
model, metadata = None, None
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton or nn_module):
w = attempt_download_asset(w) # download if not local
# NOTE: special case: in-memory pytorch model
if nn_module:
# Set device
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats
device = torch.device('cpu')
cuda = False
# Download if not local
if not (pt or triton or nn_module):
w = attempt_download_asset(w)
# Load model
if nn_module: # in-memory PyTorch model
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
if hasattr(model, 'kpt_shape'):
@ -269,14 +276,13 @@ class AutoBackend(nn.Module):
net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml'
elif triton: # NVIDIA Triton Inference Server
LOGGER.info('Triton Inference Server not supported...')
'''
TODO:
"""TODO
check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow")
'''
"""
raise NotImplementedError('Triton Inference Server is not currently supported.')
else:
from ultralytics.yolo.engine.exporter import export_formats
raise TypeError(f"model='{w}' is not a supported model format. "