From a8f3128c20171f8a41f13909cce9cf9588fae032 Mon Sep 17 00:00:00 2001
From: Anton Kullberg <anton.kullberg@liu.se>
Date: Tue, 9 Nov 2021 14:43:18 +0100
Subject: [PATCH] py: implemented GNN and JPDA

---
 src/trackers.py | 414 +++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 376 insertions(+), 38 deletions(-)

diff --git a/src/trackers.py b/src/trackers.py
index fce49b5..84bfc61 100644
--- a/src/trackers.py
+++ b/src/trackers.py
@@ -1,7 +1,9 @@
+import pdb
 import numpy as np
 import scipy.stats as stats
 import scipy
 import tqdm
+from pyehm.core import EHM2
 
 class BasicTracker():
     def __init__(self, filt, clutter_model, associator, gater):
@@ -63,6 +65,25 @@ class IMMTracker():
 
 class GNN():
     def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model):
+        """An implementation of a Global Nearest Neighbour tracker.
+
+        Parameters
+        ----------
+        logic : logic
+            See src.logic. Some sort of track logic.
+        logic_params : dict
+            Contains parameters to the track logic.
+        init_track : callable
+            A function that initiates a track. Should take a measurement, the
+            time, an id and the filter to use for the track as input.
+        filt : filter
+            See src.filter. Some sort of filter to use for the tracks.
+        gater : gater
+            See src.gater. A gating function.
+        clutter_model : dict
+            A dict containing the clutter model.
+
+        """
         self.logic = logic
         self.logic_params = logic_params
         self.init_track = init_track
@@ -71,6 +92,19 @@ class GNN():
         self.clutter_model = clutter_model
 
     def _update_track(self, meas, track):
+        """Handles the update of a certain track with the given measurement(s).
+
+        Modifies the track in-place!
+
+        Parameters
+        ----------
+        meas : numpy.ndarray
+            Contains measurement(s) to update a specific track with. ny by N,
+            where N is the number of measurements to update the track with.
+        track : dict
+            A dict containing everything relevant to the track.
+
+        """
         if meas.size == 0:
             track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track
             return
@@ -83,58 +117,272 @@ class GNN():
         # Update
         track['x'][-1], track['P'][-1] = track['filt'].update(track['x'][-1], track['P'][-1], eps)
 
+    def associate_update(self, meas_k, k, tracks, unused_meas):
+        """Associates measurements to tracks and updates the tracks with the
+        measurements.
+
+        Does *not* return anything, but modifies the objects in-place!
+        Uses Efficient Hypothesis Management (see
+        https://github.com/sglvladi/pyehm) for computing the hypothesis
+        probabilities.
+
+        Parameters
+        ----------
+        meas_k : numpy.ndarray
+            Measurements to attempt to associate
+        k : int
+            The current time step
+        tracks : list
+            A list of the tracks to associate the measurements with
+        unused_meas : numpy.ndarray
+            A logical array indicating what measurements are still
+            unused/non-associated.
+
+        """
+        association_matrix, validation_matrix = get_association_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater)
+        # Solve association problem
+        row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix)
+        for row, col in zip(row_ind, col_ind):
+            if col >= len(tracks): # No target to associate the measurement to
+                continue
+            else:
+                # Update confirmed tracks
+                self._update_track(meas_k[:, unused_meas][:, row], tracks[col])
+                tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
+        for i in range(len(tracks)):
+            if i not in col_ind:
+                self._update_track(np.array([]), tracks[i])
+        # Remove any gated measurements from further consideration
+        tmp = unused_meas[unused_meas] # Extract the unused measurements
+        inds = np.where(validation_matrix.sum(axis=1))
+        tmp[inds] = 0
+        unused_meas[unused_meas] = tmp
+
     def evaluate(self, Y):
