|
|
@ -593,14 +593,43 @@ class Exporter:
|
|
|
|
f_onnx, _ = self.export_onnx()
|
|
|
|
f_onnx, _ = self.export_onnx()
|
|
|
|
|
|
|
|
|
|
|
|
# Export to TF
|
|
|
|
# Export to TF
|
|
|
|
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
|
|
|
tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
|
|
|
|
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'
|
|
|
|
if self.args.int8:
|
|
|
|
LOGGER.info(f"\n{prefix} running '{cmd}'")
|
|
|
|
if self.args.data:
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ultralytics.data.dataset import YOLODataset
|
|
|
|
|
|
|
|
from ultralytics.data.utils import check_det_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Generate calibration data for integer quantization
|
|
|
|
|
|
|
|
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
|
|
|
|
|
|
|
dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
|
|
|
|
|
|
|
|
images = []
|
|
|
|
|
|
|
|
n_images = 100 # maximum number of images
|
|
|
|
|
|
|
|
for n, batch in enumerate(dataset):
|
|
|
|
|
|
|
|
if n >= n_images:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
|
|
|
|
|
|
|
|
images.append(im)
|
|
|
|
|
|
|
|
f.mkdir()
|
|
|
|
|
|
|
|
images = torch.cat(images, 0).float()
|
|
|
|
|
|
|
|
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
|
|
|
|
|
|
|
|
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
|
|
|
|
|
|
|
|
np.save(str(tmp_file), images.numpy()) # BHWC
|
|
|
|
|
|
|
|
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
int8 = '-oiqt -qt per-tensor'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
int8 = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
|
|
|
|
|
|
|
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
|
|
subprocess.run(cmd, shell=True)
|
|
|
|
subprocess.run(cmd, shell=True)
|
|
|
|
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
|
|
|
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
|
|
|
|
|
|
|
|
|
|
|
# Remove/rename TFLite models
|
|
|
|
# Remove/rename TFLite models
|
|
|
|
if self.args.int8:
|
|
|
|
if self.args.int8:
|
|
|
|
|
|
|
|
tmp_file.unlink(missing_ok=True)
|
|
|
|
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
|
|
|
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
|
|
|
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
|
|
|
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
|
|
|
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
|
|
|
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
|
|
|