New YOLOv8 Results() class for prediction outputs (#314)

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Viet Nhat Thai <60825385+vietnhatthai@users.noreply.github.com>
Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-17 19:02:34 +05:30
committed by GitHub
parent 0cb87f7dd3
commit c6985da9de
32 changed files with 813 additions and 259 deletions

View File

@ -5,7 +5,6 @@ import inspect
import logging.config
import os
import platform
import subprocess
import sys
import tempfile
import threading
@ -13,6 +12,7 @@ import uuid
from pathlib import Path
import cv2
import git
import numpy as np
import pandas as pd
import torch
@ -134,10 +134,8 @@ def is_git_directory() -> bool:
Returns:
bool: True if the current working directory is inside a git repository, False otherwise.
"""
import git
try:
from git import Repo
Repo(search_parent_directories=True)
git.Repo(search_parent_directories=True)
# subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) # CLI alternative
return True
except git.exc.InvalidGitRepositoryError: # subprocess.CalledProcessError:
@ -187,9 +185,10 @@ def get_git_root_dir():
If the current file is not part of a git repository, returns None.
"""
try:
output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # parent/.git
except subprocess.CalledProcessError:
# output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
# return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # CLI alternative
return Path(git.Repo(search_parent_directories=True).working_tree_dir)
except git.exc.InvalidGitRepositoryError: # (subprocess.CalledProcessError, FileNotFoundError):
return None

View File

@ -15,20 +15,39 @@ from .metrics import box_iou
class Profile(contextlib.ContextDecorator):
# YOLOv8 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
"""
YOLOv8 Profile class.
Usage: as a decorator with @Profile() or as a context manager with 'with Profile():'
"""
def __init__(self, t=0.0):
"""
Initialize the Profile class.
Args:
t (float): Initial time. Defaults to 0.0.
"""
self.t = t
self.cuda = torch.cuda.is_available()
def __enter__(self):
"""
Start timing.
"""
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
"""
Stop timing.
"""
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
"""
Get current time.
"""
if self.cuda:
torch.cuda.synchronize()
return time.time()
@ -48,15 +67,15 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
def segment2box(segment, width=640, height=640):
"""
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to
(xyxy)
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
Args:
segment (torch.tensor): the segment label
segment (torch.Tensor): the segment label
width (int): the width of the image. Defaults to 640
height (int): The height of the image. Defaults to 640
Returns:
(np.array): the minimum and maximum x and y values of the segment.
(np.ndarray): the minimum and maximum x and y values of the segment.
"""
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
@ -67,15 +86,18 @@ def segment2box(segment, width=640, height=640):
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
"""
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in (img1_shape) to the shape of a different image (img0_shape).
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
(img1_shape) to the shape of a different image (img0_shape).
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
boxes (torch.tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be calculated based on the size difference between the two images.
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
Returns:
boxes (torch.tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@ -92,7 +114,16 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
def make_divisible(x, divisor):
# Returns nearest x divisible by divisor
"""
Returns the nearest number that is divisible by the given divisor.
Args:
x (int): The number to make divisible.
divisor (int or torch.Tensor): The divisor.
Returns:
int: The nearest number divisible by the divisor.
"""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
@ -232,7 +263,7 @@ def clip_boxes(boxes, shape):
shape
Args:
boxes (torch.tensor): the bounding boxes to clip
boxes (torch.Tensor): the bounding boxes to clip
shape (tuple): the shape of the image
"""
if isinstance(boxes, torch.Tensor): # faster individually
@ -246,7 +277,19 @@ def clip_boxes(boxes, shape):
def clip_coords(boxes, shape):
# Clip bounding xyxy bounding boxes to image shape (height, width)
"""
Clip bounding xyxy bounding boxes to image shape (height, width).
Args:
boxes (torch.Tensor or numpy.ndarray): Bounding boxes to be clipped.
shape (tuple): The shape of the image. (height, width)
Returns:
None
Note:
The input `boxes` is modified in-place, there is no return value.
"""
if isinstance(boxes, torch.Tensor): # faster individually
boxes[:, 0].clamp_(0, shape[1]) # x1
boxes[:, 1].clamp_(0, shape[0]) # y1
@ -263,12 +306,12 @@ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
Args:
im1_shape (tuple): model input shape, [h, w]
masks (torch.tensor): [h, w, num]
masks (torch.Tensor): [h, w, num]
im0_shape (tuple): the original image shape
ratio_pad (tuple): the ratio of the padding to the original image.
Returns:
masks (torch.tensor): The masks that are being returned.
masks (torch.Tensor): The masks that are being returned.
"""
# Rescale coordinates (xyxy) from im1_shape to im0_shape
if ratio_pad is None: # calculate from im0_shape
@ -297,9 +340,9 @@ def xyxy2xywh(x):
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
Args:
x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x1, y1, x2, y2) format.
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
Returns:
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
@ -311,12 +354,13 @@ def xyxy2xywh(x):
def xywh2xyxy(x):
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
top-left corner and (x2, y2) is the bottom-right corner.
Args:
x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x, y, width, height) format.
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
Returns:
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
@ -337,7 +381,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
padw (int): Padding width. Defaults to 0
padh (int): Padding height. Defaults to 0
Returns:
y (numpy.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
y (np.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
@ -349,16 +394,17 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, width and height are normalized to image dimensions
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
x, y, width and height are normalized to image dimensions
Args:
x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x1, y1, x2, y2) format.
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
w (int): The width of the image. Defaults to 640
h (int): The height of the image. Defaults to 640
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
eps (float): The minimum value of the box's width and height. Defaults to 0.0
Returns:
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
"""
if clip:
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
@ -375,13 +421,13 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
Convert normalized coordinates to pixel coordinates of shape (n,2)
Args:
x (numpy.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates
x (np.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates
w (int): The width of the image. Defaults to 640
h (int): The height of the image. Defaults to 640
padw (int): The width of the padding. Defaults to 0
padh (int): The height of the padding. Defaults to 0
Returns:
y (numpy.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box
y (np.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = w * x[..., 0] + padw # top left x
@ -394,9 +440,9 @@ def xywh2ltwh(x):
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
Args:
x (numpy.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
x (np.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
Returns:
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
@ -409,9 +455,9 @@ def xyxy2ltwh(x):
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
Args:
x (numpy.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
x (np.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
Returns:
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format.
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 2] = x[:, 2] - x[:, 0] # width
@ -424,7 +470,7 @@ def ltwh2xywh(x):
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
Args:
x (torch.tensor): the input tensor
x (torch.Tensor): the input tensor
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
@ -437,10 +483,10 @@ def ltwh2xyxy(x):
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
Args:
x (numpy.ndarray) or (torch.Tensor): the input image
x (np.ndarray) or (torch.Tensor): the input image
Returns:
y (numpy.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes.
y (np.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 2] = x[:, 2] + x[:, 0] # width
@ -456,7 +502,7 @@ def segments2boxes(segments):
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
Returns:
(np.array): the xywh coordinates of the bounding boxes.
(np.ndarray): the xywh coordinates of the bounding boxes.
"""
boxes = []
for s in segments:
@ -467,7 +513,7 @@ def segments2boxes(segments):
def resample_segments(segments, n=1000):
"""
It takes a list of segments (n,2) and returns a list of segments (n,2) where each segment has been up-sampled to n points
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
Args:
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
@ -489,11 +535,11 @@ def crop_mask(masks, boxes):
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
Args:
masks (torch.tensor): [h, w, n] tensor of masks
boxes (torch.tensor): [n, 4] tensor of bbox coordinates in relative point form
masks (torch.Tensor): [h, w, n] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns:
(torch.tensor): The masks are being cropped to the bounding box.
(torch.Tensor): The masks are being cropped to the bounding box.
"""
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
@ -509,13 +555,13 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
quality but is slower.
Args:
protos (torch.tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.tensor): [n, 4], n is number of masks after nms
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
(torch.tensor): The upsampled masks.
(torch.Tensor): The upsampled masks.
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
@ -530,13 +576,13 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
downsampled quality of mask
Args:
protos (torch.tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.tensor): [n, 4], n is number of masks after nms
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
(torch.tensor): The processed masks.
(torch.Tensor): The processed masks.
"""
c, mh, mw = protos.shape # CHW
@ -560,13 +606,13 @@ def process_mask_native(protos, masks_in, bboxes, shape):
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
Args:
protos (torch.tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.tensor): [n, 4], n is number of masks after nms
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
masks (torch.tensor): The returned masks with dimensions [h, w, n]
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
@ -587,13 +633,13 @@ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=F
Args:
img1_shape (tuple): The shape of the image that the segments are from.
segments (torch.tensor): the segments to be scaled
segments (torch.Tensor): the segments to be scaled
img0_shape (tuple): the shape of the image that the segmentation is being applied to
ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
Returns:
segments (torch.tensor): the segmented image.
segments (torch.Tensor): the segmented image.
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@ -617,7 +663,7 @@ def masks2segments(masks, strategy='largest'):
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
Args:
masks (torch.tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
strategy (str): 'concat' or 'largest'. Defaults to largest
Returns: