Skip to content

feat(trackers): Implementation of Kalman Gating for DeepSORTTracker #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion trackers/core/deepsort/kalman_box_tracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Optional, Union
from typing import Optional, Tuple, Union

import numpy as np
from scipy.linalg import solve_triangular

# Chi-square 0.95 quantile for 4 degrees of freedom (Mahalanobis threshold)
MAHALANOBIS_THRESHOLD = 9.4877


class DeepSORTKalmanBoxTracker:
Expand Down Expand Up @@ -96,6 +100,58 @@ def _initialize_kalman_filter(self) -> None:
# Error covariance matrix (P)
self.P = np.eye(8, dtype=np.float32)

def project(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Projects the current state distribution to measurement space.

As per the Kalman Filter formulation mentioned implicitly in
Section 2.1 of the DeepSORT paper, this function computes:
(y_i, S_i) = (H·μ_i, H·Σ_i·H^T + R)

Returns:
Tuple[np.ndarray, np.ndarray]: Projected mean (y_i) and innovation
covariance (S_i) for gating and association.
"""
# Project state mean to measurement space: y_i = H·μ_i
projected_mean = self.H @ self.state

# Project state covariance to measurement space: H·Σ_i·H^T
projected_covariance = self.H @ self.P @ self.H.T

# Add measurement noise: S_i = H·Σ_i·H^T + R
innovation_covariance = projected_covariance + self.R

return projected_mean, innovation_covariance

def compute_gating_distance(self, measurements: np.ndarray) -> np.ndarray:
"""
Computes the squared Mahalanobis distance between the track and
measurements.

This function is used for gating (ruling out) unlikely associations
as described in Eq. (1)-(2) of the DeepSORT paper:
d^(1)(i,j) = (d_j - y_i)^T · S_i^(-1) · (d_j - y_i)

Args:
measurements (np.ndarray): An Nx4 matrix of N measurements, each in
format [x1, y1, x2, y2] representing detected bounding boxes.

Returns:
np.ndarray: An array of length N, where the i-th element contains the
squared Mahalanobis distance between the track and measurements[i].
"""
# Project current state to measurement space
mean, covariance = self.project()
mean = mean.reshape(1, 4)
cholesky_factor = np.linalg.cholesky(covariance)
d = measurements - mean
# Solve the system L·z = d^T efficiently using triangular solver
# This gives us z where z = L^(-1)·d^T
z = solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False)
# Compute squared Mahalanobis distance as the squared norm of z
# d_m^2 = z^T·z = d^T·S^(-1)·d
return np.sum(z * z, axis=0)

def predict(self) -> None:
"""
Predict the next state of the bounding box (applies the state transition).
Expand Down
187 changes: 158 additions & 29 deletions trackers/core/deepsort/tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import supervision as sv
Expand All @@ -7,8 +7,12 @@

from trackers.core.base import BaseTrackerWithFeatures
from trackers.core.deepsort.feature_extractor import DeepSORTFeatureExtractor
from trackers.core.deepsort.kalman_box_tracker import DeepSORTKalmanBoxTracker
from trackers.core.deepsort.kalman_box_tracker import (
MAHALANOBIS_THRESHOLD,
DeepSORTKalmanBoxTracker,
)
from trackers.utils.sort_utils import (
convert_bbox_to_xyah,
get_alive_trackers,
get_iou_matrix,
update_detections_with_track_ids,
Expand Down Expand Up @@ -152,6 +156,9 @@ def callback(frame: np.ndarray, _: int):
'kulczynski1', 'mahalanobis', 'matching', 'minkowski',
'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener',
'sokalsneath', 'sqeuclidean', 'yule'.
apply_kalman_gating (bool): Whether to apply Kalman gating to filter unlikely
track-detection associations based on Mahalanobis distance, as described
in Section 2.2 of the DeepSORT paper.
""" # noqa: E501

def __init__(
Expand All @@ -166,6 +173,7 @@ def __init__(
appearance_threshold: float = 0.7,
appearance_weight: float = 0.5,
distance_metric: str = "cosine",
apply_kalman_gating: bool = True,
):
self.feature_extractor = self._initialize_feature_extractor(
feature_extractor, device
Expand All @@ -179,6 +187,7 @@ def __init__(
self.appearance_threshold = appearance_threshold
self.appearance_weight = appearance_weight
self.distance_metric = distance_metric
self.apply_kalman_gating = apply_kalman_gating
# Calculate maximum frames without update based on lost_track_buffer and
# frame_rate. This scales the buffer based on the frame rate to ensure
# consistent time-based tracking across different frame rates.
Expand All @@ -198,7 +207,7 @@ def _initialize_feature_extractor(

Args:
feature_extractor: The feature extractor input, which can be a model path,
a torch module, or a DeepSORTFeatureExtractor instance.
a torch module, or a DeepSORTFeatureExtractor instance.
device: The device to run the model on.

Returns:
Expand Down Expand Up @@ -266,13 +275,78 @@ def _get_combined_distance_matrix(

return combined_dist

def _match_tracks_stage(
self,
cost_matrix: np.ndarray,
track_indices: list,
detection_indices: list,
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
"""
Match tracks with detections for a specific stage of the matching cascade.
This implements the linear assignment for a specific group of tracks based
on their maturity.

Args:
cost_matrix (np.ndarray): Cost matrix between tracks and detections.
track_indices (list): Indices of tracks to match.
detection_indices (list): Indices of detections to match.

Returns:
tuple[list[tuple[int, int]], list[int], list[int]]: Matched indices,
unmatched track indices, unmatched detection indices.
"""
if len(track_indices) == 0 or len(detection_indices) == 0:
return [], track_indices, detection_indices

sub_cost_matrix = cost_matrix[np.ix_(track_indices, detection_indices)]

# Apply threshold of 1.0 to mark infeasible associations
# Only consider associations where cost < 1.0
valid_mask = sub_cost_matrix < 1.0

if not np.any(valid_mask):
return [], track_indices, detection_indices

row_indices, col_indices = np.where(valid_mask)

indices = np.stack([row_indices, col_indices], axis=1)
indices = indices[np.argsort(sub_cost_matrix[row_indices, col_indices])]

matches = []
unmatched_tracks = list(track_indices)
unmatched_detections = list(detection_indices)

matched_track_indices = set()
matched_detection_indices = set()

for row, col in indices:
track_idx = track_indices[row]
detection_idx = detection_indices[col]

# Skip if either track or detection is already matched
if row in matched_track_indices or col in matched_detection_indices:
continue

matches.append((track_idx, detection_idx))
matched_track_indices.add(row)
matched_detection_indices.add(col)

if track_idx in unmatched_tracks:
unmatched_tracks.remove(track_idx)
if detection_idx in unmatched_detections:
unmatched_detections.remove(detection_idx)

return matches, unmatched_tracks, unmatched_detections

def _get_associated_indices(
self,
iou_matrix: np.ndarray,
detection_features: np.ndarray,
) -> tuple[list[tuple[int, int]], set[int], set[int]]:
"""
Associate detections to trackers based on both IOU and appearance.
Associate detections to trackers using a two-stage matching approach.
If `apply_kalman_gating` is enabled, an additional Mahalanobis distance filter
is applied to confirmed tracks to exclude unlikely associations.

Args:
iou_matrix (np.ndarray): IOU matrix between tracks and detections.
Expand All @@ -282,37 +356,92 @@ def _get_associated_indices(
tuple[list[tuple[int, int]], set[int], set[int]]: Matched indices,
unmatched trackers, unmatched detections.
"""
confirmed_tracks = []
unconfirmed_tracks = []
for tracker_idx, tracker in enumerate(self.trackers):
if tracker.number_of_successful_updates >= self.minimum_consecutive_frames:
confirmed_tracks.append(tracker_idx)
else:
unconfirmed_tracks.append(tracker_idx)

appearance_dist_matrix = self._get_appearance_distance_matrix(
detection_features
)
combined_dist = self._get_combined_distance_matrix(
combined_dist_matrix = self._get_combined_distance_matrix(
iou_matrix, appearance_dist_matrix
)
matched_indices = []
unmatched_trackers = set(range(len(self.trackers)))
unmatched_detections = set(range(len(detection_features)))

if combined_dist.size > 0:
row_indices, col_indices = np.where(combined_dist < 1.0)
sorted_pairs = sorted(
zip(map(int, row_indices), map(int, col_indices)),
key=lambda x: combined_dist[x[0], x[1]],

if self.apply_kalman_gating:
for tracker_idx in confirmed_tracks:
feasible_detections = np.where(combined_dist_matrix[tracker_idx] < 1.0)[
0
]
if len(feasible_detections) > 0:
measurements = np.array(
[
convert_bbox_to_xyah(
self.trackers[tracker_idx].get_state_bbox()
)
for _ in feasible_detections
]
)

gating_distances = self.trackers[
tracker_idx
].compute_gating_distance(measurements=measurements)

for j, det_idx in enumerate(feasible_detections):
if gating_distances[j] > MAHALANOBIS_THRESHOLD:
combined_dist_matrix[tracker_idx, det_idx] = (
1.0 # Mark as infeasible
)

confirmed_matches, unmatched_confirmed, unmatched_detections = (
self._match_tracks_stage(
combined_dist_matrix,
confirmed_tracks,
list(range(len(detection_features))),
)
)

used_rows = set()
used_cols = set()
for row, col in sorted_pairs:
if (row not in used_rows) and (col not in used_cols):
used_rows.add(row)
used_cols.add(col)
matched_indices.append((row, col))
# Find recently lost confirmed tracks (time_since_update == 1)
recently_lost = [
tracker_idx
for tracker_idx in unmatched_confirmed
if self.trackers[tracker_idx].time_since_update == 1
]

# Remove recently lost from unmatched_confirmed
unmatched_confirmed = [
tracker_idx
for tracker_idx in unmatched_confirmed
if tracker_idx not in recently_lost
]

iou_track_candidates = unconfirmed_tracks + recently_lost

# Match remaining tracks using IoU only
iou_matches: list[tuple[int, int]] = []
if iou_track_candidates and unmatched_detections:
iou_dist_matrix = 1 - iou_matrix
iou_dist_matrix_filtered = iou_dist_matrix.copy()
mask = iou_matrix < self.minimum_iou_threshold
iou_dist_matrix_filtered[mask] = 1.0

iou_matches, unmatched_candidates, unmatched_detections = (
self._match_tracks_stage(
iou_dist_matrix_filtered,
iou_track_candidates,
list(unmatched_detections),
)
)
else:
unmatched_candidates = iou_track_candidates

unmatched_trackers = unmatched_trackers - {int(row) for row in used_rows}
unmatched_detections = unmatched_detections - {
int(col) for col in used_cols
}
matches = confirmed_matches + iou_matches
unmatched_tracks = set(unmatched_confirmed).union(set(unmatched_candidates))

return matched_indices, unmatched_trackers, unmatched_detections
return matches, unmatched_tracks, set(unmatched_detections)

def _spawn_new_trackers(
self,
Expand Down Expand Up @@ -385,9 +514,9 @@ def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections:
trackers=self.trackers, detection_boxes=detection_boxes
)

# Associate detections to trackers based on IOU
matched_indices, _, unmatched_detections = self._get_associated_indices(
iou_matrix, detection_features
# Associate detections to trackers using the two-stage matching approach
matched_indices, unmatched_tracks, unmatched_detections = (
self._get_associated_indices(iou_matrix, detection_features)
)

# Update matched trackers with assigned detections
Expand Down
23 changes: 23 additions & 0 deletions trackers/utils/sort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,26 @@ def update_detections_with_track_ids(
updated_detections.tracker_id = np.array(final_tracker_ids)

return updated_detections


def convert_bbox_to_xyah(state_bbox: np.ndarray) -> np.ndarray:
"""
Convert bounding box into measurement space to format
`(center x, center y, aspect ratio, height)`,
where the aspect ratio is `width / height`.

Args:
state_bbox (np.ndarray): Bounding box in format
`(x1, y1, x2, y2)`.

Returns:
np.ndarray: Bounding box in format
`(center x, center y, aspect ratio, height)`.
"""
x1, y1, x2, y2 = state_bbox
width = x2 - x1
height = y2 - y1
center_x = x1 + width / 2
center_y = y1 + height / 2
aspect_ratio = width / height if height > 0 else 1.0
return np.array([center_x, center_y, aspect_ratio, height])