From a8370789a50c84d66d1d13d1eeb9fed476e1b1e2 Mon Sep 17 00:00:00 2001
From: Anton Kullberg <anton.kullberg@liu.se>
Date: Mon, 1 Nov 2021 16:57:21 +0100
Subject: [PATCH] py: updated gater interface

---
 src/gaters.py   | 11 ++++++++---
 src/trackers.py |  4 ++--
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/src/gaters.py b/src/gaters.py
index 4e651f6..6727280 100644
--- a/src/gaters.py
+++ b/src/gaters.py
@@ -5,12 +5,17 @@ class MahalanobisGater:
         self.sensor_model = sensor_model
         self.gamma = gamma
 
-    def gate(self, x, P, eps):
+    def gate(self, x, P, meas):
+        if meas.ndim < 2:
+            y = meas.expand_dims(meas, 0)
+        else:
+            y = meas
+        eps = y-self.sensor_model['h'](x.flatten())[:, None]
         # Derivative
         H = self.sensor_model['dhdx'](x.flatten())
         # Mahalanobis Gating
-        T = np.zeros((eps.shape[1],))
-        for j in range(eps.shape[1]):
+        T = np.zeros((y.shape[1],))
+        for j in range(y.shape[1]):
             Sk = H@P@H.T+self.sensor_model['R']
             T[j] = eps[:, j].T@np.linalg.inv(Sk)@eps[:, j]
         accepted_meas = T < self.gamma
diff --git a/src/trackers.py b/src/trackers.py
index b072166..d07b574 100644
--- a/src/trackers.py
+++ b/src/trackers.py
@@ -13,7 +13,7 @@ class BasicTracker():
             yhat = self.filt.sensor_model['h'](x[:2, k])
             eps = meas_k-yhat[:, None]
             # Gating step
-            accepted_meas = self.gater.gate(x[:, k], P[:, :, k], eps)
+            accepted_meas = self.gater.gate(x[:, k], P[:, :, k], meas_k)
             # If any measurements are accepted, select the nearest one
             if accepted_meas.any():
                 # Association step
@@ -43,7 +43,7 @@ class IMMTracker():
 
             eps = meas_k-yhat[:, None]
             # Gating step
-            accepted_meas = self.gater.gate(xhat, Phat, eps)
+            accepted_meas = self.gater.gate(xhat, Phat, meas_k)
             # If any measurements are accepted, select the nearest one
             if accepted_meas.any():
                 # Association step
-- 
GitLab