|
|
@ -34,19 +34,21 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
|
|
|
|
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
|
|
|
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
|
|
|
|
|
|
|
|
|
|
|
if use_lap:
|
|
|
|
if use_lap:
|
|
|
|
|
|
|
|
# https://github.com/gatagat/lap
|
|
|
|
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
|
|
|
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
|
|
|
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
|
|
|
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
|
|
|
unmatched_a = np.where(x < 0)[0]
|
|
|
|
unmatched_a = np.where(x < 0)[0]
|
|
|
|
unmatched_b = np.where(y < 0)[0]
|
|
|
|
unmatched_b = np.where(y < 0)[0]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Scipy linear sum assignment is NOT working correctly, DO NOT USE
|
|
|
|
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
|
|
|
|
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
|
|
|
|
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
|
|
|
|
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
|
|
|
|
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
|
|
|
|
unmatched = np.ones(cost_matrix.shape)
|
|
|
|
if len(matches) == 0:
|
|
|
|
for i, xi in matches:
|
|
|
|
unmatched_a = list(np.arange(cost_matrix.shape[0]))
|
|
|
|
unmatched[i, xi] = 0.0
|
|
|
|
unmatched_b = list(np.arange(cost_matrix.shape[1]))
|
|
|
|
unmatched_a = np.where(unmatched.all(1))[0]
|
|
|
|
else:
|
|
|
|
unmatched_b = np.where(unmatched.all(0))[0]
|
|
|
|
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0]))
|
|
|
|
|
|
|
|
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1]))
|
|
|
|
|
|
|
|
|
|
|
|
return matches, unmatched_a, unmatched_b
|
|
|
|
return matches, unmatched_a, unmatched_b
|
|
|
|
|
|
|
|
|
|
|
|