+        """ Evaluates the detections in Y.
+
+        Parameters
+        ----------
+        Y : list
+            List of detections at time k=0 to K where K is the length of Y.
+            Each entry of Y is ny by N_k where N_k is time-varying as the number
+            of detections vary.
+
+        Returns
+        -------
+        list, list
+            First list contains all initiated tracks, both tentative, deleted
+            and confirmed. The second list contains only the confirmed list,
+            even if they have died. Hence, the lists contain duplicates (but
+            point to the same object!).
+
+        """
+        rng = np.random.default_rng()
         tracks = [] # Store all tracks
         confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only)
         ids = 0
-        for k, meas_k in tqdm.tqdm(enumerate(Y), desc="Evaluating observations: "):
+        for k, meas_k in tqdm.tqdm(enumerate(Y), desc="GNN evaluating detections: "):
             ny = meas_k.shape[1]
             unused_meas = np.ones((ny), dtype=bool)
 
+            # Handle the confirmed and alive tracks
             live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed']
-
             if live_tracks:
-                association_matrix, _ = get_association_matrix(meas_k, live_tracks, self.logic_params, self.gater)
-                # Solve association problem
-                row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix)
-
-                for row, col in zip(row_ind, col_ind):
-                    if col >= len(live_tracks): # No target to associate the measurement to
-                        continue
-                    else:
-                        unused_meas[row] = 0 # Remove this measurement from further consideration
-                        # Update confirmed tracks
-                        self._update_track(meas_k[:, row], live_tracks[col])
-                        live_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
-                for i in range(len(live_tracks)):
-                    if i not in col_ind:
-                        self._update_track(np.array([]), live_tracks[i])
-
+                self.associate_update(meas_k, k, live_tracks, unused_meas)
 
+            # Handle the tentative tracks
             tentative_tracks = [track for track in tracks if track['stage'] == 'tentative']
-
             if tentative_tracks:
