Fix `workspace` and `verbose` arguments in TensorRT export (#2954)

Co-authored-by: crbrz <cristiab@gmail.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
crbrz 1 year ago committed by GitHub
parent 67cf53b475
commit a9129fb40e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -423,7 +423,7 @@ class Exporter:
return f, ct_model return f, ct_model
@try_export @try_export
def export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): def export_engine(self, prefix=colorstr('TensorRT:')):
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt.""" """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'" assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
try: try:
@ -441,12 +441,12 @@ class Exporter:
assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}' assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
f = self.file.with_suffix('.engine') # TensorRT engine file f = self.file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO) logger = trt.Logger(trt.Logger.INFO)
if verbose: if self.args.verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger) builder = trt.Builder(logger)
config = builder.create_builder_config() config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30 config.max_workspace_size = self.args.workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

Loading…
Cancel
Save