Fix tflite int8 scaling (#3837)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Colin Wong 1 year ago committed by GitHub
parent 0ce66f5266
commit b251b7415a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -415,6 +415,14 @@ class AutoBackend(nn.Module):
if int8:
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
if x.ndim > 2: # if task is not classification
# Unnormalize xywh with input image size
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
x[:, 0] *= w
x[:, 1] *= h
x[:, 2] *= w
x[:, 3] *= h
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed
@ -422,7 +430,6 @@ class AutoBackend(nn.Module):
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
# y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
# for x in y:
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes

@ -58,6 +58,16 @@ class Detect(nn.Module):
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
img_h = shape[2] * self.stride[0]
img_w = shape[3] * self.stride[0]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
dbox /= img_size
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)

Loading…
Cancel
Save