From ef56ffcd776d43fbaf1ad0a0ad3de7967730365e Mon Sep 17 00:00:00 2001
From: Anton Kullberg <anton.kullberg@liu.se>
Date: Tue, 9 Nov 2021 14:44:18 +0100
Subject: [PATCH] py: added documentation to the code

---
 src/associators.py  |  14 ++++++
 src/filters.py      | 113 +++++++++++++++++++++++++++++++++++++++++++-
 src/gaters.py       |  31 ++++++++++++
 src/logic.py        |  23 ++++++++-
 src/models.py       |  44 ++++++++++++++++-
 src/trajectories.py |  16 +++++++
 src/utility.py      |  22 +++++++++
 7 files changed, 258 insertions(+), 5 deletions(-)

diff --git a/src/associators.py b/src/associators.py
index 08890bf..9ef73a8 100644
--- a/src/associators.py
+++ b/src/associators.py
@@ -1,7 +1,21 @@
 import numpy as np
 
 class NNAssociator:
+    """A nearest neighbour associator."""
     def associate(self, eps):
+        """Find the minimal residual.
+
+        Parameters
+        ----------
+        eps : numpy.ndarray
+            Residuals.
+
+        Returns
+        -------
+        int
+            The index of the minimal residual.
+
+        """
         r = np.linalg.norm(eps, axis=0)
         yind = np.argmin(r)
         return yind
diff --git a/src/filters.py b/src/filters.py
index 3dce5a9..622d3b7 100644
--- a/src/filters.py
+++ b/src/filters.py
@@ -1,22 +1,68 @@
-import pdb
 import numpy as np
 import scipy.stats as stats
 import jax
 
 class EKF:
     def __init__(self, motion_model, sensor_model):
+        """An implementation of an Extended Kalman Filter.
+
+        Automatically sets up the Jacobians of the motion and sensor models
+        through autodiff with the aid of Jax.
+
+        Parameters
+        ----------
+        motion_model : dict
+            The motion model to use for the filtering.
+        sensor_model : dict
+            The sensor model to use for the filtering.
+
+        """
         self.motion_model = motion_model
         self.motion_model['dfdx'] = jax.jacfwd(motion_model['f'])
         self.sensor_model = sensor_model
         self.sensor_model['dhdx'] = jax.jacfwd(sensor_model['h'])
 
     def propagate(self, x, P):
+        """Time update in the EKF.
+
+        Parameters
+        ----------
+        x : numpy.ndarray
+            The mean of the state estimate.
+        P : numpy.ndarray
+            The state error covariance.
+
+        Returns
+        -------
+        numpy.ndarray, numpy.ndarray
+            The updated mean and covariance.
+
+        """
         xp = self.motion_model['f'](x)
         F = self.motion_model['dfdx'](x.flatten())
         Pp = F@P@F.T+self.motion_model['Q']
         return xp, Pp
 
     def update(self, x, P, res):
+        """Measurement update in the EKF.
+
+        Modifies the estimates in-place.
+
+        Parameters
+        ----------
+        x : numpy.ndarray
+            The mean of the state estimate.
+        P : numpy.ndarray
+            The state error covariance.
+        res : numpy.ndarray
+            The residual to use for the update.
+
+        Returns
+        -------
+        numpy.ndarray, numpy.ndarray
+            The updated mean and covariance.
+
+        """
         H = self.sensor_model['dhdx'](x.flatten())
         Sk = H@P@H.T+self.sensor_model['R']
         Kk = P@H.T@np.linalg.inv(Sk)
@@ -26,6 +72,18 @@ class EKF:
 
 class IMM:
     def __init__(self, filters, sensor_model, transition_prob):
+        """An implementation of an Interacting Multiple Model filter.
+
+        Parameters
+        ----------
+        filters : list
+            The filters to use in the IMM (EKF, KF, etc.)
+        sensor_model : dict
+            The sensor model to use for the filtering.
+        transition_prob : numpy.ndarray
+            The transition probability between modes.
+
+        """
         self.filters = filters
         self.sensor_model = sensor_model
         self.sensor_model['dhdx'] = jax.jacfwd(sensor_model['h'])
