diff --git a/src/models.py b/src/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..228c88864ed980a9a70cefac827e17692cf5ace1
--- /dev/null
+++ b/src/models.py
@@ -0,0 +1,27 @@
+import numpy as np
+import jax.numpy as jnp
+import jax
+
+# Setup sensor and clutter model
+def radar_model(R, PD): # Assumes positional coordinates first
+    def h(x):
+        if len(x.shape) < 2:
+            xt = x.reshape(-1, 1)
+        else:
+            xt = x
+        target_range = jnp.linalg.norm(xt[:2, :], axis=0) # Use JAX numpy to be able to auto-differentiate
+        target_bearing = jnp.arctan2(xt[1, :], xt[0, :])
+        return jnp.vstack([target_range, target_bearing]).squeeze()
+    sensor_model = dict(h=h, R=R, PD=PD, dhdx=jax.jacfwd(h))
+    return sensor_model
+
+def cv_model(Q, D, T):
+    # D - dimensions
+    # T - sampling time
+    # CV model
+    F = np.identity(2*D)
+    F[:D, D:] = np.identity(D)*T
+    G = np.vstack([np.identity(D)*T, T**2/2*np.identity(D)])
+    f = lambda x: F@x
+    motion_model = dict(f=f, Q=G@Q@G.T, dfdx=jax.jacfwd(f))
+    return motion_model