Fix save_txt
in track mode and add Keypoints and Probs (#2921)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -23,7 +23,13 @@ class BaseTensor(SimpleClass):
|
||||
"""
|
||||
|
||||
def __init__(self, data, orig_shape) -> None:
|
||||
"""Initialize BaseTensor with data and original shape."""
|
||||
"""Initialize BaseTensor with data and original shape.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
||||
orig_shape (tuple): Original shape of image.
|
||||
"""
|
||||
assert isinstance(data, (torch.Tensor, np.ndarray))
|
||||
self.data = data
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@ -34,19 +40,19 @@ class BaseTensor(SimpleClass):
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the tensor on CPU memory."""
|
||||
return self.__class__(self.data.cpu(), self.orig_shape)
|
||||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the tensor as a numpy array."""
|
||||
return self.__class__(self.data.numpy(), self.orig_shape)
|
||||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the tensor on GPU memory."""
|
||||
return self.__class__(self.data.cuda(), self.orig_shape)
|
||||
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the tensor with the specified device and dtype."""
|
||||
return self.__class__(self.data.to(*args, **kwargs), self.orig_shape)
|
||||
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
"""Return the length of the data tensor."""
|
||||
@ -90,8 +96,8 @@ class Results(SimpleClass):
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.keypoints = keypoints if keypoints is not None else None
|
||||
self.probs = Probs(probs) if probs is not None else None
|
||||
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
||||
self.names = names
|
||||
self.path = path
|
||||
@ -229,13 +235,11 @@ class Results(SimpleClass):
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if pred_probs is not None and show_probs:
|
||||
n5 = min(len(names), 5)
|
||||
top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, "
|
||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
if keypoints is not None:
|
||||
for k in reversed(keypoints):
|
||||
for k in reversed(keypoints.data):
|
||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||
|
||||
return annotator.result()
|
||||
@ -250,9 +254,7 @@ class Results(SimpleClass):
|
||||
if len(self) == 0:
|
||||
return log_string if probs is not None else f'{log_string}(no detections), '
|
||||
if probs is not None:
|
||||
n5 = min(len(self.names), 5)
|
||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
log_string += f"{', '.join(f'{self.names[j]} {probs[j]:.2f}' for j in top5i)}, "
|
||||
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
||||
if boxes:
|
||||
for c in boxes.cls.unique():
|
||||
n = (boxes.cls == c).sum() # detections per class
|
||||
@ -274,9 +276,7 @@ class Results(SimpleClass):
|
||||
texts = []
|
||||
if probs is not None:
|
||||
# Classify
|
||||
n5 = min(len(self.names), 5)
|
||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
[texts.append(f'{probs[j]:.2f} {self.names[j]}') for j in top5i]
|
||||
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
||||
elif boxes:
|
||||
# Detect/segment/pose
|
||||
for j, d in enumerate(boxes):
|
||||
@ -286,7 +286,7 @@ class Results(SimpleClass):
|
||||
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
||||
line = (c, *seg)
|
||||
if kpts is not None:
|
||||
kpt = (kpts[j][:, :2].cpu() / d.orig_shape[[1, 0]]).reshape(-1).tolist()
|
||||
kpt = kpts[j].xyn.reshape(-1).tolist()
|
||||
line += (*kpt, )
|
||||
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
||||
texts.append(('%g ' * len(line)).rstrip() % line)
|
||||
@ -322,6 +322,10 @@ class Results(SimpleClass):
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
"""Convert the object to JSON format."""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
# Create list of detection dictionaries
|
||||
@ -338,7 +342,7 @@ class Results(SimpleClass):
|
||||
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
||||
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
||||
if self.keypoints is not None:
|
||||
x, y, visible = self.keypoints[i].cpu().unbind(dim=1) # torch Tensor
|
||||
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
||||
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
||||
results.append(result)
|
||||
|
||||
@ -386,8 +390,7 @@ class Boxes(BaseTensor):
|
||||
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 7
|
||||
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
|
||||
else np.asarray(orig_shape)
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
@ -419,13 +422,19 @@ class Boxes(BaseTensor):
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyn(self):
|
||||
"""Return the boxes in xyxy format normalized by original image size."""
|
||||
return self.xyxy / self.orig_shape[[1, 0, 1, 0]]
|
||||
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
|
||||
xyxy[..., [0, 2]] /= self.orig_shape[1]
|
||||
xyxy[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xyxy
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywhn(self):
|
||||
"""Return the boxes in xywh format normalized by original image size."""
|
||||
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
||||
xywh = ops.xyxy2xywh(self.xyxy)
|
||||
xywh[..., [0, 2]] /= self.orig_shape[1]
|
||||
xywh[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xywh
|
||||
|
||||
@property
|
||||
def boxes(self):
|
||||
@ -439,11 +448,11 @@ class Masks(BaseTensor):
|
||||
A class for storing and manipulating detection masks.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
@ -496,3 +505,100 @@ class Masks(BaseTensor):
|
||||
def pandas(self):
|
||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.pandas' method is not yet implemented.")
|
||||
|
||||
|
||||
class Keypoints(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection keypoints.
|
||||
|
||||
Args:
|
||||
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xy (list): A list of keypoints (pixels) which includes x, y keypoints of each detection.
|
||||
xyn (list): A list of keypoints (normalized) which includes x, y keypoints of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the keypoints tensor on CPU memory.
|
||||
numpy(): Returns a copy of the keypoints tensor as a numpy array.
|
||||
cuda(): Returns a copy of the keypoints tensor on GPU memory.
|
||||
to(): Returns a copy of the keypoints tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, keypoints, orig_shape) -> None:
|
||||
if keypoints.ndim == 2:
|
||||
keypoints = keypoints[None, :]
|
||||
super().__init__(keypoints, orig_shape)
|
||||
self.has_visible = self.data.shape[-1] == 3
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
return self.data[..., :2]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
||||
xy[..., 0] /= self.orig_shape[1]
|
||||
xy[..., 1] /= self.orig_shape[0]
|
||||
return xy
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def conf(self):
|
||||
return self.data[..., 3] if self.has_visible else None
|
||||
|
||||
|
||||
class Probs(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating classify predictions.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class, ).
|
||||
|
||||
Attributes:
|
||||
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class).
|
||||
|
||||
Properties:
|
||||
top5 (list[int]): Top 1 indice.
|
||||
top1 (int): Top 5 indices.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the probs tensor on CPU memory.
|
||||
numpy(): Returns a copy of the probs tensor as a numpy array.
|
||||
cuda(): Returns a copy of the probs tensor on GPU memory.
|
||||
to(): Returns a copy of the probs tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, probs, orig_shape=None) -> None:
|
||||
super().__init__(probs, orig_shape)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5(self):
|
||||
"""Return the indices of top 5."""
|
||||
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1(self):
|
||||
"""Return the indices of top 1."""
|
||||
return int(self.data.argmax())
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5conf(self):
|
||||
"""Return the confidences of top 5."""
|
||||
return self.data[self.top5]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1conf(self):
|
||||
"""Return the confidences of top 1."""
|
||||
return self.data[self.top1]
|
||||
|
Reference in New Issue
Block a user