ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
0
ultralytics/trackers/utils/__init__.py
Normal file
0
ultralytics/trackers/utils/__init__.py
Normal file
319
ultralytics/trackers/utils/gmc.py
Normal file
319
ultralytics/trackers/utils/gmc.py
Normal file
@ -0,0 +1,319 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import copy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
|
||||
class GMC:
|
||||
|
||||
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
|
||||
"""Initialize a video tracker with specified parameters."""
|
||||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
|
||||
if self.method == 'orb':
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
self.extractor = cv2.ORB_create()
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
||||
|
||||
elif self.method == 'sift':
|
||||
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
||||
|
||||
elif self.method == 'ecc':
|
||||
number_of_iterations = 5000
|
||||
termination_eps = 1e-6
|
||||
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
||||
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
||||
|
||||
elif self.method == 'sparseOptFlow':
|
||||
self.feature_params = dict(maxCorners=1000,
|
||||
qualityLevel=0.01,
|
||||
minDistance=1,
|
||||
blockSize=3,
|
||||
useHarrisDetector=False,
|
||||
k=0.04)
|
||||
# self.gmc_file = open('GMC_results.txt', 'w')
|
||||
|
||||
elif self.method in ['file', 'files']:
|
||||
seqName = verbose[0]
|
||||
ablation = verbose[1]
|
||||
if ablation:
|
||||
filePath = r'tracker/GMC_files/MOT17_ablation'
|
||||
else:
|
||||
filePath = r'tracker/GMC_files/MOTChallenge'
|
||||
|
||||
if '-FRCNN' in seqName:
|
||||
seqName = seqName[:-6]
|
||||
elif '-DPM' in seqName or '-SDP' in seqName:
|
||||
seqName = seqName[:-4]
|
||||
self.gmcFile = open(f'{filePath}/GMC-{seqName}.txt')
|
||||
|
||||
if self.gmcFile is None:
|
||||
raise ValueError(f'Error: Unable to open GMC file in directory:{filePath}')
|
||||
elif self.method in ['none', 'None']:
|
||||
self.method = 'none'
|
||||
else:
|
||||
raise ValueError(f'Error: Unknown CMC method:{method}')
|
||||
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
|
||||
self.initializedFirstFrame = False
|
||||
|
||||
def apply(self, raw_frame, detections=None):
|
||||
"""Apply object detection on a raw frame using specified method."""
|
||||
if self.method in ['orb', 'sift']:
|
||||
return self.applyFeatures(raw_frame, detections)
|
||||
elif self.method == 'ecc':
|
||||
return self.applyEcc(raw_frame, detections)
|
||||
elif self.method == 'sparseOptFlow':
|
||||
return self.applySparseOptFlow(raw_frame, detections)
|
||||
elif self.method == 'file':
|
||||
return self.applyFile(raw_frame, detections)
|
||||
elif self.method == 'none':
|
||||
return np.eye(2, 3)
|
||||
else:
|
||||
return np.eye(2, 3)
|
||||
|
||||
def applyEcc(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3, dtype=np.float32)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
||||
try:
|
||||
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
|
||||
|
||||
return H
|
||||
|
||||
def applyFeatures(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Find the keypoints
|
||||
mask = np.zeros_like(frame)
|
||||
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
|
||||
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
||||
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
|
||||
|
||||
keypoints = self.detector.detect(frame, mask)
|
||||
|
||||
# Compute the descriptors
|
||||
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Match descriptors.
|
||||
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
||||
|
||||
# Filtered matches based on smallest spatial distance
|
||||
matches = []
|
||||
spatialDistances = []
|
||||
|
||||
maxSpatialDistance = 0.25 * np.array([width, height])
|
||||
|
||||
# Handle empty matches case
|
||||
if len(knnMatches) == 0:
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
for m, n in knnMatches:
|
||||
if m.distance < 0.9 * n.distance:
|
||||
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
||||
currKeyPointLocation = keypoints[m.trainIdx].pt
|
||||
|
||||
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1])
|
||||
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
|
||||
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
|
||||
spatialDistances.append(spatialDistance)
|
||||
matches.append(m)
|
||||
|
||||
meanSpatialDistances = np.mean(spatialDistances, 0)
|
||||
stdSpatialDistances = np.std(spatialDistances, 0)
|
||||
|
||||
inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
||||
|
||||
goodMatches = []
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
for i in range(len(matches)):
|
||||
if inliers[i, 0] and inliers[i, 1]:
|
||||
goodMatches.append(matches[i])
|
||||
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
|
||||
currPoints.append(keypoints[matches[i].trainIdx].pt)
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Draw the keypoint matches on the output image
|
||||
# if False:
|
||||
# import matplotlib.pyplot as plt
|
||||
# matches_img = np.hstack((self.prevFrame, frame))
|
||||
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
||||
# W = np.size(self.prevFrame, 1)
|
||||
# for m in goodMatches:
|
||||
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
||||
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
||||
# curr_pt[0] += W
|
||||
# color = np.random.randint(0, 255, 3)
|
||||
# color = (int(color[0]), int(color[1]), int(color[2]))
|
||||
#
|
||||
# matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
||||
# matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
||||
# matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
||||
#
|
||||
# plt.figure()
|
||||
# plt.imshow(matches_img)
|
||||
# plt.show()
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning('WARNING: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
def applySparseOptFlow(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
# t0 = time.time()
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
|
||||
# Find the keypoints
|
||||
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Find correspondences
|
||||
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
|
||||
# Leave good correspondences only
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
|
||||
for i in range(len(status)):
|
||||
if status[i]:
|
||||
prevPoints.append(self.prevKeyPoints[i])
|
||||
currPoints.append(matchedKeypoints[i])
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning('WARNING: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# gmc_line = str(1000 * (time.time() - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
|
||||
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
|
||||
# self.gmc_file.write(gmc_line)
|
||||
|
||||
return H
|
||||
|
||||
def applyFile(self, raw_frame, detections=None):
|
||||
"""Return the homography matrix based on the GCPs in the next line of the input GMC file."""
|
||||
line = self.gmcFile.readline()
|
||||
tokens = line.split('\t')
|
||||
H = np.eye(2, 3, dtype=np.float_)
|
||||
H[0, 0] = float(tokens[1])
|
||||
H[0, 1] = float(tokens[2])
|
||||
H[0, 2] = float(tokens[3])
|
||||
H[1, 0] = float(tokens[4])
|
||||
H[1, 1] = float(tokens[5])
|
||||
H[1, 2] = float(tokens[6])
|
||||
|
||||
return H
|
462
ultralytics/trackers/utils/kalman_filter.py
Normal file
462
ultralytics/trackers/utils/kalman_filter.py
Normal file
@ -0,0 +1,462 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
# Table for the 0.95 quantile of the chi-square distribution with N degrees of freedom (contains values for N=1, ..., 9)
|
||||
# Taken from MATLAB/Octave's chi2inv function and used as Mahalanobis gating threshold.
|
||||
chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
|
||||
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
# mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
|
||||
|
||||
class KalmanFilterXYWH:
|
||||
"""
|
||||
For BoT-SORT
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, w, h, vx, vy, vw, vh
|
||||
|
||||
contains the bounding box center position (x, y), width w, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, w, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainties."""
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
||||
width w, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
||||
is the center position, w the width, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
229
ultralytics/trackers/utils/matching.py
Normal file
229
ultralytics/trackers/utils/matching.py
Normal file
@ -0,0 +1,229 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from .kalman_filter import chi2inv95
|
||||
|
||||
try:
|
||||
import lap # for linear_assignment
|
||||
|
||||
assert lap.__version__ # verify package is not directory
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
|
||||
import lap
|
||||
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
"""Merge two sets of matches and return matched and unmatched indices."""
|
||||
O, P, Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
||||
|
||||
mask = M1 * M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - {i for i, j in match})
|
||||
unmatched_Q = tuple(set(range(Q)) - {j for i, j in match})
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
"""_indices_to_matches: Return matched and unmatched indices given a cost matrix, indices, and a threshold."""
|
||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||
matched_mask = (matched_cost <= thresh)
|
||||
|
||||
matches = indices[matched_mask]
|
||||
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
"""Linear assignment implementations with scipy and lap.lapjv."""
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
|
||||
if use_lap:
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
else:
|
||||
# Scipy linear sum assignment is NOT working correctly, DO NOT USE
|
||||
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
|
||||
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
|
||||
unmatched = np.ones(cost_matrix.shape)
|
||||
for i, xi in matches:
|
||||
unmatched[i, xi] = 0.0
|
||||
unmatched_a = np.where(unmatched.all(1))[0]
|
||||
unmatched_b = np.where(unmatched.all(0))[0]
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if ious.size == 0:
|
||||
return ious
|
||||
|
||||
ious = bbox_ious(np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32))
|
||||
return ious
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
return 1 - _ious # cost matrix
|
||||
|
||||
|
||||
def v_iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
||||
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
return 1 - _ious # cost matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[STrack]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
||||
# for i, track in enumerate(tracks):
|
||||
# cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
"""Apply gating to the cost matrix based on predicted tracks and detected objects."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
||||
"""Fuse motion between tracks and detections with gating and Kalman filtering."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_iou(cost_matrix, tracks, detections):
|
||||
"""Fuses ReID and IoU similarity matrices to yield a cost matrix for object tracking."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
reid_sim = 1 - cost_matrix
|
||||
iou_dist = iou_distance(tracks, detections)
|
||||
iou_sim = 1 - iou_dist
|
||||
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
||||
# det_scores = np.array([det.score for det in detections])
|
||||
# det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
return 1 - fuse_sim # fuse cost
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
"""Fuses cost matrix with detection scores to produce a single similarity matrix."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
return 1 - fuse_sim # fuse_cost
|
||||
|
||||
|
||||
def bbox_ious(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Calculate the Intersection over Union (IoU) between pairs of bounding boxes.
|
||||
|
||||
Args:
|
||||
box1 (np.array): A numpy array of shape (n, 4) representing 'n' bounding boxes.
|
||||
Each row is in the format (x1, y1, x2, y2).
|
||||
box2 (np.array): A numpy array of shape (m, 4) representing 'm' bounding boxes.
|
||||
Each row is in the format (x1, y1, x2, y2).
|
||||
eps (float, optional): A small constant to prevent division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(np.array): A numpy array of shape (n, m) representing the IoU scores for each pair
|
||||
of bounding boxes from box1 and box2.
|
||||
|
||||
Note:
|
||||
The bounding box coordinates are expected to be in the format (x1, y1, x2, y2).
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
|
||||
return inter_area / (box2_area + box1_area[:, None] - inter_area + eps)
|
Reference in New Issue
Block a user