Skip to content
Snippets Groups Projects
Commit 7448e41c authored by Anton Kullberg's avatar Anton Kullberg
Browse files

refactor: moved function to class

parent 128b20b6
No related branches found
No related tags found
No related merge requests found
......@@ -167,6 +167,45 @@ class GNN():
self.gater = gater
self.clutter_model = clutter_model
def get_association_matrix(self, meas, tracks):
""" Computes the validation and association matrix (specifically for the
GNN implementation)
Parameters
----------
meas : numpy.ndarray
Measurements to attempt to associate
tracks : list
A list of the tracks to associate the measurements with
Returns
-------
numpy.ndarray, numpy.ndarray
Returns an association matrix of size ny by Nc+2*ny and a validation
matrix of size ny by Nc, where Nc is the number of tracks. The association
matrix also contains the false alarm and new track possibilities.
"""
ny = meas.shape[1]
Nc = len(tracks) # Number of tracks to associate
validation_matrix = np.zeros((ny, Nc), dtype=bool)
association_matrix = -np.inf*np.ones((ny, Nc+2*ny))
# Entry for false alarms
np.fill_diagonal(association_matrix[:, Nc:Nc+ny], np.log(self.logic_params['Bfa']))
# Entry for new targets
np.fill_diagonal(association_matrix[:, Nc+ny:], np.log(self.logic_params['Bnt']))
for ti, track in enumerate(tracks): # Iterate over confirmed tracks
validation_matrix[:, ti] = self.gater.gate(track['x'][-1], track['P'][-1], meas)
# Entry for validated tracks
val_meas = meas[:, validation_matrix[:, ti]] # Get the validated measurements for this track
yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track
H = track['filt'].sensor_model['dhdx'](track['x'][-1])
py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1
return association_matrix, validation_matrix
def _update_track(self, meas, track):
"""Handles the update of a certain track with the given measurement(s).
......@@ -215,7 +254,7 @@ class GNN():
unused/non-associated.
"""
association_matrix, validation_matrix = get_association_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater)
association_matrix, validation_matrix = self.get_association_matrix(meas_k[:, unused_meas], tracks)
# Solve association problem
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix)
for row, col in zip(row_ind, col_ind):
......@@ -321,6 +360,45 @@ class JPDA():
self.gater = gater
self.clutter_model = clutter_model
def get_likelihood_matrix(self, meas, tracks):
""" Computes the likelihood and validation matrix (specifically for the
JPDA implementation)
Parameters
----------
meas : numpy.ndarray
Measurements to attempt to associate
tracks : list
A list of the tracks to associate the measurements with
Returns
-------
numpy.ndarray, numpy.ndarray
Returns a likelihood matrix containing the unnormalized likelihood of
associating measurement i with track j. Also returns a validation matrix
indicating the possible associations of measurement i with track j.
Both the likelihood and validation matrices are of size Nc by ny+1,
where Nc is the number of tracks. The first column of the validation
and likelihood matrices corresponds to a false alarm.
"""
ny = meas.shape[1]
Nc = len(tracks) # Number of tracks to associate
validation_matrix = np.zeros((Nc, ny+1), dtype=bool)
validation_matrix[:, 0] = 1
likelihood_matrix = np.zeros((Nc, ny+1))
for ti, track in enumerate(tracks): # Iterate over confirmed tracks
validation_matrix[ti, 1:] = self.gater.gate(track['x'][-1], track['P'][-1], meas)
# Entry for validated tracks
val_meas = meas[:, validation_matrix[ti, 1:]] # Get the validated measurements for this track
yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track
H = track['filt'].sensor_model['dhdx'](track['x'][-1])
py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
likelihood_matrix[ti, np.where(validation_matrix[ti, 1:])[0]+1] = track['filt'].sensor_model['PD']*py
likelihood_matrix[ti, 0] = 1-track['filt'].sensor_model['PD'] # PG assumed 1
return likelihood_matrix, validation_matrix
def _update_track(self, meas, track, association_probability):
"""Handles the update of a certain track with the given measurement(s).
......@@ -396,7 +474,7 @@ class JPDA():
unused/non-associated.
"""
likelihood_matrix, validation_matrix = get_likelihood_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater)
likelihood_matrix, validation_matrix = self.get_likelihood_matrix(meas_k[:, unused_meas], tracks)
association_matrix = EHM2.run(validation_matrix, likelihood_matrix)
meas = meas_k[:, unused_meas]
......@@ -718,89 +796,3 @@ class MHT():
for hyp in hypothesis[k]:
hyp['probability'] /= total_score
return hypothesis
def get_likelihood_matrix(meas, tracks, logic_params, gater):
""" Computes the likelihood and validation matrix (specifically for the
JPDA implementation)
Parameters
----------
meas : numpy.ndarray
Measurements to attempt to associate
tracks : list
A list of the tracks to associate the measurements with
logic_params : dict
Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.)
gater : gater
A gater that can gate the measurements with the tracks.
Returns
-------
numpy.ndarray, numpy.ndarray
Returns a likelihood matrix containing the unnormalized likelihood of
associating measurement i with track j. Also returns a validation matrix
indicating the possible associations of measurement i with track j.
Both the likelihood and validation matrices are of size Nc by ny+1,
where Nc is the number of tracks. The first column of the validation
and likelihood matrices corresponds to a false alarm.
"""
ny = meas.shape[1]
Nc = len(tracks) # Number of tracks to associate
validation_matrix = np.zeros((Nc, ny+1), dtype=bool)
validation_matrix[:, 0] = 1
likelihood_matrix = np.zeros((Nc, ny+1))
for ti, track in enumerate(tracks): # Iterate over confirmed tracks
validation_matrix[ti, 1:] = gater.gate(track['x'][-1], track['P'][-1], meas)
# Entry for validated tracks
val_meas = meas[:, validation_matrix[ti, 1:]] # Get the validated measurements for this track
yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track
H = track['filt'].sensor_model['dhdx'](track['x'][-1])
py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
likelihood_matrix[ti, np.where(validation_matrix[ti, 1:])[0]+1] = track['filt'].sensor_model['PD']*py
likelihood_matrix[ti, 0] = 1-track['filt'].sensor_model['PD'] # PG assumed 1
return likelihood_matrix, validation_matrix
def get_association_matrix(meas, tracks, logic_params, gater):
""" Computes the validation and association matrix (specifically for the
GNN implementation)
Parameters
----------
meas : numpy.ndarray
Measurements to attempt to associate
tracks : list
A list of the tracks to associate the measurements with
logic_params : dict
Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.)
gater : gater
A gater that can gate the measurements with the tracks.
Returns
-------
numpy.ndarray, numpy.ndarray
Returns an association matrix of size ny by Nc+2*ny and a validation
matrix of size ny by Nc, where Nc is the number of tracks. The association
matrix also contains the false alarm and new track possibilities.
"""
ny = meas.shape[1]
Nc = len(tracks) # Number of tracks to associate
validation_matrix = np.zeros((ny, Nc), dtype=bool)
association_matrix = -np.inf*np.ones((ny, Nc+2*ny))
# Entry for false alarms
np.fill_diagonal(association_matrix[:, Nc:Nc+ny], np.log(logic_params['Bfa']))
# Entry for new targets
np.fill_diagonal(association_matrix[:, Nc+ny:], np.log(logic_params['Bnt']))
for ti, track in enumerate(tracks): # Iterate over confirmed tracks
validation_matrix[:, ti] = gater.gate(track['x'][-1], track['P'][-1], meas)
# Entry for validated tracks
val_meas = meas[:, validation_matrix[:, ti]] # Get the validated measurements for this track
yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track
H = track['filt'].sensor_model['dhdx'](track['x'][-1])
py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1
return association_matrix, validation_matrix
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment