New YOLOv8 `Results()` class for prediction outputs (#314)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Viet Nhat Thai <60825385+vietnhatthai@users.noreply.github.com> Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com>single_channel
parent
0cb87f7dd3
commit
c6985da9de
@ -0,0 +1,72 @@
|
|||||||
|
Inference or prediction of a task returns a list of `Results` objects. Alternatively, in the streaming mode, it returns a generator of `Results` objects which is memory efficient. Streaming mode can be enabled by passing `stream=True` in predictor's call method.
|
||||||
|
|
||||||
|
!!! example "Predict"
|
||||||
|
=== "Getting a List"
|
||||||
|
```python
|
||||||
|
inputs = [img, img] # list of np arrays
|
||||||
|
results = model(inputs) # List of Results objects
|
||||||
|
for result in results:
|
||||||
|
boxes = results.boxes # Boxes object for bbox outputs
|
||||||
|
masks = results.masks # Masks object for segmenation masks outputs
|
||||||
|
probs = results.probs # Class probabilities for classification outputs
|
||||||
|
...
|
||||||
|
```
|
||||||
|
=== "Getting a Generator"
|
||||||
|
```python
|
||||||
|
inputs = [img, img] # list of np arrays
|
||||||
|
results = model(inputs, stream="True") # Generator of Results objects
|
||||||
|
for result in results:
|
||||||
|
boxes = results.boxes # Boxes object for bbox outputs
|
||||||
|
masks = results.masks # Masks object for segmenation masks outputs
|
||||||
|
probs = results.probs # Class probabilities for classification outputs
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Working with Results
|
||||||
|
|
||||||
|
Results object consists of these component objects:
|
||||||
|
|
||||||
|
- `results.boxes` : It is an object of class `Boxes`. It has properties and methods for manipulating bboxes
|
||||||
|
- `results.masks` : It is an object of class `Masks`. It can be used to index masks or to get segment coordinates.
|
||||||
|
- `results.prob` : It is a `Tensor` object. It contains the class probabilities/logits.
|
||||||
|
|
||||||
|
Each result is composed of torch.Tensor by default, in which you can easily use following functionality:
|
||||||
|
```python
|
||||||
|
results = results.cuda()
|
||||||
|
results = results.cpu()
|
||||||
|
results = results.to("cpu")
|
||||||
|
results = results.numpy()
|
||||||
|
```
|
||||||
|
### Boxes
|
||||||
|
`Boxes` object can be used index, manipulate and convert bboxes to different formats. The box format conversion operations are cached, which means they're only calculated once per object and those values are reused for future calls.
|
||||||
|
|
||||||
|
- Indexing a `Boxes` objects returns a `Boxes` object
|
||||||
|
```python
|
||||||
|
boxes = results.boxes
|
||||||
|
box = boxes[0] # returns one box
|
||||||
|
box.xyxy
|
||||||
|
```
|
||||||
|
- Properties and conversions
|
||||||
|
```
|
||||||
|
results.boxes.xyxy # box with xyxy format, (N, 4)
|
||||||
|
results.boxes.xywh # box with xywh format, (N, 4)
|
||||||
|
results.boxes.xyxyn # box with xyxy format but normalized, (N, 4)
|
||||||
|
results.boxes.xywhn # box with xywh format but normalized, (N, 4)
|
||||||
|
results.boxes.conf # confidence score, (N, 1)
|
||||||
|
results.boxes.cls # cls, (N, 1)
|
||||||
|
```
|
||||||
|
### Masks
|
||||||
|
`Masks` object can be used index, manipulate and convert masks to segments. The segment conversion operation is cached.
|
||||||
|
|
||||||
|
```python
|
||||||
|
results.masks.masks # masks, (N, H, W)
|
||||||
|
results.masks.segments # bounding coordinates of masks, List[segment] * N
|
||||||
|
```
|
||||||
|
|
||||||
|
### probs
|
||||||
|
`probs` attribute of `Results` class is a `Tensor` containing class probabilities of a classification operation.
|
||||||
|
```python
|
||||||
|
results.probs # cls prob, (num_class, )
|
||||||
|
```
|
||||||
|
|
||||||
|
Class reference documentation for `Results` module and its components can be found [here](reference/results.md)
|
@ -0,0 +1,11 @@
|
|||||||
|
### Results API Reference
|
||||||
|
|
||||||
|
:::ultralytics.yolo.engine.results.Results
|
||||||
|
|
||||||
|
### Boxes API Reference
|
||||||
|
|
||||||
|
:::ultralytics.yolo.engine.results.Boxes
|
||||||
|
|
||||||
|
### Masks API Reference
|
||||||
|
|
||||||
|
:::ultralytics.yolo.engine.results.Masks
|
@ -0,0 +1,284 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import LOGGER, ops
|
||||||
|
|
||||||
|
|
||||||
|
class Results:
|
||||||
|
"""
|
||||||
|
A class for storing and manipulating inference results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||||
|
masks (Masks, optional): A Masks object containing the detection masks.
|
||||||
|
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||||
|
orig_shape (tuple, optional): Original image size.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||||
|
masks (Masks, optional): A Masks object containing the detection masks.
|
||||||
|
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||||
|
orig_shape (tuple, optional): Original image size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
|
||||||
|
self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes
|
||||||
|
self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks
|
||||||
|
self.probs = probs.softmax(0) if probs is not None else None
|
||||||
|
self.orig_shape = orig_shape
|
||||||
|
self.comp = ["boxes", "masks", "probs"]
|
||||||
|
|
||||||
|
def pandas(self):
|
||||||
|
pass
|
||||||
|
# TODO masks.pandas + boxes.pandas + cls.pandas
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
r = Results(orig_shape=self.orig_shape)
|
||||||
|
for item in self.comp:
|
||||||
|
if getattr(self, item) is None:
|
||||||
|
continue
|
||||||
|
setattr(r, item, getattr(self, item)[idx])
|
||||||
|
return r
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
r = Results(orig_shape=self.orig_shape)
|
||||||
|
for item in self.comp:
|
||||||
|
if getattr(self, item) is None:
|
||||||
|
continue
|
||||||
|
setattr(r, item, getattr(self, item).cpu())
|
||||||
|
return r
|
||||||
|
|
||||||
|
def numpy(self):
|
||||||
|
r = Results(orig_shape=self.orig_shape)
|
||||||
|
for item in self.comp:
|
||||||
|
if getattr(self, item) is None:
|
||||||
|
continue
|
||||||
|
setattr(r, item, getattr(self, item).numpy())
|
||||||
|
return r
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
r = Results(orig_shape=self.orig_shape)
|
||||||
|
for item in self.comp:
|
||||||
|
if getattr(self, item) is None:
|
||||||
|
continue
|
||||||
|
setattr(r, item, getattr(self, item).cuda())
|
||||||
|
return r
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
r = Results(orig_shape=self.orig_shape)
|
||||||
|
for item in self.comp:
|
||||||
|
if getattr(self, item) is None:
|
||||||
|
continue
|
||||||
|
setattr(r, item, getattr(self, item).to(*args, **kwargs))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
for item in self.comp:
|
||||||
|
if getattr(self, item) is None:
|
||||||
|
continue
|
||||||
|
return len(getattr(self, item))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = f'Ultralytics YOLO {self.__class__} instance\n' # string
|
||||||
|
if self.boxes:
|
||||||
|
s = s + self.boxes.__repr__() + '\n'
|
||||||
|
if self.masks:
|
||||||
|
s = s + self.masks.__repr__() + '\n'
|
||||||
|
if self.probs:
|
||||||
|
s = s + self.probs.__repr__()
|
||||||
|
s += f'original size: {self.orig_shape}\n'
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class Boxes:
|
||||||
|
"""
|
||||||
|
A class for storing and manipulating detection boxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||||
|
with shape (num_boxes, 6). The last two columns should contain confidence and class values.
|
||||||
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||||
|
with shape (num_boxes, 6).
|
||||||
|
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||||
|
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||||
|
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||||
|
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||||
|
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||||
|
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, boxes, orig_shape) -> None:
|
||||||
|
if boxes.ndim == 1:
|
||||||
|
boxes = boxes[None, :]
|
||||||
|
assert boxes.shape[-1] == 6 # xyxy, conf, cls
|
||||||
|
self.boxes = boxes
|
||||||
|
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
|
||||||
|
else np.asarray(orig_shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def xyxy(self):
|
||||||
|
return self.boxes[:, :4]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conf(self):
|
||||||
|
return self.boxes[:, -2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cls(self):
|
||||||
|
return self.boxes[:, -1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||||
|
def xywh(self):
|
||||||
|
return ops.xyxy2xywh(self.xyxy)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=2)
|
||||||
|
def xyxyn(self):
|
||||||
|
return self.xyxy / self.orig_shape[[1, 0, 1, 0]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=2)
|
||||||
|
def xywhn(self):
|
||||||
|
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
boxes = self.boxes.cpu()
|
||||||
|
return Boxes(boxes, self.orig_shape)
|
||||||
|
|
||||||
|
def numpy(self):
|
||||||
|
boxes = self.boxes.numpy()
|
||||||
|
return Boxes(boxes, self.orig_shape)
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
boxes = self.boxes.cuda()
|
||||||
|
return Boxes(boxes, self.orig_shape)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
boxes = self.boxes.to(*args, **kwargs)
|
||||||
|
return Boxes(boxes, self.orig_shape)
|
||||||
|
|
||||||
|
def pandas(self):
|
||||||
|
LOGGER.info('results.pandas() method not yet implemented')
|
||||||
|
'''
|
||||||
|
new = copy(self) # return copy
|
||||||
|
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
||||||
|
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
||||||
|
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
||||||
|
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
||||||
|
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
||||||
|
return new
|
||||||
|
'''
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.boxes.shape
|
||||||
|
|
||||||
|
def __len__(self): # override len(results)
|
||||||
|
return len(self.boxes)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
|
||||||
|
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}")
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
boxes = self.boxes[idx]
|
||||||
|
return Boxes(boxes, self.orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Masks:
|
||||||
|
"""
|
||||||
|
A class for storing and manipulating detection masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||||
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||||
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, masks, orig_shape) -> None:
|
||||||
|
self.masks = masks # N, h, w
|
||||||
|
self.orig_shape = orig_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def segments(self):
|
||||||
|
return [
|
||||||
|
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True)
|
||||||
|
for x in reversed(ops.masks2segments(self.masks))]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.masks.shape
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
masks = self.masks.cpu()
|
||||||
|
return Masks(masks, self.orig_shape)
|
||||||
|
|
||||||
|
def numpy(self):
|
||||||
|
masks = self.masks.numpy()
|
||||||
|
return Masks(masks, self.orig_shape)
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
masks = self.masks.cuda()
|
||||||
|
return Masks(masks, self.orig_shape)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
masks = self.masks.to(*args, **kwargs)
|
||||||
|
return Masks(masks, self.orig_shape)
|
||||||
|
|
||||||
|
def __len__(self): # override len(results)
|
||||||
|
return len(self.masks)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
|
||||||
|
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}")
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
masks = self.masks[idx]
|
||||||
|
return Masks(masks, self.im_shape, self.orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test examples
|
||||||
|
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
|
||||||
|
results = results.cuda()
|
||||||
|
print("--cuda--pass--")
|
||||||
|
results = results.cpu()
|
||||||
|
print("--cpu--pass--")
|
||||||
|
results = results.to("cuda:0")
|
||||||
|
print("--to-cuda--pass--")
|
||||||
|
results = results.to("cpu")
|
||||||
|
print("--to-cpu--pass--")
|
||||||
|
results = results.numpy()
|
||||||
|
print("--numpy--pass--")
|
||||||
|
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
|
||||||
|
# box = box.cuda()
|
||||||
|
# box = box.cpu()
|
||||||
|
# box = box.numpy()
|
||||||
|
# for b in box:
|
||||||
|
# print(b)
|
Loading…
Reference in new issue