|
|
|
@ -200,12 +200,16 @@ def non_max_suppression(
|
|
|
|
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
|
|
|
|
merge = False # use merge-NMS
|
|
|
|
|
|
|
|
|
|
prediction = prediction.clone() # don't modify original
|
|
|
|
|
prediction = prediction.transpose(-1, -2) # to (batch, boxes, items)
|
|
|
|
|
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
|
|
|
|
|
|
|
|
|
t = time.time()
|
|
|
|
|
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
|
|
|
|
for xi, x in enumerate(prediction): # image index, image inference
|
|
|
|
|
# Apply constraints
|
|
|
|
|
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
|
|
|
|
x = x.transpose(0, -1)[xc[xi]] # confidence
|
|
|
|
|
x = x[xc[xi]] # confidence
|
|
|
|
|
|
|
|
|
|
# Cat apriori labels if autolabelling
|
|
|
|
|
if labels and len(labels[xi]):
|
|
|
|
@ -221,9 +225,9 @@ def non_max_suppression(
|
|
|
|
|
|
|
|
|
|
# Detections matrix nx6 (xyxy, conf, cls)
|
|
|
|
|
box, cls, mask = x.split((4, nc, nm), 1)
|
|
|
|
|
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
|
|
|
|
|
|
|
|
|
if multi_label:
|
|
|
|
|
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
|
|
|
|
|
i, j = torch.where(cls > conf_thres)
|
|
|
|
|
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
|
|
|
|
else: # best class only
|
|
|
|
|
conf, j = cls.max(1, keepdim=True)
|
|
|
|
@ -241,7 +245,9 @@ def non_max_suppression(
|
|
|
|
|
n = x.shape[0] # number of boxes
|
|
|
|
|
if not n: # no boxes
|
|
|
|
|
continue
|
|
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
|
|
|
|
|
|
|
|
|
if n > max_nms: # excess boxes
|
|
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
|
|
|
|
|
|
|
|
|
# Batched NMS
|
|
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
|
|
|
|