Skip to content
Snippets Groups Projects
Commit a8f3128c authored by Anton Kullberg's avatar Anton Kullberg
Browse files

py: implemented GNN and JPDA

parent 8149955b
No related branches found
No related tags found
No related merge requests found
import pdb
import numpy as np import numpy as np
import scipy.stats as stats import scipy.stats as stats
import scipy import scipy
import tqdm import tqdm
from pyehm.core import EHM2
class BasicTracker(): class BasicTracker():
def __init__(self, filt, clutter_model, associator, gater): def __init__(self, filt, clutter_model, associator, gater):
...@@ -63,6 +65,25 @@ class IMMTracker(): ...@@ -63,6 +65,25 @@ class IMMTracker():
class GNN(): class GNN():
def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model): def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model):
"""An implementation of a Global Nearest Neighbour tracker.
Parameters
----------
logic : logic
See src.logic. Some sort of track logic.
logic_params : dict
Contains parameters to the track logic.
init_track : callable
A function that initiates a track. Should take a measurement, the
time, an id and the filter to use for the track as input.
filt : filter
See src.filter. Some sort of filter to use for the tracks.
gater : gater
See src.gater. A gating function.
clutter_model : dict
A dict containing the clutter model.
"""
self.logic = logic self.logic = logic
self.logic_params = logic_params self.logic_params = logic_params
self.init_track = init_track self.init_track = init_track
...@@ -71,6 +92,19 @@ class GNN(): ...@@ -71,6 +92,19 @@ class GNN():
self.clutter_model = clutter_model self.clutter_model = clutter_model
def _update_track(self, meas, track): def _update_track(self, meas, track):
"""Handles the update of a certain track with the given measurement(s).
Modifies the track in-place!
Parameters
----------
meas : numpy.ndarray
Contains measurement(s) to update a specific track with. ny by N,
where N is the number of measurements to update the track with.
track : dict
A dict containing everything relevant to the track.
"""
if meas.size == 0: if meas.size == 0:
track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track
return return
...@@ -83,58 +117,272 @@ class GNN(): ...@@ -83,58 +117,272 @@ class GNN():
# Update # Update
track['x'][-1], track['P'][-1] = track['filt'].update(track['x'][-1], track['P'][-1], eps) track['x'][-1], track['P'][-1] = track['filt'].update(track['x'][-1], track['P'][-1], eps)
def associate_update(self, meas_k, k, tracks, unused_meas):
"""Associates measurements to tracks and updates the tracks with the
measurements.
Does *not* return anything, but modifies the objects in-place!
Uses Efficient Hypothesis Management (see
https://github.com/sglvladi/pyehm) for computing the hypothesis
probabilities.
Parameters
----------
meas_k : numpy.ndarray
Measurements to attempt to associate
k : int
The current time step
tracks : list
A list of the tracks to associate the measurements with
unused_meas : numpy.ndarray
A logical array indicating what measurements are still
unused/non-associated.
"""
association_matrix, validation_matrix = get_association_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater)
# Solve association problem
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix)
for row, col in zip(row_ind, col_ind):
if col >= len(tracks): # No target to associate the measurement to
continue
else:
# Update confirmed tracks
self._update_track(meas_k[:, unused_meas][:, row], tracks[col])
tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
for i in range(len(tracks)):
if i not in col_ind:
self._update_track(np.array([]), tracks[i])
# Remove any gated measurements from further consideration
tmp = unused_meas[unused_meas] # Extract the unused measurements
inds = np.where(validation_matrix.sum(axis=1))
tmp[inds] = 0
unused_meas[unused_meas] = tmp
def evaluate(self, Y): def evaluate(self, Y):
""" Evaluates the detections in Y.
Parameters
----------
Y : list
List of detections at time k=0 to K where K is the length of Y.
Each entry of Y is ny by N_k where N_k is time-varying as the number
of detections vary.
Returns
-------
list, list
First list contains all initiated tracks, both tentative, deleted
and confirmed. The second list contains only the confirmed list,
even if they have died. Hence, the lists contain duplicates (but
point to the same object!).
"""
rng = np.random.default_rng()
tracks = [] # Store all tracks tracks = [] # Store all tracks
confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only) confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only)
ids = 0 ids = 0
for k, meas_k in tqdm.tqdm(enumerate(Y), desc="Evaluating observations: "): for k, meas_k in tqdm.tqdm(enumerate(Y), desc="GNN evaluating detections: "):
ny = meas_k.shape[1] ny = meas_k.shape[1]
unused_meas = np.ones((ny), dtype=bool) unused_meas = np.ones((ny), dtype=bool)
# Handle the confirmed and alive tracks
live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed'] live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed']
if live_tracks: if live_tracks:
association_matrix, _ = get_association_matrix(meas_k, live_tracks, self.logic_params, self.gater) self.associate_update(meas_k, k, live_tracks, unused_meas)
# Solve association problem
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix)
for row, col in zip(row_ind, col_ind):
if col >= len(live_tracks): # No target to associate the measurement to
continue
else:
unused_meas[row] = 0 # Remove this measurement from further consideration
# Update confirmed tracks
self._update_track(meas_k[:, row], live_tracks[col])
live_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
for i in range(len(live_tracks)):
if i not in col_ind:
self._update_track(np.array([]), live_tracks[i])
# Handle the tentative tracks
tentative_tracks = [track for track in tracks if track['stage'] == 'tentative'] tentative_tracks = [track for track in tracks if track['stage'] == 'tentative']
if tentative_tracks: if tentative_tracks:
association_matrix, _ = get_association_matrix(meas_k[:, unused_meas], tentative_tracks, self.logic_params, self.gater) self.associate_update(meas_k, k, tentative_tracks, unused_meas)
# Solve association problem for track in tentative_tracks:
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix) if track['stage'] == 'confirmed':
meas = meas_k[:, unused_meas] confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks
for row, col in zip(row_ind, col_ind):
if col >= len(tentative_tracks): # No target to associate the measurement to
continue
else:
unused_meas[(meas_k == meas[:,[row]]).all(axis=0)] = 0 # Remove this measurement from consideration
# Update confirmed tracks
self._update_track(meas[:, row], tentative_tracks[col])
tentative_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
if tentative_tracks[col]['stage'] == 'confirmed':
confirmed_tracks.append(tentative_tracks[col]) # If a track has been confirmed, add it to confirmed tracks
for i in range(len(tentative_tracks)):
if i not in col_ind:
self._update_track(np.array([]), tentative_tracks[i])
# Use the unused measurements to initiate new tracks # Use the unused measurements to initiate new tracks
for meas in meas_k[:, unused_meas].T: while unused_meas.any():
tracks.append(self.init_track(meas, k, ids, self.filt)) ind = rng.choice(np.arange(unused_meas.size), p=unused_meas/unused_meas.sum())
track = self.init_track(meas_k[:, ind], k, ids, self.filt) # Initialize track
tracks.append(track)
unused_meas[ind] = 0 # Remove measurement from association hypothesis
self.associate_update(meas_k, k, [track], unused_meas)
ids += 1
if track['stage'] == 'confirmed':
confirmed_tracks.append(track)
for track in tracks:
if track['stage'] != 'deleted':
x, P = track['filt'].propagate(track['x'][-1], track['P'][-1])
track['x'].append(x)
track['P'].append(P)
track['t'].append(k+1)
return tracks, confirmed_tracks
class JPDA():
def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model):
"""An implementation of a Joint Probabilistic Data Association tracker.
Parameters
----------
logic : logic
See src.logic. Some sort of track logic.
logic_params : dict
Contains parameters to the track logic.
init_track : callable
A function that initiates a track. Should take a measurement, the
time, an id and the filter to use for the track as input.
filt : filter
See src.filter. Some sort of filter to use for the tracks.
gater : gater
See src.gater. A gating function.
clutter_model : dict
A dict containing the clutter model.
"""
self.logic = logic
self.logic_params = logic_params
self.init_track = init_track
self.filt = filt
self.gater = gater
self.clutter_model = clutter_model
def _update_track(self, meas, track, association_probability):
"""Handles the update of a certain track with the given measurement(s).
Modifies the track in-place!
Parameters
----------
meas : numpy.ndarray
Contains measurement(s) to update a specific track with. ny by N,
where N is the number of measurements to update the track with.
track : dict
A dict containing everything relevant to the track.
association_probability : numpy.ndarray
The association probability of each measurement to the track.
"""
if meas.size == 0:
track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track
return
# Calculate prediction error of each measurement
yhat = track['filt'].sensor_model['h'](track['x'][-1])
eps = meas-yhat[:, None]
# Update track
Lt = []
x = []
P = []
oLt = track['Lt']
# Handle false measurement separately
x.append(track['x'][-1])
P.append(track['P'][-1])
track = self.logic(np.array([]), track['filt'], track, self.logic_params)
Lt.append(track['Lt'])
track['Lt'] = oLt # Reset track score
for j in range(eps.shape[1]):
# Compute the track score given an association
track = self.logic(meas[:, j], track['filt'], track, self.logic_params)
Lt.append(track['Lt'])
track['Lt'] = oLt # Reset the current track score
xj, Pj = track['filt'].update(track['x'][-1], track['P'][-1], eps[:, j])
x.append(xj)
P.append(Pj)
# Update the track score by marginalizing the measurement associations
track['Lt'] = np.array(Lt)@association_probability
# Compute the state estimate
xhat = (association_probability@np.vstack(x)).T
# Compute the state error covariance
err = np.vstack(x).T-xhat[:, None]
Pk = np.stack([col[:,None]@col[None,:] for col in err.T])
Phat = np.tensordot(np.stack(P)+Pk, association_probability, (0, 0)) # Dot product over axis 0
track['x'][-1] = xhat
track['P'][-1] = Phat
def associate_update(self, meas_k, k, tracks, unused_meas):
"""Associates measurements to tracks and updates the tracks with the
measurements.
Does *not* return anything, but modifies the objects in-place!
Uses Efficient Hypothesis Management (see
https://github.com/sglvladi/pyehm) for computing the hypothesis
probabilities.
Parameters
----------
meas_k : numpy.ndarray
Measurements to attempt to associate
k : int
The current time step
tracks : list
A list of the tracks to associate the measurements with
unused_meas : numpy.ndarray
A logical array indicating what measurements are still
unused/non-associated.
"""
likelihood_matrix, validation_matrix = get_likelihood_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater)
association_matrix = EHM2.run(validation_matrix, likelihood_matrix)
meas = meas_k[:, unused_meas]
for ti, track in enumerate(tracks):
if validation_matrix[ti, 1:].any(): # If any measurements are validated to this track, update it accordingly
self._update_track(meas[:, validation_matrix[ti, 1:].flatten()], track, association_matrix[ti, validation_matrix[ti, :]])
track['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
else:
self._update_track(np.array([]), track, None)
# Measurements that are validated to any track can be removed from further association
tmp = unused_meas[unused_meas]
used_inds = np.where(validation_matrix[:, 1:].sum(axis=0))
tmp[used_inds] = 0
unused_meas[unused_meas] = tmp
def evaluate(self, Y):
""" Evaluates the detections in Y.
Parameters
----------
Y : list
List of detections at time k=0 to K where K is the length of Y.
Each entry of Y is ny by N_k where N_k is time-varying as the number
of detections vary.
Returns
-------
list, list
First list contains all initiated tracks, both tentative, deleted
and confirmed. The second list contains only the confirmed list,
even if they have died. Hence, the lists contain duplicates (but
point to the same object!).
"""
rng = np.random.default_rng()
tracks = [] # Store all tracks
confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only)
ids = 0
for k, meas_k in tqdm.tqdm(enumerate(Y), desc="JPDA evaluating detections: "):
ny = meas_k.shape[1]
unused_meas = np.ones((ny), dtype=bool)
# Handle confirmed and alive tracks
live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed']
if live_tracks:
self.associate_update(meas_k, k, live_tracks, unused_meas)
# Handle tentative tracks
tentative_tracks = [track for track in tracks if track['stage']=='tentative']
if tentative_tracks:
self.associate_update(meas_k, k, tentative_tracks, unused_meas)
for track in tentative_tracks:
if track['stage'] == 'confirmed':
confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks
# Initiate new tracks at random
while unused_meas.any():
ind = rng.choice(np.arange(unused_meas.size), p=unused_meas/unused_meas.sum())
track = self.init_track(meas_k[:, ind], k, ids, self.filt) # Initialize track
tracks.append(track)
unused_meas[ind] = 0 # Remove measurement from association hypothesis
self.associate_update(meas_k, k, [track], unused_meas)
ids += 1 ids += 1
for track in tracks: for track in tracks:
...@@ -145,8 +393,72 @@ class GNN(): ...@@ -145,8 +393,72 @@ class GNN():
track['t'].append(k+1) track['t'].append(k+1)
return tracks, confirmed_tracks return tracks, confirmed_tracks
def get_likelihood_matrix(meas, tracks, logic_params, gater):
""" Computes the likelihood and validation matrix (specifically for the
JPDA implementation)
Parameters
----------
meas : numpy.ndarray
Measurements to attempt to associate
tracks : list
A list of the tracks to associate the measurements with
logic_params : dict
Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.)
gater : gater
A gater that can gate the measurements with the tracks.
Returns
-------
numpy.ndarray, numpy.ndarray
Returns a likelihood matrix containing the unnormalized likelihood of
associating measurement i with track j. Also returns a validation matrix
indicating the possible associations of measurement i with track j.
Both the likelihood and validation matrices are of size Nc by ny+1,
where Nc is the number of tracks. The first column of the validation
and likelihood matrices corresponds to a false alarm.
"""
ny = meas.shape[1]
Nc = len(tracks) # Number of tracks to associate
validation_matrix = np.zeros((Nc, ny+1), dtype=bool)
validation_matrix[:, 0] = 1
likelihood_matrix = np.zeros((Nc, ny+1))
for ti, track in enumerate(tracks): # Iterate over confirmed tracks
validation_matrix[ti, 1:] = gater.gate(track['x'][-1], track['P'][-1], meas)
# Entry for validated tracks
val_meas = meas[:, validation_matrix[ti, 1:]] # Get the validated measurements for this track
yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track
H = track['filt'].sensor_model['dhdx'](track['x'][-1])
py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
likelihood_matrix[ti, np.where(validation_matrix[ti, 1:])[0]+1] = track['filt'].sensor_model['PD']*py
likelihood_matrix[ti, 0] = 1-track['filt'].sensor_model['PD'] # PG assumed 1
return likelihood_matrix, validation_matrix
def get_association_matrix(meas, tracks, logic_params, gater): def get_association_matrix(meas, tracks, logic_params, gater):
""" Computes the validation and association matrix (specifically for the
GNN implementation)
Parameters
----------
meas : numpy.ndarray
Measurements to attempt to associate
tracks : list
A list of the tracks to associate the measurements with
logic_params : dict
Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.)
gater : gater
A gater that can gate the measurements with the tracks.
Returns
-------
numpy.ndarray, numpy.ndarray
Returns an association matrix of size ny by Nc+2*ny and a validation
matrix of size ny by Nc, where Nc is the number of tracks. The association
matrix also contains the false alarm and new track possibilities.
"""
ny = meas.shape[1] ny = meas.shape[1]
Nc = len(tracks) # Number of tracks to associate Nc = len(tracks) # Number of tracks to associate
validation_matrix = np.zeros((ny, Nc), dtype=bool) validation_matrix = np.zeros((ny, Nc), dtype=bool)
...@@ -166,3 +478,29 @@ def get_association_matrix(meas, tracks, logic_params, gater): ...@@ -166,3 +478,29 @@ def get_association_matrix(meas, tracks, logic_params, gater):
py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R']) py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1 association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1
return association_matrix, validation_matrix return association_matrix, validation_matrix
### Obsolete
def compute_prob(association_matrix, validation_matrix, logic_params):
# Association matrix is assumed to consist of tracks and FA, no NT.
ny = association_matrix.shape[0]
ntracks = association_matrix.shape[1]-ny
P = np.zeros((ny, ntracks))
def rec_find_associations(association_matrix, assoc_done, logic_params):
inds = np.where(association_matrix[0, :] != -np.inf)[0] # These are the nodes necessary to look at
this_assoc = []
for k, i in enumerate(inds):
if i not in assoc_done:
if association_matrix.shape[0] != 1:
assoc = rec_compute_prob(association_matrix[1:, :], [[i]], logic_params)
this_assoc.extend(assoc)
else:
this_assoc.append([i])
result = []
for assoc in assoc_done:
for th_assoc in this_assoc:
result.append(assoc + th_assoc)
return result
possible_associations = rec_find_associations(association_matrix, [[]], logic_params) # Recursively finds possible measurement hypothesis
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment