Avoid CUDA round-trip for relevant export formats (#3727)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-14 20:38:31 +02:00
committed by GitHub
parent c5991d7cd8
commit 135a10f1fa
5 changed files with 40 additions and 32 deletions

View File

@ -300,17 +300,16 @@ class BasePredictor:
def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
device = select_device(self.args.device, verbose=verbose)
model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
self.model = AutoBackend(model,
device=device,
self.model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, verbose=verbose),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
fuse=True,
verbose=verbose)
self.device = device
self.device = self.model.device # update device
self.args.half = self.model.fp16 # update half
self.model.eval()
def show(self, p):

View File

@ -109,19 +109,21 @@ class BaseValidator:
callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start')
assert model is not None, 'Either trainer or model is needed for validation'
self.device = select_device(self.args.device, self.args.batch)
self.args.half &= self.device.type != 'cpu'
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
model = AutoBackend(model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half)
self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_imgsz(self.args.imgsz, stride=stride)
if engine:
self.args.batch = model.batch_size
else:
self.device = model.device
if not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data)