diff --git a/src/associators.py b/src/associators.py index 9ef73a8e7da1684a975792cb8e2f735983197cb5..029849330179a7a7e58f454688b9013fc17835b9 100644 --- a/src/associators.py +++ b/src/associators.py @@ -1,3 +1,4 @@ +"""Contains associatiors (Nearest Neighbour etc.)""" import numpy as np class NNAssociator: diff --git a/src/filters.py b/src/filters.py index 622d3b7abae6ad625e4374a1f60e099a11b55147..a1d0e156b83a77229cc03db41b6c1a6df43d0e0f 100644 --- a/src/filters.py +++ b/src/filters.py @@ -1,3 +1,5 @@ +"""Contains filter implementations. +""" import numpy as np import scipy.stats as stats import jax @@ -66,9 +68,9 @@ class EKF: H = self.sensor_model['dhdx'](x.flatten()) Sk = H@P@H.T+self.sensor_model['R'] Kk = P@H.T@np.linalg.inv(Sk) - x +=(Kk@res) - P -= Kk@Sk@Kk.T - return x, P + xf = x+Kk@res + Pf = P-Kk@Sk@Kk.T + return xf, Pf class IMM: def __init__(self, filters, sensor_model, transition_prob): diff --git a/src/gaters.py b/src/gaters.py index e84211d0e556b0f8a5c0333db2289024c2e3a4ac..01c8e4baf8b7afa06874663dda787af9cdd9aa31 100644 --- a/src/gaters.py +++ b/src/gaters.py @@ -1,3 +1,4 @@ +"""Contains gaters that validates measurements to state estimates.""" import numpy as np class MahalanobisGater: diff --git a/src/logic.py b/src/logic.py index d7d12443cfe3aa8d75fdaf8a38ebb5d10935243f..8c6102d73b3653a55f6331b0054dc16955eabf41 100644 --- a/src/logic.py +++ b/src/logic.py @@ -1,3 +1,4 @@ +"""Contains track logic.""" import scipy.stats as stats import numpy as np diff --git a/src/models.py b/src/models.py index bedb2647c5b4386c591d437250116374ca7b85a8..dc892fc38ea2a300594ea86bf2cc263483f42804 100644 --- a/src/models.py +++ b/src/models.py @@ -1,3 +1,4 @@ +"""Contains both sensor and motion models.""" import numpy as np import jax.numpy as jnp import jax diff --git a/src/plotters.py b/src/plotters.py index bf93c24b36032c94acece440f9567bac724f50c8..a0f08463474d9d298f90f3c404612ffa0a96f750 100644 --- a/src/plotters.py +++ b/src/plotters.py @@ -1,3 +1,4 @@ +"""Plotting functionality for the exercise sessions.""" import numpy as np import matplotlib.pyplot as plt @@ -63,26 +64,27 @@ def plot_result_ex2_2(result, trajs): for key, T in trajs.items(): ax[1].plot(T[0, :], T[1, :], color='k', lw=2) + confirmed_id = [track['identity'] for track in result['confirmed_tracks']] for track in result['tracks']: x = np.vstack(track['x']) t = np.hstack(track['t']).flatten() assoc = np.hstack(track['associations']).flatten() - if track in result['confirmed_tracks']: + if track['identity'] in confirmed_id: ls = '-' - l = ax[0].plot(t, track['identity']*np.ones(t.shape), ls=ls, markersize=3)[0] - ax[0].plot(assoc, track['identity']*np.ones(assoc.shape), ls='', color=l.get_color(), marker='x', markersize=6) - ax[1].plot(x[:, 0], x[:, 1], ls=ls, color=l.get_color(), lw=3) + l = ax[0].plot(t, [str(track['identity'])]*t.size, ls=ls, markersize=3)[0] + ax[0].plot(assoc, [str(track['identity'])]*assoc.size, ls='', color=l.get_color(), marker='x', markersize=6) + ax[1].plot(x[:, 0], x[:, 1], ls=ls, color=l.get_color(), lw=3, label='Track {}'.format(track['identity'])) else: ls = '--' ax[1].plot(x[:, 0], x[:, 1], ls=ls, lw=2) - ax[0].set_ylabel('Track identity') ax[0].set_title('Confirmed tracks over time') ax[0].set_xlabel('Time index, k') - ax[1].plot(yx, yy, '.', color='k') + ax[1].plot(yx, yy, '.', color='k', label='Measurements') ax[1].set_xlabel(r'$p_x$') ax[1].set_ylabel(r'$p_y$') ax[1].set_title('Measurements and measurement predictions + tracks') + ax[1].legend(loc=2) # Plot the RMSE for the matched trajectories for track_id, key in result['matches'].items(): @@ -122,23 +124,22 @@ def plot_result_ex2_24(result): x = np.vstack(track['x']) t = np.hstack(track['t']).flatten() assoc = np.hstack(track['associations']).flatten() - if track['identity'] in confirmed_id: ls = '-' - l = ax[0].plot(t, track['identity']*np.ones(t.shape), ls=ls, markersize=3)[0] - ax[0].plot(assoc, track['identity']*np.ones(assoc.shape), ls='', color=l.get_color(), marker='x', markersize=6) - ax[1].plot(x[:, 0], x[:, 1], ls=ls, color=l.get_color(), lw=3) + l = ax[0].plot(t, [str(track['identity'])]*t.size, ls=ls, markersize=3)[0] + ax[0].plot(assoc, [str(track['identity'])]*assoc.size, ls='', color=l.get_color(), marker='x', markersize=6) + ax[1].plot(x[:, 0], x[:, 1], ls=ls, color=l.get_color(), lw=3, label='Track {}'.format(track['identity'])) else: ls = '--' ax[1].plot(x[:, 0], x[:, 1], ls=ls, lw=2) - ax[0].set_ylabel('Track identity') ax[0].set_title('Confirmed tracks over time') ax[0].set_xlabel('Time index, k') - ax[1].plot(yx, yy, '.', color='k') + ax[1].plot(yx, yy, '.', color='k', label='Measurements') ax[1].set_xlabel(r'$p_x$') ax[1].set_ylabel(r'$p_y$') ax[1].set_title('Measurements and measurement predictions + tracks') + ax[1].legend(loc=2) ax[1].set_xlim([-2000, 2000]) ax[1].set_ylim([-21000, -17000]) return fig diff --git a/src/sim.py b/src/sim.py index be98ad38ffd4e3de61985b53e93f531ae14bcfd3..89735570255a79b63482b05e81e37316a2b0651e 100644 --- a/src/sim.py +++ b/src/sim.py @@ -1,3 +1,4 @@ +"""Simulation functionality (data generation)""" import numpy as np def generate_data(trajectories, sensor_model, clutter_model, rng=None): diff --git a/src/trajectories.py b/src/trajectories.py index f9c13739a5a3a0ed549907debb37317a62045523..f9b9781d14d726ebb9be717a63bda6c66eb8c0a5 100644 --- a/src/trajectories.py +++ b/src/trajectories.py @@ -1,3 +1,4 @@ +"""Contains exercise trajectories (deterministic)""" import numpy as np def get_ex1_trajectories(): diff --git a/src/utility.py b/src/utility.py index dae556d95075e0380d7315b687aef75c79e1588c..1d969c0f93884fc8dacfb2f4a5b8cd090a8d458e 100644 --- a/src/utility.py +++ b/src/utility.py @@ -1,3 +1,4 @@ +"""Contains utility functionality.""" import copy import numpy as np import murty as murty_ @@ -44,6 +45,29 @@ def match_tracks_to_ground_truth(tracks, ground_truth): return matches def recreate_trajectories(hypothesis, marginalize=True): + """Recreates target trajectories from an MHT output (see src.trackers). + + It takes a dict of hypothesis where each key corresponds to a certain time + step. Each such value then is a list of plausible hypothesis for that + particular time instance. + + Parameters + ---------- + hypothesis : dict + A dict with keys corresponding to specific time steps. Values are lists + of varying length of plausible hypothesis in each time step. + marginalize : bool + If True, it marginalizes over the hypothesis in each time step. Otherwise + returns the MAP estimate. + + Returns + ------- + list + A list of all tracks, tentative, confirmed and deleted. + list + A list of only confirmed tracks, even if they have died. + + """ confirmed_tracks = dict() tracks = dict() for t, hyp_t in hypothesis.items(): @@ -109,10 +133,14 @@ def recreate_trajectories(hypothesis, marginalize=True): if tracks[track_identity]['stage'] == 'confirmed': if track_identity not in confirmed_tracks.keys(): confirmed_tracks[track_identity] = tracks[track_identity] - return tracks.values(), confirmed_tracks.values() + return list(tracks.values()), list(confirmed_tracks.values()) def murty(C): - """Algorithm due to Murty.""" + """Algorithm due to Murty. + + This particular implementation is taken from Jonatan Olofsson, see + https://github.com/jonatanolofsson/mht. + """ mgen = murty_.Murty(C) while True: ok, cost, sol = mgen.draw() @@ -120,13 +148,38 @@ def murty(C): return None yield cost, sol -# Save result def save_result(filename, result): + """Save a tracking result. + + Parameters + ---------- + filename : str + A filename to save the result to + result : dict + A dictionary to save + + """ for track in result['tracks']: track.pop('filt', None) # Pop the filter object so it is only data that is saved. np.save(filename, result) def load_result(filename): + """Load a tracking result. + + NOTE: THIS IS NOT SAFE LOADING. PICKLED OBJECTS MAY EXECUTE MALICIOUS CODE. + MAKE SURE TO VERIFY THAT YOU ARE LOADING SPECIFICALLY WHAT YOU WANT. + + Parameters + ---------- + filename : str + A filename to load + + Returns + ---------- + - + Returns whatever data is stored in the file. + + """ if '.npy' not in filename: filename += '.npy' result = np.load(filename, allow_pickle=True)