|
|
@ -200,8 +200,7 @@ def non_max_suppression(
|
|
|
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
|
|
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
|
|
|
merge = False # use merge-NMS
|
|
|
|
merge = False # use merge-NMS
|
|
|
|
|
|
|
|
|
|
|
|
prediction = prediction.clone() # don't modify original
|
|
|
|
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
|
|
|
prediction = prediction.transpose(-1, -2) # to (batch, boxes, items)
|
|
|
|
|
|
|
|
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
|
|
|
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
|
|
|
|
|
|
|
|
|
|
|
t = time.time()
|
|
|
|
t = time.time()
|
|
|
@ -245,7 +244,6 @@ def non_max_suppression(
|
|
|
|
n = x.shape[0] # number of boxes
|
|
|
|
n = x.shape[0] # number of boxes
|
|
|
|
if not n: # no boxes
|
|
|
|
if not n: # no boxes
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if n > max_nms: # excess boxes
|
|
|
|
if n > max_nms: # excess boxes
|
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
|
|
|
|
|
|
|
|
|
|
|