-                association_matrix, _ = get_association_matrix(meas_k[:, unused_meas], tentative_tracks, self.logic_params, self.gater)
-                # Solve association problem
-                row_ind, col_ind = scipy.optimize.linear_sum_assignment(-association_matrix)
-                meas = meas_k[:, unused_meas]
-                for row, col in zip(row_ind, col_ind):
-                    if col >= len(tentative_tracks): # No target to associate the measurement to
-                        continue
-                    else:
-                        unused_meas[(meas_k == meas[:,[row]]).all(axis=0)] = 0 # Remove this measurement from consideration
-                        # Update confirmed tracks
-                        self._update_track(meas[:, row], tentative_tracks[col])
-                        tentative_tracks[col]['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
-                        if tentative_tracks[col]['stage'] == 'confirmed':
-                            confirmed_tracks.append(tentative_tracks[col]) # If a track has been confirmed, add it to confirmed tracks
-                for i in range(len(tentative_tracks)):
-                    if i not in col_ind:
-                        self._update_track(np.array([]), tentative_tracks[i])
+                self.associate_update(meas_k, k, tentative_tracks, unused_meas)
+                for track in tentative_tracks:
+                    if track['stage'] == 'confirmed':
+                        confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks
 
             # Use the unused measurements to initiate new tracks
-            for meas in meas_k[:, unused_meas].T:
-                tracks.append(self.init_track(meas, k, ids, self.filt))
+            while unused_meas.any():
+                ind = rng.choice(np.arange(unused_meas.size), p=unused_meas/unused_meas.sum())
+                track = self.init_track(meas_k[:, ind], k, ids, self.filt) # Initialize track
+                tracks.append(track)
+                unused_meas[ind] = 0 # Remove measurement from association hypothesis
+                self.associate_update(meas_k, k, [track], unused_meas)
+                ids += 1
+                if track['stage'] == 'confirmed':
+                    confirmed_tracks.append(track)
+
+            for track in tracks:
+                if track['stage'] != 'deleted':
+                    x, P = track['filt'].propagate(track['x'][-1], track['P'][-1])
+                    track['x'].append(x)
+                    track['P'].append(P)
+                    track['t'].append(k+1)
+        return tracks, confirmed_tracks
+
+class JPDA():
+    def __init__(self, logic, logic_params, init_track, filt, gater, clutter_model):
+        """An implementation of a Joint Probabilistic Data Association tracker.
+
+        Parameters
+        ----------
+        logic : logic
+            See src.logic. Some sort of track logic.
+        logic_params : dict
+            Contains parameters to the track logic.
+        init_track : callable
+            A function that initiates a track. Should take a measurement, the
+            time, an id and the filter to use for the track as input.
+        filt : filter
+            See src.filter. Some sort of filter to use for the tracks.
+        gater : gater
+            See src.gater. A gating function.
+        clutter_model : dict
+            A dict containing the clutter model.
+
+        """
+        self.logic = logic
+        self.logic_params = logic_params
+        self.init_track = init_track
+        self.filt = filt
+        self.gater = gater
+        self.clutter_model = clutter_model
+
+    def _update_track(self, meas, track, association_probability):
+        """Handles the update of a certain track with the given measurement(s).
+
+        Modifies the track in-place!
+
+        Parameters
+        ----------
+        meas : numpy.ndarray
+            Contains measurement(s) to update a specific track with. ny by N,
+            where N is the number of measurements to update the track with.
+        track : dict
+            A dict containing everything relevant to the track.
+        association_probability : numpy.ndarray
+            The association probability of each measurement to the track.
+
+        """
+        if meas.size == 0:
+            track = self.logic(np.array([]), track['filt'], track, self.logic_params) # If no meas associated, still update logic of track
+            return
+        # Calculate prediction error of each measurement
+        yhat = track['filt'].sensor_model['h'](track['x'][-1])
+        eps = meas-yhat[:, None]
+
+        # Update track
+        Lt = []
+        x = []
+        P = []
+        oLt = track['Lt']
+        # Handle false measurement separately
+        x.append(track['x'][-1])
+        P.append(track['P'][-1])
+        track = self.logic(np.array([]), track['filt'], track, self.logic_params)
+        Lt.append(track['Lt'])
+        track['Lt'] = oLt # Reset track score
+        for j in range(eps.shape[1]):
+            # Compute the track score given an association
+            track = self.logic(meas[:, j], track['filt'], track, self.logic_params)
+            Lt.append(track['Lt'])
+            track['Lt'] = oLt # Reset the current track score
+            xj, Pj = track['filt'].update(track['x'][-1], track['P'][-1], eps[:, j])
+            x.append(xj)
+            P.append(Pj)
+        # Update the track score by marginalizing the measurement associations
+        track['Lt'] = np.array(Lt)@association_probability
+        # Compute the state estimate
+        xhat = (association_probability@np.vstack(x)).T
+        # Compute the state error covariance
+        err = np.vstack(x).T-xhat[:, None]
+        Pk = np.stack([col[:,None]@col[None,:] for col in err.T])
+        Phat = np.tensordot(np.stack(P)+Pk, association_probability, (0, 0)) # Dot product over axis 0
+        track['x'][-1] = xhat
+        track['P'][-1] = Phat
+
+    def associate_update(self, meas_k, k, tracks, unused_meas):
+        """Associates measurements to tracks and updates the tracks with the
+        measurements.
+
+        Does *not* return anything, but modifies the objects in-place!
+        Uses Efficient Hypothesis Management (see
+        https://github.com/sglvladi/pyehm) for computing the hypothesis
+        probabilities.
+
+        Parameters
+        ----------
+        meas_k : numpy.ndarray
+            Measurements to attempt to associate
+        k : int
+            The current time step
+        tracks : list
+            A list of the tracks to associate the measurements with
+        unused_meas : numpy.ndarray
+            A logical array indicating what measurements are still
+            unused/non-associated.
+
+        """
+        likelihood_matrix, validation_matrix = get_likelihood_matrix(meas_k[:, unused_meas], tracks, self.logic_params, self.gater)
+        association_matrix = EHM2.run(validation_matrix, likelihood_matrix)
+
+        meas = meas_k[:, unused_meas]
+        for ti, track in enumerate(tracks):
+            if validation_matrix[ti, 1:].any(): # If any measurements are validated to this track, update it accordingly
+                self._update_track(meas[:, validation_matrix[ti, 1:].flatten()], track, association_matrix[ti, validation_matrix[ti, :]])
+                track['associations'].append(k) # If we've associated something, add the time here (for plotting purposes)
+            else:
+                self._update_track(np.array([]), track, None)
+        # Measurements that are validated to any track can be removed from further association
+        tmp = unused_meas[unused_meas]
+        used_inds = np.where(validation_matrix[:, 1:].sum(axis=0))
+        tmp[used_inds] = 0
+        unused_meas[unused_meas] = tmp
+
+    def evaluate(self, Y):
+        """ Evaluates the detections in Y.
+
+        Parameters
+        ----------
+        Y : list
+            List of detections at time k=0 to K where K is the length of Y.
+            Each entry of Y is ny by N_k where N_k is time-varying as the number
+            of detections vary.
+
+        Returns
+        -------
+        list, list
+            First list contains all initiated tracks, both tentative, deleted
+            and confirmed. The second list contains only the confirmed list,
+            even if they have died. Hence, the lists contain duplicates (but
+            point to the same object!).
+
+        """
+        rng = np.random.default_rng()
+        tracks = [] # Store all tracks
+        confirmed_tracks = [] # Store the confirmed tracks (for plotting purposes only)
+        ids = 0
+        for k, meas_k in tqdm.tqdm(enumerate(Y), desc="JPDA evaluating detections: "):
+            ny = meas_k.shape[1]
+            unused_meas = np.ones((ny), dtype=bool)
+
+            # Handle confirmed and alive tracks
+            live_tracks = [track for track in confirmed_tracks if track['stage']=='confirmed']
+            if live_tracks:
+                self.associate_update(meas_k, k, live_tracks, unused_meas)
+
+            # Handle tentative tracks
+            tentative_tracks = [track for track in tracks if track['stage']=='tentative']
+            if tentative_tracks:
+                self.associate_update(meas_k, k, tentative_tracks, unused_meas)
+                for track in tentative_tracks:
+                    if track['stage'] == 'confirmed':
+                        confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks
+
+            # Initiate new tracks at random
+            while unused_meas.any():
+                ind = rng.choice(np.arange(unused_meas.size), p=unused_meas/unused_meas.sum())
+                track = self.init_track(meas_k[:, ind], k, ids, self.filt) # Initialize track
+                tracks.append(track)
+                unused_meas[ind] = 0 # Remove measurement from association hypothesis
+                self.associate_update(meas_k, k, [track], unused_meas)
                 ids += 1
 
             for track in tracks:
@@ -145,8 +393,72 @@ class GNN():
                     track['t'].append(k+1)
         return tracks, confirmed_tracks
 
+def get_likelihood_matrix(meas, tracks, logic_params, gater):
+    """ Computes the likelihood and validation matrix (specifically for the
+        JPDA implementation)
+
+    Parameters
+    ----------
+    meas : numpy.ndarray
+        Measurements to attempt to associate
+    tracks : list
+        A list of the tracks to associate the measurements with
+    logic_params : dict
+        Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.)
+    gater : gater
+        A gater that can gate the measurements with the tracks.
+
+    Returns
+    -------
+    numpy.ndarray, numpy.ndarray
+        Returns a likelihood matrix containing the unnormalized likelihood of
+        associating measurement i with track j. Also returns a validation matrix
+        indicating the possible associations of measurement i with track j.
+        Both the likelihood and validation matrices are of size Nc by ny+1,
+        where Nc is the number of tracks. The first column of the validation
+        and likelihood matrices corresponds to a false alarm.
+
+    """
+    ny = meas.shape[1]
+    Nc = len(tracks) # Number of tracks to associate
+    validation_matrix = np.zeros((Nc, ny+1), dtype=bool)
+    validation_matrix[:, 0] = 1
+    likelihood_matrix = np.zeros((Nc, ny+1))
+
+    for ti, track in enumerate(tracks): # Iterate over confirmed tracks
+        validation_matrix[ti, 1:] = gater.gate(track['x'][-1], track['P'][-1], meas)
+        # Entry for validated tracks
+        val_meas = meas[:, validation_matrix[ti, 1:]] # Get the validated measurements for this track
+        yhat = track['filt'].sensor_model['h'](track['x'][-1]) # Calculate the predicted measurement for this track
+        H = track['filt'].sensor_model['dhdx'](track['x'][-1])
+        py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
+        likelihood_matrix[ti, np.where(validation_matrix[ti, 1:])[0]+1] = track['filt'].sensor_model['PD']*py
+        likelihood_matrix[ti, 0] = 1-track['filt'].sensor_model['PD'] # PG assumed 1
+    return likelihood_matrix, validation_matrix
 
 def get_association_matrix(meas, tracks, logic_params, gater):
+    """ Computes the validation and association matrix (specifically for the
+        GNN implementation)
+
+    Parameters
+    ----------
+    meas : numpy.ndarray
+        Measurements to attempt to associate
+    tracks : list
+        A list of the tracks to associate the measurements with
+    logic_params : dict
+        Parameters of the track logic (for simplicity. The function needs the rate of false alarms and new targets, respectively.)
+    gater : gater
+        A gater that can gate the measurements with the tracks.
+
+    Returns
+    -------
+    numpy.ndarray, numpy.ndarray
+        Returns an association matrix of size ny by Nc+2*ny and a validation
+        matrix of size ny by Nc, where Nc is the number of tracks. The association
+        matrix also contains the false alarm and new track possibilities.
+
+    """
     ny = meas.shape[1]
     Nc = len(tracks) # Number of tracks to associate
     validation_matrix = np.zeros((ny, Nc), dtype=bool)
@@ -166,3 +478,29 @@ def get_association_matrix(meas, tracks, logic_params, gater):
         py = stats.multivariate_normal.pdf(val_meas.squeeze().T, mean=yhat.flatten(), cov=H@track['P'][-1]@H.T+track['filt'].sensor_model['R'])
         association_matrix[validation_matrix[:, ti], ti] = np.log(track['filt'].sensor_model['PD']*py/(1-track['filt'].sensor_model['PD'])) # PG assumed = 1
     return association_matrix, validation_matrix
+
+### Obsolete
+def compute_prob(association_matrix, validation_matrix, logic_params):
+    # Association matrix is assumed to consist of tracks and FA, no NT.
+    ny = association_matrix.shape[0]
+    ntracks = association_matrix.shape[1]-ny
+    P = np.zeros((ny, ntracks))
+
+    def rec_find_associations(association_matrix, assoc_done, logic_params):
+        inds = np.where(association_matrix[0, :] != -np.inf)[0] # These are the nodes necessary to look at
+        this_assoc = []
+        for k, i in enumerate(inds):
+            if i not in assoc_done:
+                if association_matrix.shape[0] != 1:
+                    assoc = rec_compute_prob(association_matrix[1:, :], [[i]], logic_params)
+                    this_assoc.extend(assoc)
+                else:
+                    this_assoc.append([i])
+        result = []
+        for assoc in assoc_done:
+            for th_assoc in this_assoc:
+                result.append(assoc + th_assoc)
+        return result
+
+    possible_associations = rec_find_associations(association_matrix, [[]], logic_params) # Recursively finds possible measurement hypothesis
+    return res
-- 
GitLab