You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
307 lines
11 KiB
307 lines
11 KiB
import contextlib
|
|
import math
|
|
import time
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
|
|
from ultralytics.yolo.utils import LOGGER
|
|
|
|
from .metrics import box_iou
|
|
|
|
|
|
class Profile(contextlib.ContextDecorator):
|
|
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
|
def __init__(self, t=0.0):
|
|
self.t = t
|
|
self.cuda = torch.cuda.is_available()
|
|
|
|
def __enter__(self):
|
|
self.start = self.time()
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.dt = self.time() - self.start # delta-time
|
|
self.t += self.dt # accumulate dt
|
|
|
|
def time(self):
|
|
if self.cuda:
|
|
torch.cuda.synchronize()
|
|
return time.time()
|
|
|
|
|
|
def segment2box(segment, width=640, height=640):
|
|
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
|
x, y = segment.T # segment xy
|
|
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
|
x, y, = (
|
|
x[inside],
|
|
y[inside],
|
|
)
|
|
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy
|
|
|
|
|
|
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
|
# Rescale boxes (xyxy) from img1_shape to img0_shape
|
|
if ratio_pad is None: # calculate from img0_shape
|
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
|
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
|
else:
|
|
gain = ratio_pad[0][0]
|
|
pad = ratio_pad[1]
|
|
|
|
boxes[:, [0, 2]] -= pad[0] # x padding
|
|
boxes[:, [1, 3]] -= pad[1] # y padding
|
|
boxes[:, :4] /= gain
|
|
clip_boxes(boxes, img0_shape)
|
|
return boxes
|
|
|
|
|
|
def clip_boxes(boxes, shape):
|
|
# Clip boxes (xyxy) to image shape (height, width)
|
|
if isinstance(boxes, torch.Tensor): # faster individually
|
|
boxes[:, 0].clamp_(0, shape[1]) # x1
|
|
boxes[:, 1].clamp_(0, shape[0]) # y1
|
|
boxes[:, 2].clamp_(0, shape[1]) # x2
|
|
boxes[:, 3].clamp_(0, shape[0]) # y2
|
|
else: # np.array (faster grouped)
|
|
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
|
|
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
|
|
|
|
|
def make_divisible(x, divisor):
|
|
# Returns nearest x divisible by divisor
|
|
if isinstance(divisor, torch.Tensor):
|
|
divisor = int(divisor.max()) # to int
|
|
return math.ceil(x / divisor) * divisor
|
|
|
|
|
|
def non_max_suppression(
|
|
prediction,
|
|
conf_thres=0.25,
|
|
iou_thres=0.45,
|
|
classes=None,
|
|
agnostic=False,
|
|
multi_label=False,
|
|
labels=(),
|
|
max_det=300,
|
|
nm=0, # number of masks
|
|
):
|
|
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
|
Returns:
|
|
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
|
"""
|
|
|
|
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
|
prediction = prediction[0] # select only inference output
|
|
|
|
device = prediction.device
|
|
mps = 'mps' in device.type # Apple MPS
|
|
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
|
prediction = prediction.cpu()
|
|
bs = prediction.shape[0] # batch size
|
|
nc = prediction.shape[2] - nm - 5 # number of classes
|
|
xc = prediction[..., 4] > conf_thres # candidates
|
|
|
|
# Checks
|
|
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
|
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
|
|
|
# Settings
|
|
# min_wh = 2 # (pixels) minimum box width and height
|
|
max_wh = 7680 # (pixels) maximum box width and height
|
|
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
|
time_limit = 0.5 + 0.05 * bs # seconds to quit after
|
|
redundant = True # require redundant detections
|
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
|
merge = False # use merge-NMS
|
|
|
|
t = time.time()
|
|
mi = 5 + nc # mask start index
|
|
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
|
for xi, x in enumerate(prediction): # image index, image inference
|
|
# Apply constraints
|
|
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
|
x = x[xc[xi]] # confidence
|
|
|
|
# Cat apriori labels if autolabelling
|
|
if labels and len(labels[xi]):
|
|
lb = labels[xi]
|
|
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
|
v[:, :4] = lb[:, 1:5] # box
|
|
v[:, 4] = 1.0 # conf
|
|
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
|
|
x = torch.cat((x, v), 0)
|
|
|
|
# If none remain process next image
|
|
if not x.shape[0]:
|
|
continue
|
|
|
|
# Compute conf
|
|
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
|
|
|
# Box/Mask
|
|
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
|
mask = x[:, mi:] # zero columns if no masks
|
|
|
|
# Detections matrix nx6 (xyxy, conf, cls)
|
|
if multi_label:
|
|
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
|
|
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
|
|
else: # best class only
|
|
conf, j = x[:, 5:mi].max(1, keepdim=True)
|
|
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
|
|
|
# Filter by class
|
|
if classes is not None:
|
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
|
|
|
# Apply finite constraint
|
|
# if not torch.isfinite(x).all():
|
|
# x = x[torch.isfinite(x).all(1)]
|
|
|
|
# Check shape
|
|
n = x.shape[0] # number of boxes
|
|
if not n: # no boxes
|
|
continue
|
|
elif n > max_nms: # excess boxes
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
|
else:
|
|
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
|
|
|
|
# Batched NMS
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
|
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
|
if i.shape[0] > max_det: # limit detections
|
|
i = i[:max_det]
|
|
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
|
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
|
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
|
weights = iou * scores[None] # box weights
|
|
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
|
if redundant:
|
|
i = i[iou.sum(1) > 1] # require redundancy
|
|
|
|
output[xi] = x[i]
|
|
if mps:
|
|
output[xi] = output[xi].to(device)
|
|
if (time.time() - t) > time_limit:
|
|
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
|
break # time limit exceeded
|
|
|
|
return output
|
|
|
|
|
|
def clip_coords(boxes, shape):
|
|
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
|
if isinstance(boxes, torch.Tensor): # faster individually
|
|
boxes[:, 0].clamp_(0, shape[1]) # x1
|
|
boxes[:, 1].clamp_(0, shape[0]) # y1
|
|
boxes[:, 2].clamp_(0, shape[1]) # x2
|
|
boxes[:, 3].clamp_(0, shape[0]) # y2
|
|
else: # np.array (faster grouped)
|
|
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
|
|
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
|
|
|
|
|
def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
|
|
"""
|
|
img1_shape: model input shape, [h, w]
|
|
img0_shape: origin pic shape, [h, w, 3]
|
|
masks: [h, w, num]
|
|
"""
|
|
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
|
if ratio_pad is None: # calculate from im0_shape
|
|
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
|
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
|
else:
|
|
pad = ratio_pad[1]
|
|
top, left = int(pad[1]), int(pad[0]) # y, x
|
|
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
|
|
|
if len(masks.shape) < 2:
|
|
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
|
masks = masks[top:bottom, left:right]
|
|
# masks = masks.permute(2, 0, 1).contiguous()
|
|
# masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
|
|
# masks = masks.permute(1, 2, 0).contiguous()
|
|
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
|
|
|
|
if len(masks.shape) == 2:
|
|
masks = masks[:, :, None]
|
|
return masks
|
|
|
|
|
|
def xyxy2xywh(x):
|
|
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
|
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
|
y[:, 2] = x[:, 2] - x[:, 0] # width
|
|
y[:, 3] = x[:, 3] - x[:, 1] # height
|
|
return y
|
|
|
|
|
|
def xywh2xyxy(x):
|
|
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
|
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
|
return y
|
|
|
|
|
|
def xywh2ltwh(x):
|
|
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, w, h] where xy1=top-left
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
|
return y
|
|
|
|
|
|
def xyxy2ltwh(x):
|
|
# Convert nx4 boxes from [x1, y1, x2, y2] to [x1, y1, w, h] where xy1=top-left, xy2=bottom-right
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[:, 2] = x[:, 2] - x[:, 0] # width
|
|
y[:, 3] = x[:, 3] - x[:, 1] # height
|
|
return y
|
|
|
|
|
|
def ltwh2xywh(x):
|
|
# Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
|
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
|
return y
|
|
|
|
|
|
def ltwh2xyxy(x):
|
|
# Convert nx4 boxes from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[:, 2] = x[:, 2] + x[:, 0] # width
|
|
y[:, 3] = x[:, 3] + x[:, 1] # height
|
|
return y
|
|
|
|
|
|
def segments2boxes(segments):
|
|
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
|
boxes = []
|
|
for s in segments:
|
|
x, y = s.T # segment xy
|
|
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
|
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
|
|
|
|
|
def resample_segments(segments, n=1000):
|
|
# Up-sample an (n,2) segment
|
|
for i, s in enumerate(segments):
|
|
s = np.concatenate((s, s[0:1, :]), axis=0)
|
|
x = np.linspace(0, len(s) - 1, n)
|
|
xp = np.arange(len(s))
|
|
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
|
return segments
|