@@ -34,7 +92,23 @@ class IMM:
         self.mode_probabilities = [np.ones(self.nmodes)/self.nmodes]
 
     def propagate(self, x, P):
-        mu = self.transition_prob*self.mode_probabilities[-1][:, None] # Calc. mixing probabilities
+        """Time update in the IMM.
+
+        Parameters
+        ----------
+        x : numpy.ndarray
+            The mean of the state estimates for the different modes.
+        P : numpy.ndarray
+            The state error covariance of the different modes.
+
+        Returns
+        -------
+        numpy.ndarray, numpy.ndarray
+            The updated mean and covariance.
+
+        """
+        # Calc. mixing probabilities
+        mu = self.transition_prob*self.mode_probabilities[-1][:, None]
         mu = mu/np.sum(mu, axis=0) # Normalize
         # Mixed state
         xp = np.hstack([x@mu[:, [k]] for k in range(self.nmodes)])
@@ -50,6 +124,25 @@ class IMM:
         return xp, Pp
 
     def update(self, x, P, meas):
+        """Measurement update in the IMM.
+
+        Modifies the estimates in-place.
+
+        Parameters
+        ----------
+        x : numpy.ndarray
+            The mean of the state estimate.
+        P : numpy.ndarray
+            The state error covariance.
+        res : numpy.ndarray
+            The residual to use for the update.
+
+        Returns
+        -------
+        numpy.ndarray, numpy.ndarray
+            The updated mean and covariance.
+
+        """
         # Mode probability update
         omega = []
         for i, xi in enumerate(x.T):
@@ -64,6 +157,22 @@ class IMM:
         return x, P
 
     def mix(self, x, P):
+        """Computes the combined state estimate based on the current mode
+        probabilities.
+
+        Parameters
+        ----------
+        x : numpy.ndarray
+            The mean of the state estimate for the modes.
+        P : numpy.ndarray
+            The state error covariances for the modes.
+
+        Returns
+        -------
+        numpy.ndarray, numpy.ndarray
+            The updated mean and covariance.
+
+        """
         xhat = x@self.mode_probabilities[-1]
         tP = np.zeros(P[:,:,0].shape)
         Phat = np.zeros(P[:,:,0].shape)
diff --git a/src/gaters.py b/src/gaters.py
index 6727280..96c7284 100644
--- a/src/gaters.py
+++ b/src/gaters.py
@@ -2,10 +2,41 @@ import numpy as np
 
 class MahalanobisGater:
     def __init__(self, sensor_model, gamma):
+        """A Mahalanobis gater.
+
+        Parameters
+        ----------
+        sensor_model : dict
+            A dict containing the sensor model. Needs the following fields:
+            h : callable
+                The actual sensor model in functional form.
+            dhdx : callable
+                The Jacobian of the sensor model.
+        gamma : float
+            A constant related to the gating probability.
+
+        """
         self.sensor_model = sensor_model
         self.gamma = gamma
 
     def gate(self, x, P, meas):
+        """ Gates the measurement(s) with the state estimate.
+
+        Parameters
+        ----------
+        x : numpy.ndarray
+            The mean of the state estimate.
+        P : numpy.ndarray
+            The state error covariance.
+        meas : numpy.ndarray
+            The measurement(s) to gate.
+
+        Returns
+        -------
+        numpy.ndarray
+            Logical array describing if a measurement is accepted or not.
+
+        """
         if meas.ndim < 2:
             y = meas.expand_dims(meas, 0)
         else:
diff --git a/src/logic.py b/src/logic.py
index 0d6a76a..d7d1244 100644
--- a/src/logic.py
+++ b/src/logic.py
@@ -2,7 +2,17 @@ import scipy.stats as stats
 import numpy as np
 
 def nm_logic(y, filt, state, params):
-    """
+    """ Implements an N/M logic for tracks.
+
+    Modifies everything in-place! Also returns the state (track).
+
+    Parameters
+    ----------
+    y : numpy.ndarray
+        The measurement to use for the logic. May be empty.
+    filt : filter
+        See src.filters. To ensure the logics have the same callsign, this is
+        included.
     state : dict
         Fields:
             stage : tentative/confirmed/deleted
