From a8f3128c20171f8a41f13909cce9cf9588fae032 Mon Sep 17 00:00:00 2001 From: Anton Kullberg <anton.kullberg@liu.se> Date: Tue, 9 Nov 2021 14:43:18 +0100 Subject: [PATCH] py: implemented GNN and JPDA --- src/trackers.py | 414 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 376 insertions(+), 38 deletions(-) diff --git a/src/trackers.py b/src/trackers.py index fce49b5..84bfc61 100644 --- a/src/trackers.py +++ b/src/trackers.py @@ -1,7 +1,9 @@ +import pdb import numpy as np import scipy.stats as stats import scipy import tqdm +from pyehm.core import EHM2 class BasicTracker(): def __init__(self, filt, clutter_model, associator, gater): @@ -63,6 +65,25 @@ class IMMTracker(): class GNN(): def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model): + """An implementation of a Global Nearest Neighbour tracker. + + Parameters + ---------- + logic : logic + See src.logic. Some sort of track logic. + logic_params : dict + Contains parameters to the track logic. + init_track : callable + A function that initiates a track. Should take a measurement, the + time, an id and the filter to use for the track as input. + filt : filter + See src.filter. Some sort of filter to use for the tracks. + gater : gater + See src.gater. A gating function. + clutter_model : dict + A dict containing the clutter model. + + """ self.logic = logic self.logic_params = logic_params self.init_track = init_track @@ -71,6 +92,19 @@ class GNN(): self.clutter_model = clutter_model def _update_track(self, meas, track): + """Handles the update of a certain track with the given measurement(s). + + Modifies the track in-place! + + Parameters + ---------- + meas : numpy.ndarray + Contains measurement(s) to update a specific track with. ny by N, + where N is the number of measurements to update the track with. + track : dict + A dict containing everything relevant to the track. + + """ if meas.size == 0: track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track return @@ -83,58 +117,272 @@ class GNN(): # Update track['x'][-1], track['P'][-1] = track['filt'].update(track['x'][-1], track['P'][-1], eps) + def associate_update(self, meas_k, k, tracks, unused_meas): + """Associates measurements to tracks and updates the tracks with the + measurements. + + Does *not* return anything, but modifies the objects in-place! + Uses Efficient Hypothesis Management (see + https://github.com/sglvladi/pyehm) for computing the hypothesis + probabilities. + + Parameters + ---------- + meas_k : numpy.ndarray + Measurements to attempt to associate + k : int + The current time step + tracks : list + A list of the tracks to associate the measurements with + unused_meas : numpy.ndarray + A logical array indicating what measurements are still + unused/non-associated. + + """ + association_matrix, validation_matrix = get_association_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater) + # Solve association problem + row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix) + for row, col in zip(row_ind, col_ind): + if col >= len(tracks): # No target to associate the measurement to + continue + else: + # Update confirmed tracks + self._update_track(meas_k[:, unused_meas][:, row], tracks[col]) + tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes) + for i in range(len(tracks)): + if i not in col_ind: + self._update_track(np.array([]), tracks[i]) + # Remove any gated measurements from further consideration + tmp = unused_meas[unused_meas] # Extract the unused measurements + inds = np.where(validation_matrix.sum(axis=1)) + tmp[inds] = 0 + unused_meas[unused_meas] = tmp + def evaluate(self, Y): + """ Evaluates the detections in Y. + + Parameters + ---------- + Y : list + List of detections at time k=0 to K where K is the length of Y. + Each entry of Y is ny by N_k where N_k is time-varying as the number + of detections vary. + + Returns + ------- + list, list + First list contains all initiated tracks, both tentative, deleted + and confirmed. The second list contains only the confirmed list, + even if they have died. Hence, the lists contain duplicates (but + point to the same object!). + + """ + rng = np.random.default_rng() tracks = [] # Store all tracks confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only) ids = 0 - for k, meas_k in tqdm.tqdm(enumerate(Y), desc="Evaluating observations: "): + for k, meas_k in tqdm.tqdm(enumerate(Y), desc="GNN evaluating detections: "): ny = meas_k.shape[1] unused_meas = np.ones((ny), dtype=bool) + # Handle the confirmed and alive tracks live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed'] - if live_tracks: - association_matrix, _ = get_association_matrix(meas_k, live_tracks, self.logic_params, self.gater) - # Solve association problem - row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix) - - for row, col in zip(row_ind, col_ind): - if col >= len(live_tracks): # No target to associate the measurement to - continue - else: - unused_meas[row] = 0 # Remove this measurement from further consideration - # Update confirmed tracks - self._update_track(meas_k[:, row], live_tracks[col]) - live_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes) - for i in range(len(live_tracks)): - if i not in col_ind: - self._update_track(np.array([]), live_tracks[i]) - + self.associate_update(meas_k, k, live_tracks, unused_meas) + # Handle the tentative tracks tentative_tracks = [track for track in tracks if track['stage'] == 'tentative'] - if tentative_tracks: - association_matrix, _ = get_association_matrix(meas_k[:, unused_meas], tentative_tracks, self.logic_params, self.gater) - # Solve association problem - row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix) - meas = meas_k[:, unused_meas] - for row, col in zip(row_ind, col_ind): - if col >= len(tentative_tracks): # No target to associate the measurement to - continue - else: - unused_meas[(meas_k == meas[:,[row]]).all(axis=0)] = 0 # Remove this measurement from consideration - # Update confirmed tracks - self._update_track(meas[:, row], tentative_tracks[col]) - tentative_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes) - if tentative_tracks[col]['stage'] == 'confirmed': - confirmed_tracks.append(tentative_tracks[col]) # If a track has been confirmed, add it to confirmed tracks - for i in range(len(tentative_tracks)): - if i not in col_ind: - self._update_track(np.array([]), tentative_tracks[i]) + self.associate_update(meas_k, k, tentative_tracks, unused_meas) + for track in tentative_tracks: + if track['stage'] == 'confirmed': + confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks # Use the unused measurements to initiate new tracks - for meas in meas_k[:, unused_meas].T: - tracks.append(self.init_track(meas, k, ids, self.filt)) + while unused_meas.any(): + ind = rng.choice(np.arange(unused_meas.size), p=unused_meas/unused_meas.sum()) + track = self.init_track(meas_k[:, ind], k, ids, self.filt) # Initialize track + tracks.append(track) + unused_meas[ind] = 0 # Remove measurement from association hypothesis + self.associate_update(meas_k, k, [track], unused_meas) + ids += 1 + if track['stage'] == 'confirmed': + confirmed_tracks.append(track) + + for track in tracks: + if track['stage'] != 'deleted': + x, P = track['filt'].propagate(track['x'][-1], track['P'][-1]) + track['x'].append(x) + track['P'].append(P) + track['t'].append(k+1) + return tracks, confirmed_tracks + +class JPDA(): + def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model): + """An implementation of a Joint Probabilistic Data Association tracker. + + Parameters + ---------- + logic : logic + See src.logic. Some sort of track logic. + logic_params : dict + Contains parameters to the track logic. + init_track : callable + A function that initiates a track. Should take a measurement, the + time, an id and the filter to use for the track as input. + filt : filter + See src.filter. Some sort of filter to use for the tracks. + gater : gater + See src.gater. A gating function. + clutter_model : dict + A dict containing the clutter model. + + """ + self.logic = logic + self.logic_params = logic_params + self.init_track = init_track + self.filt = filt + self.gater = gater + self.clutter_model = clutter_model + + def _update_track(self, meas, track, association_probability): + """Handles the update of a certain track with the given measurement(s). + + Modifies the track in-place! + + Parameters + ---------- + meas : numpy.ndarray + Contains measurement(s) to update a specific track with. ny by N, + where N is the number of measurements to update the track with. + track : dict + A dict containing everything relevant to the track. + association_probability : numpy.ndarray + The association probability of each measurement to the track. + + """ + if meas.size == 0: + track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track + return + # Calculate prediction error of each measurement + yhat = track['filt'].sensor_model['h'](track['x'][-1]) + eps = meas-yhat[:, None] + + # Update track + Lt = [] + x = [] + P = [] + oLt = track['Lt'] + # Handle false measurement separately + x.append(track['x'][-1]) + P.append(track['P'][-1]) + track = self.logic(np.array([]), track['filt'], track, self.logic_params) + Lt.append(track['Lt']) + track['Lt'] = oLt # Reset track score + for j in range(eps.shape[1]): + # Compute the track score given an association + track = self.logic(meas[:, j], track['filt'], track, self.logic_params) + Lt.append(track['Lt']) + track['Lt'] = oLt # Reset the current track score + xj, Pj = track['filt'].update(track['x'][-1], track['P'][-1], eps[:, j]) + x.append(xj) + P.append(Pj) + # Update the track score by marginalizing the measurement associations + track['Lt'] = np.array(Lt)@association_probability + # Compute the state estimate + xhat = (association_probability@np.vstack(x)).T + # Compute the state error covariance + err = np.vstack(x).T-xhat[:, None] + Pk = np.stack([col[:,None]@col[None,:] for col in err.T]) + Phat = np.tensordot(np.stack(P)+Pk, association_probability, (0, 0)) # Dot product over axis 0 + track['x'][-1] = xhat + track['P'][-1] = Phat + + def associate_update(self, meas_k, k, tracks, unused_meas): + """Associates measurements to tracks and updates the tracks with the + measurements. + + Does *not* return anything, but modifies the objects in-place! + Uses Efficient Hypothesis Management (see + https://github.com/sglvladi/pyehm) for computing the hypothesis + probabilities. + + Parameters + ---------- + meas_k : numpy.ndarray + Measurements to attempt to associate + k : int + The current time step + tracks : list + A list of the tracks to associate the measurements with + unused_meas : numpy.ndarray + A logical array indicating what measurements are still + unused/non-associated. + + """ + likelihood_matrix, validation_matrix = get_likelihood_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater) + association_matrix = EHM2.run(validation_matrix, likelihood_matrix) + + meas = meas_k[:, unused_meas] + for ti, track in enumerate(tracks): + if validation_matrix[ti, 1:].any(): # If any measurements are validated to this track, update it accordingly + self._update_track(meas[:, validation_matrix[ti, 1:].flatten()], track, association_matrix[ti, validation_matrix[ti, :]]) + track['associations'].append(k) # If we've associated something, add the time here (for plotting purposes) + else: + self._update_track(np.array([]), track, None) + # Measurements that are validated to any track can be removed from further association + tmp = unused_meas[unused_meas] + used_inds = np.where(validation_matrix[:, 1:].sum(axis=0)) + tmp[used_inds] = 0 + unused_meas[unused_meas] = tmp + + def evaluate(self, Y): + """ Evaluates the detections in Y. + + Parameters + ---------- + Y : list + List of detections at time k=0 to K where K is the length of Y. + Each entry of Y is ny by N_k where N_k is time-varying as the number + of detections vary. + + Returns + ------- + list, list + First list contains all initiated tracks, both tentative, deleted + and confirmed. The second list contains only the confirmed list, + even if they have died. Hence, the lists contain duplicates (but + point to the same object!). + + """ + rng = np.random.default_rng() + tracks = [] # Store all tracks + confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only) + ids = 0 + for k, meas_k in tqdm.tqdm(enumerate(Y), desc="JPDA evaluating detections: "): + ny = meas_k.shape[1] + unused_meas = np.ones((ny), dtype=bool) + + # Handle confirmed and alive tracks + live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed'] + if live_tracks: + self.associate_update(meas_k, k, live_tracks, unused_meas) + + # Handle tentative tracks + tentative_tracks = [track for track in tracks if track['stage']=='tentative'] + if tentative_tracks: + self.associate_update(meas_k, k, tentative_tracks, unused_meas) + for track in tentative_tracks: + if track['stage'] == 'confirmed': + confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks + + # Initiate new tracks at random + while unused_meas.any(): + ind = rng.choice(np.arange(unused_meas.size), p=unused_meas/unused_meas.sum()) + track = self.init_track(meas_k[:, ind], k, ids, self.filt) # Initialize track + tracks.append(track) + unused_meas[ind] = 0 # Remove measurement from association hypothesis + self.associate_update(meas_k, k, [track], unused_meas) ids += 1 for track in tracks: @@ -145,8 +393,72 @@ class GNN(): track['t'].append(k+1) return tracks, confirmed_tracks +def get_likelihood_matrix(meas, tracks, logic_params, gater): + """ Computes the likelihood and validation matrix (specifically for the + JPDA implementation) + + Parameters + ---------- + meas : numpy.ndarray + Measurements to attempt to associate + tracks : list + A list of the tracks to associate the measurements with + logic_params : dict + Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.) + gater : gater + A gater that can gate the measurements with the tracks. + + Returns + ------- + numpy.ndarray, numpy.ndarray + Returns a likelihood matrix containing the unnormalized likelihood of + associating measurement i with track j. Also returns a validation matrix + indicating the possible associations of measurement i with track j. + Both the likelihood and validation matrices are of size Nc by ny+1, + where Nc is the number of tracks. The first column of the validation + and likelihood matrices corresponds to a false alarm. + + """ + ny = meas.shape[1] + Nc = len(tracks) # Number of tracks to associate + validation_matrix = np.zeros((Nc, ny+1), dtype=bool) + validation_matrix[:, 0] = 1 + likelihood_matrix = np.zeros((Nc, ny+1)) + + for ti, track in enumerate(tracks): # Iterate over confirmed tracks + validation_matrix[ti, 1:] = gater.gate(track['x'][-1], track['P'][-1], meas) + # Entry for validated tracks + val_meas = meas[:, validation_matrix[ti, 1:]] # Get the validated measurements for this track + yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track + H = track['filt'].sensor_model['dhdx'](track['x'][-1]) + py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R']) + likelihood_matrix[ti, np.where(validation_matrix[ti, 1:])[0]+1] = track['filt'].sensor_model['PD']*py + likelihood_matrix[ti, 0] = 1-track['filt'].sensor_model['PD'] # PG assumed 1 + return likelihood_matrix, validation_matrix def get_association_matrix(meas, tracks, logic_params, gater): + """ Computes the validation and association matrix (specifically for the + GNN implementation) + + Parameters + ---------- + meas : numpy.ndarray + Measurements to attempt to associate + tracks : list + A list of the tracks to associate the measurements with + logic_params : dict + Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.) + gater : gater + A gater that can gate the measurements with the tracks. + + Returns + ------- + numpy.ndarray, numpy.ndarray + Returns an association matrix of size ny by Nc+2*ny and a validation + matrix of size ny by Nc, where Nc is the number of tracks. The association + matrix also contains the false alarm and new track possibilities. + + """ ny = meas.shape[1] Nc = len(tracks) # Number of tracks to associate validation_matrix = np.zeros((ny, Nc), dtype=bool) @@ -166,3 +478,29 @@ def get_association_matrix(meas, tracks, logic_params, gater): py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R']) association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1 return association_matrix, validation_matrix + +### Obsolete +def compute_prob(association_matrix, validation_matrix, logic_params): + # Association matrix is assumed to consist of tracks and FA, no NT. + ny = association_matrix.shape[0] + ntracks = association_matrix.shape[1]-ny + P = np.zeros((ny, ntracks)) + + def rec_find_associations(association_matrix, assoc_done, logic_params): + inds = np.where(association_matrix[0, :] != -np.inf)[0] # These are the nodes necessary to look at + this_assoc = [] + for k, i in enumerate(inds): + if i not in assoc_done: + if association_matrix.shape[0] != 1: + assoc = rec_compute_prob(association_matrix[1:, :], [[i]], logic_params) + this_assoc.extend(assoc) + else: + this_assoc.append([i]) + result = [] + for assoc in assoc_done: + for th_assoc in this_assoc: + result.append(assoc + th_assoc) + return result + + possible_associations = rec_find_associations(association_matrix, [[]], logic_params) # Recursively finds possible measurement hypothesis + return res -- GitLab