ultralytics 8.0.52
reduced TAL CUDA usage and AMP check fix (#1333)
Co-authored-by: CNH5 <74132034+CNH5@users.noreply.github.com> Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com> Co-authored-by: Lorenzo Mammana <lorenzom96@hotmail.it> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hardik Dava <39372750+hardikdava@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -36,18 +36,26 @@ def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
# Linear assignment implementations with scipy and lap.lapjv
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
|
||||
# TODO: investigate scipy.optimize.linear_sum_assignment() for lap.lapjv()
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
if use_lap:
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
else:
|
||||
# Scipy linear sum assignment is NOT working correctly, DO NOT USE
|
||||
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
|
||||
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
|
||||
unmatched = np.ones(cost_matrix.shape)
|
||||
for i, xi in matches:
|
||||
unmatched[i, xi] = 0.0
|
||||
unmatched_a = np.where(unmatched.all(1))[0]
|
||||
unmatched_b = np.where(unmatched.all(0))[0]
|
||||
|
||||
matches.extend([ix, mx] for ix, mx in enumerate(x) if mx >= 0)
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user