@@ -46,7 +56,16 @@ def nm_logic(y, filt, state, params):
     return state
 
 def score_logic(y, filt, state, params):
-    """
+    """ Implements a score logic for tracks.
+
+    Modifies everything in-place! Also returns the state (track).
+
+    Parameters
+    ----------
+    y : numpy.ndarray
+        The measurement to use for the logic. May be empty.
+    filt : filter
+        See src.filters. The filter used for the track.
     state : dict
         Fields:
             stage : tentative/confirmed/deleted
diff --git a/src/models.py b/src/models.py
index 228c888..bedb264 100644
--- a/src/models.py
+++ b/src/models.py
@@ -3,7 +3,32 @@ import jax.numpy as jnp
 import jax
 
 # Setup sensor and clutter model
-def radar_model(R, PD): # Assumes positional coordinates first
+def radar_model(R, PD):
+    """A distance and bearing radar implementation.
+
+    Assumes positional coordinate first.
+
+    Parameters
+    ----------
+    R : numpy.ndarray
+        The measurement noise covariance.
+    PD : float
+        The probability of detection.
+
+    Returns
+    -------
+    dict
+        A dict containing the sensor model with the following entries:
+        h : callable
+            The actual sensor model in functional form.
+        dhdx : callable
+            The Jacobian of the sensor model w.r.t. the state.
+        R : numpy.ndarray
+            The measurement noise covariance.
+        PD : float
+            Probability of detection
+
+    """
     def h(x):
         if len(x.shape) < 2:
             xt = x.reshape(-1, 1)
@@ -16,6 +41,23 @@ def radar_model(R, PD): # Assumes positional coordinates first
     return sensor_model
 
 def cv_model(Q, D, T):
+    """A constant velocity model.
+
+    Parameters
+    ----------
+    Q : numpy.ndarray
+        The process noise covariance.
+    D : int
+        The dimension of the CV model (1, 2, 3)
+    T : float
+        Sampling interval
+
+    Returns
+    -------
+    dict
+        A dict with the motion model.
+
+    """
     # D - dimensions
     # T - sampling time
     # CV model
diff --git a/src/trajectories.py b/src/trajectories.py
index 620a741..f9c1373 100644
--- a/src/trajectories.py
+++ b/src/trajectories.py
@@ -1,6 +1,14 @@
 import numpy as np
 
 def get_ex1_trajectories():
+    """Generates the trajectories for exercise 1.
+
+    Returns
+    -------
+    dict
+        A dict with the trajectories.
+
+    """
     # Generate trajectories (Assuming a sampling time of Ts=1)
     # Trajectory 1
     vx = 100/3.6
@@ -24,6 +32,14 @@ def get_ex1_trajectories():
     return dict(T1=T1, T2=T2, T3=T3, T4=T4)
 
 def get_ex2_trajectories():
+    """Generates the trajectories for exercise 2.
+
+    Returns
+    -------
+    dict
+        A dict with the trajectories.
+
+    """
     T = get_ex1_trajectories()
     v = 100/3.6
     p0 = np.array([200, 1200])
diff --git a/src/utility.py b/src/utility.py
index 4c29b12..28f3b50 100644
--- a/src/utility.py
+++ b/src/utility.py
@@ -2,6 +2,28 @@ import numpy as np
 
 
 def match_tracks_to_ground_truth(tracks, ground_truth):
+    """Matches tracks to ground truth naively.
+
+    Compares each track to each ground truth by computing the RMSE. Assumes
+    that the track is found immediately at trajectory birth (restrictive
+    assumption). The lowest RMSE is assumed to be the correct ground truth
+    trajectory. Can match several trajectories to the same ground truth.
+
+    Parameters
+    ----------
+    tracks : list
+        A list of the tracks to match.
+    ground_truth : list
+        A list of ground truth trajectories.
+
+    Returns
+    -------
+    dict
+        A dictionary with matches.
+            key : track identity
+            value : ground truth trajectory identity
+
+    """
     matches = {}
     # Match tracks to ground truth
     for track in tracks:
-- 
GitLab