ultralytics 8.0.88 Pose TFLite model fix (#2261)

This commit is contained in:
Glenn Jocher
2023-04-26 17:00:08 +02:00
committed by GitHub
parent efc941aa81
commit 0a36b83e7a
5 changed files with 22 additions and 13 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.87'
__version__ = '8.0.88'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO

View File

@ -161,7 +161,7 @@ class Events:
Initializes the Events object with default values for events, rate_limit, and metadata.
"""
self.events = [] # events list
self.rate_limit = 10.0 # rate limit (seconds)
self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
'cli': Path(sys.argv[0]).name == 'yolo',
@ -204,7 +204,9 @@ class Events:
# Time is over rate limiter, send now
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
smart_request('post', self.url, json=data, retry=0, code=3) # equivalent to requests.post(self.url, json=data)
# POST equivalent to requests.post(self.url, json=data)
smart_request('post', self.url, json=data, retry=0, verbose=False)
# Reset events and rate limit timer
self.events = []

View File

@ -541,18 +541,25 @@ class Pose(Detect):
x = self.detect(self, x)
if self.training:
return x, kpt
pred_kpt = self.kpts_decode(kpt)
pred_kpt = self.kpts_decode(bs, kpt)
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
def kpts_decode(self, kpts):
def kpts_decode(self, bs, kpts):
"""Decodes keypoints."""
ndim = self.kpt_shape[1]
y = kpts.clone()
if ndim == 3:
y[:, 2::3].sigmoid_() # inplace sigmoid
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
return y
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
y = kpts.view(bs, *self.kpt_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
if ndim == 3:
a = torch.cat((a, y[:, :, 1:2].sigmoid()), 2)
return a.view(bs, self.nk, -1)
else:
y = kpts.clone()
if ndim == 3:
y[:, 2::3].sigmoid_() # inplace sigmoid
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
return y
class Classify(nn.Module):