Avoid CUDA round-trip for relevant export formats (#3727)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -300,17 +300,16 @@ class BasePredictor:
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
self.model = AutoBackend(model,
|
||||
device=device,
|
||||
self.model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, verbose=verbose),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
fuse=True,
|
||||
verbose=verbose)
|
||||
self.device = device
|
||||
|
||||
self.device = self.model.device # update device
|
||||
self.args.half = self.model.fp16 # update half
|
||||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
|
@ -109,19 +109,21 @@ class BaseValidator:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, 'Either trainer or model is needed for validation'
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
||||
model = AutoBackend(model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half)
|
||||
self.model = model
|
||||
self.device = model.device # update device
|
||||
self.args.half = model.fp16 # update half
|
||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||
if engine:
|
||||
self.args.batch = model.batch_size
|
||||
else:
|
||||
self.device = model.device
|
||||
if not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
elif not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
|
Reference in New Issue
Block a user