Avoid CUDA round-trip for relevant export formats (#3727)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent c5991d7cd8
commit 135a10f1fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -83,16 +83,23 @@ class AutoBackend(nn.Module):
nn_module = isinstance(weights, torch.nn.Module) nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \ pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
self._model_type(w) self._model_type(w)
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16 fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride stride = 32 # default stride
model, metadata = None, None model, metadata = None, None
# Set device
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats
device = torch.device('cpu')
cuda = False
# Download if not local
if not (pt or triton or nn_module): if not (pt or triton or nn_module):
w = attempt_download_asset(w) # download if not local w = attempt_download_asset(w)
# NOTE: special case: in-memory pytorch model # Load model
if nn_module: if nn_module: # in-memory PyTorch model
model = weights.to(device) model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model model = model.fuse(verbose=verbose) if fuse else model
if hasattr(model, 'kpt_shape'): if hasattr(model, 'kpt_shape'):
@ -269,14 +276,13 @@ class AutoBackend(nn.Module):
net.load_model(str(w.with_suffix('.bin'))) net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml' metadata = w.parent / 'metadata.yaml'
elif triton: # NVIDIA Triton Inference Server elif triton: # NVIDIA Triton Inference Server
LOGGER.info('Triton Inference Server not supported...') """TODO
'''
TODO:
check_requirements('tritonclient[all]') check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w) model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow") nhwc = model.runtime.startswith("tensorflow")
''' """
raise NotImplementedError('Triton Inference Server is not currently supported.')
else: else:
from ultralytics.yolo.engine.exporter import export_formats from ultralytics.yolo.engine.exporter import export_formats
raise TypeError(f"model='{w}' is not a supported model format. " raise TypeError(f"model='{w}' is not a supported model format. "

@ -18,7 +18,9 @@ from .build import build_sam
class Predictor(BasePredictor): class Predictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
if overrides is None:
overrides = {}
overrides.update(dict(task='segment', mode='predict', imgsz=1024)) overrides.update(dict(task='segment', mode='predict', imgsz=1024))
super().__init__(cfg, overrides, _callbacks) super().__init__(cfg, overrides, _callbacks)
# SAM needs retina_masks=True, or the results would be a mess. # SAM needs retina_masks=True, or the results would be a mess.
@ -90,7 +92,7 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input. a subsequent iteration as mask input.
""" """
if all([i is None for i in [bboxes, points, masks]]): if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs) return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@ -284,7 +286,7 @@ class Predictor(BasePredictor):
return pred_masks, pred_scores, pred_bboxes return pred_masks, pred_scores, pred_bboxes
def setup_model(self, model): def setup_model(self, model, verbose=True):
"""Set up YOLO model with specified thresholds and device.""" """Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device) device = select_device(self.args.device)
if model is None: if model is None:
@ -306,7 +308,7 @@ class Predictor(BasePredictor):
# (N, 1, H, W), (N, 1) # (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2] pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None pred_bboxes = preds[2] if self.segment_all else None
names = dict(enumerate([str(i) for i in range(len(pred_masks))])) names = dict(enumerate(str(i) for i in range(len(pred_masks))))
results = [] results = []
for i, masks in enumerate([pred_masks]): for i, masks in enumerate([pred_masks]):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs

@ -300,17 +300,16 @@ class BasePredictor:
def setup_model(self, model, verbose=True): def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode.""" """Initialize YOLO model with given parameters and set it to evaluation mode."""
device = select_device(self.args.device, verbose=verbose) self.model = AutoBackend(model or self.args.model,
model = model or self.args.model device=select_device(self.args.device, verbose=verbose),
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
self.model = AutoBackend(model,
device=device,
dnn=self.args.dnn, dnn=self.args.dnn,
data=self.args.data, data=self.args.data,
fp16=self.args.half, fp16=self.args.half,
fuse=True, fuse=True,
verbose=verbose) verbose=verbose)
self.device = device
self.device = self.model.device # update device
self.args.half = self.model.fp16 # update half
self.model.eval() self.model.eval()
def show(self, p): def show(self, p):

@ -109,19 +109,21 @@ class BaseValidator:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start') self.run_callbacks('on_val_start')
assert model is not None, 'Either trainer or model is needed for validation' assert model is not None, 'Either trainer or model is needed for validation'
self.device = select_device(self.args.device, self.args.batch) model = AutoBackend(model,
self.args.half &= self.device.type != 'cpu' device=select_device(self.args.device, self.args.batch),
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half) dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half)
self.model = model self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_imgsz(self.args.imgsz, stride=stride) imgsz = check_imgsz(self.args.imgsz, stride=stride)
if engine: if engine:
self.args.batch = model.batch_size self.args.batch = model.batch_size
else: elif not pt and not jit:
self.device = model.device self.args.batch = 1 # export.py models default to batch-size 1
if not pt and not jit: LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'): if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data) self.data = check_det_dataset(self.args.data)

@ -213,7 +213,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version check_python() # check python version
check_torchvision() # check torch-torchvision compatibility check_torchvision() # check torch-torchvision compatibility
file = None
if isinstance(requirements, Path): # requirements.txt file if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve() file = requirements.resolve()
assert file.exists(), f'{prefix} {file} not found, check failed.' assert file.exists(), f'{prefix} {file} not found, check failed.'
@ -225,13 +224,13 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
s = '' # console string s = '' # console string
pkgs = [] pkgs = []
for r in requirements: for r in requirements:
rmin = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo' r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
try: try:
pkg.require(rmin) pkg.require(r_stripped)
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
try: # attempt to import (slower but more accurate) try: # attempt to import (slower but more accurate)
import importlib import importlib
importlib.import_module(next(pkg.parse_requirements(rmin)).name) importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
except ImportError: except ImportError:
s += f'"{r}" ' s += f'"{r}" '
pkgs.append(r) pkgs.append(r)

Loading…
Cancel
Save