diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..228c88864ed980a9a70cefac827e17692cf5ace1 --- /dev/null +++ b/src/models.py @@ -0,0 +1,27 @@ +import numpy as np +import jax.numpy as jnp +import jax + +# Setup sensor and clutter model +def radar_model(R, PD): # Assumes positional coordinates first + def h(x): + if len(x.shape) < 2: + xt = x.reshape(-1, 1) + else: + xt = x + target_range = jnp.linalg.norm(xt[:2, :], axis=0) # Use JAX numpy to be able to auto-differentiate + target_bearing = jnp.arctan2(xt[1, :], xt[0, :]) + return jnp.vstack([target_range, target_bearing]).squeeze() + sensor_model = dict(h=h, R=R, PD=PD, dhdx=jax.jacfwd(h)) + return sensor_model + +def cv_model(Q, D, T): + # D - dimensions + # T - sampling time + # CV model + F = np.identity(2*D) + F[:D, D:] = np.identity(D)*T + G = np.vstack([np.identity(D)*T, T**2/2*np.identity(D)]) + f = lambda x: F@x + motion_model = dict(f=f, Q=G@Q@G.T, dfdx=jax.jacfwd(f)) + return motion_model