diff --git a/src/trackers.py b/src/trackers.py index d07b574172bc30598e8f35df3eefc914fad08c06..fce49b54cc6b8cc3deb6b1948fc9d82a5f330cfb 100644 --- a/src/trackers.py +++ b/src/trackers.py @@ -1,4 +1,7 @@ import numpy as np +import scipy.stats as stats +import scipy +import tqdm class BasicTracker(): def __init__(self, filt, clutter_model, associator, gater): @@ -57,3 +60,109 @@ class IMMTracker(): if k < len(Y)-1: xm[:, k+1, :], Pm[:, :, k+1, :] = self.filt.propagate(xm[:, k, :], Pm[:, :, k, :]) return x, P + +class GNN(): + def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model): + self.logic = logic + self.logic_params = logic_params + self.init_track = init_track + self.filt = filt + self.gater = gater + self.clutter_model = clutter_model + + def _update_track(self, meas, track): + if meas.size == 0: + track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track + return + # Calculate prediction error of each measurement + yhat = track['filt'].sensor_model['h'](track['x'][-1]) + + eps = meas-yhat + track = self.logic(meas, track['filt'], track, self.logic_params) + + # Update + track['x'][-1], track['P'][-1] = track['filt'].update(track['x'][-1], track['P'][-1], eps) + + def evaluate(self, Y): + tracks = [] # Store all tracks + confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only) + ids = 0 + for k, meas_k in tqdm.tqdm(enumerate(Y), desc="Evaluating observations: "): + ny = meas_k.shape[1] + unused_meas = np.ones((ny), dtype=bool) + + live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed'] + + if live_tracks: + association_matrix, _ = get_association_matrix(meas_k, live_tracks, self.logic_params, self.gater) + # Solve association problem + row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix) + + for row, col in zip(row_ind, col_ind): + if col >= len(live_tracks): # No target to associate the measurement to + continue + else: + unused_meas[row] = 0 # Remove this measurement from further consideration + # Update confirmed tracks + self._update_track(meas_k[:, row], live_tracks[col]) + live_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes) + for i in range(len(live_tracks)): + if i not in col_ind: + self._update_track(np.array([]), live_tracks[i]) + + + tentative_tracks = [track for track in tracks if track['stage'] == 'tentative'] + + if tentative_tracks: + association_matrix, _ = get_association_matrix(meas_k[:, unused_meas], tentative_tracks, self.logic_params, self.gater) + # Solve association problem + row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix) + meas = meas_k[:, unused_meas] + for row, col in zip(row_ind, col_ind): + if col >= len(tentative_tracks): # No target to associate the measurement to + continue + else: + unused_meas[(meas_k == meas[:,[row]]).all(axis=0)] = 0 # Remove this measurement from consideration + # Update confirmed tracks + self._update_track(meas[:, row], tentative_tracks[col]) + tentative_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes) + if tentative_tracks[col]['stage'] == 'confirmed': + confirmed_tracks.append(tentative_tracks[col]) # If a track has been confirmed, add it to confirmed tracks + for i in range(len(tentative_tracks)): + if i not in col_ind: + self._update_track(np.array([]), tentative_tracks[i]) + + # Use the unused measurements to initiate new tracks + for meas in meas_k[:, unused_meas].T: + tracks.append(self.init_track(meas, k, ids, self.filt)) + ids += 1 + + for track in tracks: + if track['stage'] != 'deleted': + x, P = track['filt'].propagate(track['x'][-1], track['P'][-1]) + track['x'].append(x) + track['P'].append(P) + track['t'].append(k+1) + return tracks, confirmed_tracks + + +def get_association_matrix(meas, tracks, logic_params, gater): + ny = meas.shape[1] + Nc = len(tracks) # Number of tracks to associate + validation_matrix = np.zeros((ny, Nc), dtype=bool) + + association_matrix = -np.inf*np.ones((ny, Nc+2*ny)) + # Entry for false alarms + np.fill_diagonal(association_matrix[:, Nc:Nc+ny], np.log(logic_params['Bfa'])) + # Entry for new targets + np.fill_diagonal(association_matrix[:, Nc+ny:], np.log(logic_params['Bnt'])) + + for ti, track in enumerate(tracks): # Iterate over confirmed tracks + validation_matrix[:, ti] = gater.gate(track['x'][-1], track['P'][-1], meas) + # Entry for validated tracks + val_meas = meas[:, validation_matrix[:, ti]] # Get the validated measurements for this track + yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track + H = track['filt'].sensor_model['dhdx'](track['x'][-1]) + py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R']) + association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1 + return association_matrix, validation_matrix diff --git a/src/utility.py b/src/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..4c29b123be0189cb880e5d12478ffc9bb61309a5 --- /dev/null +++ b/src/utility.py @@ -0,0 +1,21 @@ +import numpy as np + + +def match_tracks_to_ground_truth(tracks, ground_truth): + matches = {} + # Match tracks to ground truth + for track in tracks: + x = np.vstack(track['x']).T + t = np.hstack(track['t']).flatten() + ormse = 1e10 + for key, T in ground_truth.items(): + if T.shape[1] > x.shape[1]: + N = x.shape[1] + else: + N = T.shape[1] + # Only compare times present in both ground truth and estimate + rmse = np.sum((T[:, t[t<N]]-x[:2, t[t<N]])**2)/N + if rmse < ormse: # The ground truth with the lowest to track RMSE is assumed to be correct + matches[track['identity']] = key + ormse = rmse + return matches