Improve NMS speed (#3467)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Andy 1 year ago committed by GitHub
parent 586c95b8a9
commit 5e38c7d71b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -200,12 +200,16 @@ def non_max_suppression(
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS merge = False # use merge-NMS
prediction = prediction.clone() # don't modify original
prediction = prediction.transpose(-1, -2) # to (batch, boxes, items)
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
t = time.time() t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints # Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.transpose(0, -1)[xc[xi]] # confidence x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling # Cat apriori labels if autolabelling
if labels and len(labels[xi]): if labels and len(labels[xi]):
@ -221,9 +225,9 @@ def non_max_suppression(
# Detections matrix nx6 (xyxy, conf, cls) # Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1) box, cls, mask = x.split((4, nc, nm), 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
if multi_label: if multi_label:
i, j = (cls > conf_thres).nonzero(as_tuple=False).T i, j = torch.where(cls > conf_thres)
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only else: # best class only
conf, j = cls.max(1, keepdim=True) conf, j = cls.max(1, keepdim=True)
@ -241,6 +245,8 @@ def non_max_suppression(
n = x.shape[0] # number of boxes n = x.shape[0] # number of boxes
if not n: # no boxes if not n: # no boxes
continue continue
if n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS # Batched NMS

Loading…
Cancel
Save