From b251b7415ad4fe78f4f7224562dbb8e6c1cd81b2 Mon Sep 17 00:00:00 2001 From: Colin Wong Date: Thu, 20 Jul 2023 18:37:41 -0500 Subject: [PATCH] Fix tflite int8 scaling (#3837) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/nn/autobackend.py | 9 ++++++++- ultralytics/nn/modules/head.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 4a47f03..242f356 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -415,6 +415,14 @@ class AutoBackend(nn.Module): if int8: scale, zero_point = output['quantization'] x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim > 2: # if task is not classification + # Unnormalize xywh with input image size + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695 + x[:, 0] *= w + x[:, 1] *= h + x[:, 2] *= w + x[:, 3] *= h y.append(x) # TF segment fixes: export is reversed vs ONNX export and protos are transposed if len(y) == 2: # segment with (det, proto) output order reversed @@ -422,7 +430,6 @@ class AutoBackend(nn.Module): y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] - # y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels # for x in y: # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index e31ae8b..4e3661b 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -58,6 +58,16 @@ class Detect(nn.Module): else: box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + + if self.export and self.format in ('tflite', 'edgetpu'): + # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5: + # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309 + # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695 + img_h = shape[2] * self.stride[0] + img_w = shape[3] * self.stride[0] + img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1) + dbox /= img_size + y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, x)