From 135a10f1fad1c5e4c77d8a40fb6d222143549147 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 14 Jul 2023 20:38:31 +0200 Subject: [PATCH] Avoid CUDA round-trip for relevant export formats (#3727) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/nn/autobackend.py | 22 ++++++++++++++-------- ultralytics/vit/sam/predict.py | 10 ++++++---- ultralytics/yolo/engine/predictor.py | 11 +++++------ ultralytics/yolo/engine/validator.py | 18 ++++++++++-------- ultralytics/yolo/utils/checks.py | 7 +++---- 5 files changed, 38 insertions(+), 30 deletions(-) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 04e8dca..78776bd 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -83,16 +83,23 @@ class AutoBackend(nn.Module): nn_module = isinstance(weights, torch.nn.Module) pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \ self._model_type(w) - fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16 + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) stride = 32 # default stride model, metadata = None, None + + # Set device cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA + if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats + device = torch.device('cpu') + cuda = False + + # Download if not local if not (pt or triton or nn_module): - w = attempt_download_asset(w) # download if not local + w = attempt_download_asset(w) - # NOTE: special case: in-memory pytorch model - if nn_module: + # Load model + if nn_module: # in-memory PyTorch model model = weights.to(device) model = model.fuse(verbose=verbose) if fuse else model if hasattr(model, 'kpt_shape'): @@ -269,14 +276,13 @@ class AutoBackend(nn.Module): net.load_model(str(w.with_suffix('.bin'))) metadata = w.parent / 'metadata.yaml' elif triton: # NVIDIA Triton Inference Server - LOGGER.info('Triton Inference Server not supported...') - ''' - TODO: + """TODO check_requirements('tritonclient[all]') from utils.triton import TritonRemoteModel model = TritonRemoteModel(url=w) nhwc = model.runtime.startswith("tensorflow") - ''' + """ + raise NotImplementedError('Triton Inference Server is not currently supported.') else: from ultralytics.yolo.engine.exporter import export_formats raise TypeError(f"model='{w}' is not a supported model format. " diff --git a/ultralytics/vit/sam/predict.py b/ultralytics/vit/sam/predict.py index 47a9d55..c6db86e 100644 --- a/ultralytics/vit/sam/predict.py +++ b/ultralytics/vit/sam/predict.py @@ -18,7 +18,9 @@ from .build import build_sam class Predictor(BasePredictor): - def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + if overrides is None: + overrides = {} overrides.update(dict(task='segment', mode='predict', imgsz=1024)) super().__init__(cfg, overrides, _callbacks) # SAM needs retina_masks=True, or the results would be a mess. @@ -90,7 +92,7 @@ class Predictor(BasePredictor): of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ - if all([i is None for i in [bboxes, points, masks]]): + if all(i is None for i in [bboxes, points, masks]): return self.generate(im, *args, **kwargs) return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) @@ -284,7 +286,7 @@ class Predictor(BasePredictor): return pred_masks, pred_scores, pred_bboxes - def setup_model(self, model): + def setup_model(self, model, verbose=True): """Set up YOLO model with specified thresholds and device.""" device = select_device(self.args.device) if model is None: @@ -306,7 +308,7 @@ class Predictor(BasePredictor): # (N, 1, H, W), (N, 1) pred_masks, pred_scores = preds[:2] pred_bboxes = preds[2] if self.segment_all else None - names = dict(enumerate([str(i) for i in range(len(pred_masks))])) + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) results = [] for i, masks in enumerate([pred_masks]): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index e326e5c..011d0ab 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -300,17 +300,16 @@ class BasePredictor: def setup_model(self, model, verbose=True): """Initialize YOLO model with given parameters and set it to evaluation mode.""" - device = select_device(self.args.device, verbose=verbose) - model = model or self.args.model - self.args.half &= device.type != 'cpu' # half precision only supported on CUDA - self.model = AutoBackend(model, - device=device, + self.model = AutoBackend(model or self.args.model, + device=select_device(self.args.device, verbose=verbose), dnn=self.args.dnn, data=self.args.data, fp16=self.args.half, fuse=True, verbose=verbose) - self.device = device + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half self.model.eval() def show(self, p): diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index f84c8d0..a3faebf 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -109,19 +109,21 @@ class BaseValidator: callbacks.add_integration_callbacks(self) self.run_callbacks('on_val_start') assert model is not None, 'Either trainer or model is needed for validation' - self.device = select_device(self.args.device, self.args.batch) - self.args.half &= self.device.type != 'cpu' - model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half) + model = AutoBackend(model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half) self.model = model + self.device = model.device # update device + self.args.half = model.fp16 # update half stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine imgsz = check_imgsz(self.args.imgsz, stride=stride) if engine: self.args.batch = model.batch_size - else: - self.device = model.device - if not pt and not jit: - self.args.batch = 1 # export.py models default to batch-size 1 - LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + elif not pt and not jit: + self.args.batch = 1 # export.py models default to batch-size 1 + LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'): self.data = check_det_dataset(self.args.data) diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 427a175..80e2cc3 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -213,7 +213,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=() prefix = colorstr('red', 'bold', 'requirements:') check_python() # check python version check_torchvision() # check torch-torchvision compatibility - file = None if isinstance(requirements, Path): # requirements.txt file file = requirements.resolve() assert file.exists(), f'{prefix} {file} not found, check failed.' @@ -225,13 +224,13 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=() s = '' # console string pkgs = [] for r in requirements: - rmin = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo' + r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo' try: - pkg.require(rmin) + pkg.require(r_stripped) except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met try: # attempt to import (slower but more accurate) import importlib - importlib.import_module(next(pkg.parse_requirements(rmin)).name) + importlib.import_module(next(pkg.parse_requirements(r_stripped)).name) except ImportError: s += f'"{r}" ' pkgs.append(r)