Fix catastrophic accuracy degradation of TFLite static quantized integer models (#1695)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -593,14 +593,43 @@ class Exporter:
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'
|
||||
LOGGER.info(f"\n{prefix} running '{cmd}'")
|
||||
tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
|
||||
if self.args.int8:
|
||||
if self.args.data:
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
|
||||
# Generate calibration data for integer quantization
|
||||
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
||||
dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
|
||||
images = []
|
||||
n_images = 100 # maximum number of images
|
||||
for n, batch in enumerate(dataset):
|
||||
if n >= n_images:
|
||||
break
|
||||
im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
|
||||
images.append(im)
|
||||
f.mkdir()
|
||||
images = torch.cat(images, 0).float()
|
||||
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
|
||||
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
|
||||
np.save(str(tmp_file), images.numpy()) # BHWC
|
||||
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
||||
else:
|
||||
int8 = '-oiqt -qt per-tensor'
|
||||
else:
|
||||
int8 = ''
|
||||
|
||||
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
|
||||
LOGGER.info(f"{prefix} running '{cmd}'")
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
# Remove/rename TFLite models
|
||||
if self.args.int8:
|
||||
tmp_file.unlink(missing_ok=True)
|
||||
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
||||
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
||||
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
||||
|
@ -343,6 +343,8 @@ class YOLO:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if 'batch' not in kwargs:
|
||||
overrides['batch'] = 1 # default to 1 if not modified
|
||||
if 'data' not in kwargs:
|
||||
overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
Reference in New Issue
Block a user