From 820eb2d8d8867e2b1437590125fe7c3f35b32dc2 Mon Sep 17 00:00:00 2001
From: Anton Kullberg <anton.kullberg@liu.se>
Date: Mon, 1 Nov 2021 16:58:13 +0100
Subject: [PATCH] py: added module with models

---
 src/models.py | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)
 create mode 100644 src/models.py

diff --git a/src/models.py b/src/models.py
new file mode 100644
index 0000000..228c888
--- /dev/null
+++ b/src/models.py
@@ -0,0 +1,27 @@
+import numpy as np
+import jax.numpy as jnp
+import jax
+
+# Setup sensor and clutter model
+def radar_model(R, PD): # Assumes positional coordinates first
+    def h(x):
+        if len(x.shape) < 2:
+            xt = x.reshape(-1, 1)
+        else:
+            xt = x
+        target_range = jnp.linalg.norm(xt[:2, :], axis=0) # Use JAX numpy to be able to auto-differentiate
+        target_bearing = jnp.arctan2(xt[1, :], xt[0, :])
+        return jnp.vstack([target_range, target_bearing]).squeeze()
+    sensor_model = dict(h=h, R=R, PD=PD, dhdx=jax.jacfwd(h))
+    return sensor_model
+
+def cv_model(Q, D, T):
+    # D - dimensions
+    # T - sampling time
+    # CV model
+    F = np.identity(2*D)
+    F[:D, D:] = np.identity(D)*T
+    G = np.vstack([np.identity(D)*T, T**2/2*np.identity(D)])
+    f = lambda x: F@x
+    motion_model = dict(f=f, Q=G@Q@G.T, dfdx=jax.jacfwd(f))
+    return motion_model
-- 
GitLab