diff --git a/exercise-session-1.ipynb b/exercise-session-1.ipynb index 780e1c864d666adfa3f0a2b1bf43712164db4151..36b938ab0c9d3dba9a1982b0c4b34ef7c3093fd5 100644 --- a/exercise-session-1.ipynb +++ b/exercise-session-1.ipynb @@ -31,7 +31,13 @@ "import scipy.stats as stats\n", "from src.sim import generate_data\n", "from src.trajectories import get_ex1_trajectories\n", - "ex1_trajectories = get_ex1_trajectories()" + "ex1_trajectories = get_ex1_trajectories()\n", + "import src.gaters as gaters\n", + "import src.filters as filters\n", + "import src.associators as associators\n", + "import src.trackers as trackers\n", + "import src.plotters as plotters\n", + "import src.models as models" ] }, { @@ -41,8 +47,7 @@ "metadata": {}, "outputs": [], "source": [ - "def evaluate_tracker(trajectory, tracker, MC):\n", - " T = trajectory\n", + "def evaluate_tracker(T, tracker, MC):\n", " # Initialize correctly\n", " xhat = np.zeros((4, T.shape[1], MC))\n", " xhat[:2, 0, :] = np.repeat(T[:, [0]], MC, axis=1)\n", @@ -52,7 +57,7 @@ " Phat[:, :, 0, :] = np.repeat(np.diag([10, 10, 1, 1])[:, :, None], MC, axis=2)\n", " Y = []\n", " for m in tqdm.tqdm(range(MC), desc=\"MC iteration: \"):\n", - " Y.append(generate_data(T, tracker.filt.sensor_model, tracker.clutter_model, rng))\n", + " Y.append(generate_data({'T': T}, tracker.filt.sensor_model, tracker.clutter_model, rng))\n", " xhat[:, :, m], Phat[:, :, :, m] = tracker.evaluate(Y[-1], xhat[:, :, m], Phat[:, :, :, m])\n", " # Evaluate RMSE\n", " rmse = np.sqrt(np.mean((T[:, :, None] - xhat[:2, :, :])**2, axis=1))\n", @@ -67,36 +72,16 @@ "metadata": {}, "outputs": [], "source": [ - "import src.gaters as gaters\n", - "import src.filters as filters\n", - "import src.associators as associators\n", - "import src.trackers as trackers\n", - "import src.plotters as plotters\n", - "\n", - "# Setup sensor and clutter model\n", - "def h(x): # Assumes positional coordinates first\n", - " if len(x.shape) < 2:\n", - " xt = x.reshape(-1, 1)\n", - " else:\n", - " xt = x\n", - " target_range = jnp.linalg.norm(xt[:2, :], axis=0) # Use JAX numpy to be able to auto-differentiate\n", - " target_bearing = jnp.arctan2(xt[1, :], xt[0, :])\n", - " return jnp.vstack([target_range, target_bearing]).squeeze()\n", - " \n", "R = np.diag([10, 0.001])**2\n", "PD = 0.9\n", "lam = 2\n", "volume = dict(xmin=0, xmax=2500, ymin=0, ymax=1000)\n", - "sensor_model = dict(h=h, R=R, PD=PD)\n", + "sensor_model = models.radar_model(R=R, PD=PD)\n", "clutter_model = dict(volume=volume, lam=lam)\n", "\n", "# CV model\n", - "F = np.identity(4)\n", - "F[:2, 2:] = np.identity(2)\n", - "G = np.vstack([np.identity(2), 1/2*np.identity(2)])\n", - "Q = lambda q: G@(q*np.identity(2))@G.T\n", - "f = lambda x: F@x\n", - "motion_model = dict(f=f, Q=Q(1))\n", + "Q = np.identity(2)\n", + "motion_model = models.cv_model(Q=Q, D=2, T=1)\n", "\n", "# Setup Gater\n", "gamma = 4.7\n", @@ -127,7 +112,7 @@ "metadata": {}, "outputs": [], "source": [ - "motion_model['Q'] = Q(1)\n", + "motion_model = models.cv_model(Q=np.identity(2), D=2, T=1)\n", "# Setup filter\n", "filt = filters.EKF(motion_model, sensor_model)\n", "# Setup tracker\n", @@ -168,7 +153,7 @@ "metadata": {}, "outputs": [], "source": [ - "motion_model['Q'] = Q(100)\n", + "motion_model = models.cv_model(Q=np.identity(2)*100, D=2, T=1)\n", "# Setup filter\n", "filt = filters.EKF(motion_model, sensor_model)\n", "# Setup tracker\n", @@ -211,23 +196,21 @@ "metadata": {}, "outputs": [], "source": [ - "motion_model_two = motion_model.copy()\n", - "motion_model_two['Q'] = Q(100)\n", + "motion_model = models.cv_model(Q=np.identity(2), D=2, T=1) # Use Q=I\n", + "motion_model_two = models.cv_model(Q=100*np.identity(2), D=2, T=1) # Use Q=100I\n", "filtone = filters.EKF(motion_model, sensor_model)\n", "filttwo = filters.EKF(motion_model_two, sensor_model)\n", - "p = 0.8\n", + "p = 0.9\n", "trans_prob = np.array([[p, 1-p], [1-p, p]])\n", "imm = filters.IMM([filtone, filttwo], sensor_model, trans_prob)\n", - "gater.gamma = 10.6#9.2" + "gater.gamma = 12.43 # Corresponds to alpha=0.998. Tuned for performance." ] }, { "cell_type": "code", "execution_count": null, "id": "c74df784", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "# Setup tracker\n", @@ -244,7 +227,7 @@ "mode_probabilities = []\n", "Y = []\n", "for m in tqdm.tqdm(range(MC), desc=\"MC iteration: \"):\n", - " Y.append(generate_data(T, tracker.filt.sensor_model, tracker.clutter_model, rng))\n", + " Y.append(generate_data({'T': T}, tracker.filt.sensor_model, tracker.clutter_model, rng))\n", " xhat[:, :, m], Phat[:, :, :, m] = tracker.evaluate(Y[-1], xhat[:, :, m], Phat[:, :, :, m])\n", " mode_probabilities.append(imm.mode_probabilities)\n", " imm.mode_probabilities = [imm.mode_probabilities[0]]\n", @@ -304,16 +287,8 @@ "outputs": [], "source": [ "T = np.round(np.diff(dat24['T1'][0])[0], 2)\n", - "F = np.identity(4)\n", - "F[:2, 2:] = np.identity(2)*T\n", - "G = np.vstack([np.identity(2)*T, T**2/2*np.identity(2)])\n", - "q1 = 1*np.identity(2)\n", - "q2 = 1e4*np.identity(2)\n", - "Q1 = G@q1@G.T\n", - "Q2 = G@q2@G.T\n", - "f = lambda x: F@x\n", - "mm = dict(f=f, Q=Q1)\n", - "mm2 = dict(f=f, Q=Q2)\n", + "mm = models.cv_model(Q=np.identity(2), D=2, T=T)\n", + "mm2 = models.cv_model(Q=1e4*np.identity(2), D=2, T=T)\n", "\n", "sensor_model['R'] = np.diag([135, 1.5*np.pi/180])**2\n", "clutter_model['volume'] = dict(xmin=20000, xmax=32000, ymin=-10000, ymax=35000)\n", @@ -374,6 +349,14 @@ "$$N1=2,~M1=2,~N2=2,~M2=3,~N3=3$$" ] }, + { + "cell_type": "markdown", + "id": "0b81e293-471d-4eb5-87b3-c24b108b6e65", + "metadata": {}, + "source": [ + "#### Convenience functionality" + ] + }, { "cell_type": "code", "execution_count": null, @@ -382,34 +365,30 @@ "outputs": [], "source": [ "def update_track(meas_k, unused_meas, filt, track, associator, gater, logic, logic_params):\n", - " # Calculate prediction error of each measurement\n", + " # Calculate prediction error of each unused measurement\n", " yhat = filt.sensor_model['h'](track['x'][-1])\n", + " eps = meas_k[:, unused_meas]-yhat\n", " \n", - " eps = meas_k[unused_meas]-yhat\n", - " \n", + " # No unused measurements -> update track with false measurement\n", " if eps.size == 0:\n", " track = logic(np.array([]), filt, track, logic_params)\n", " return\n", + " # Necessary for broadcasting later on\n", " if eps.ndim < 2:\n", " eps = np.expand_dims(eps, 0)\n", " \n", " # Gating step\n", - " accepted_meas = gater.gate(track['x'][-1], track['P'][-1], eps)\n", + " accepted_meas = gater.gate(track['x'][-1], track['P'][-1], meas_k[:, unused_meas])\n", " # If any measurements are accepted, select the nearest one\n", " if accepted_meas.any():\n", " # Association step\n", " yind = associator.associate(eps[:, accepted_meas])\n", " # Update track information\n", - " track = logic(meas_k[unused_meas][accepted_meas][yind], filt, track, logic_params)\n", - " # Unfortunately, Python can't update an array value with multi-level indexing\n", - " tmp = unused_meas[unused_meas]\n", - " tmp2 = tmp[accepted_meas]\n", - " tmp2[yind] = 0\n", - " tmp[accepted_meas] = tmp2\n", - " unused_meas[unused_meas] = tmp\n", + " track = logic(meas_k[:, unused_meas][:, accepted_meas][:, yind], filt, track, logic_params)\n", + " # Remove this measurement from further consideration\n", + " unused_meas[meas_k.flatten()==meas_k[:, unused_meas][:, accepted_meas][:, yind]] = 0\n", " # Update\n", " track['x'][-1], track['P'][-1] = filt.update(track['x'][-1], track['P'][-1], eps[:, accepted_meas][:, yind])\n", - " \n", " else:\n", " track = logic(np.array([]), filt, track, logic_params)\n", "\n", @@ -419,7 +398,7 @@ "\n", " ids = 0\n", " for k, meas_k in tqdm.tqdm(enumerate(Y), desc=\"Evaluating observations: \"):\n", - " unused_meas = np.ones((1,meas_k.size), dtype=bool)\n", + " unused_meas = np.ones((meas_k.size,), dtype=bool)\n", "\n", " for track in confirmed_tracks:\n", " if track['stage'] == 'deleted':\n", @@ -443,9 +422,10 @@ " confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks\n", "\n", " # Used the unused measurements to initiate new tracks\n", - " for meas in meas_k[unused_meas]:\n", - " tracks.append(init_track(meas, k, ids))\n", - " ids += 1\n", + " if unused_meas.any():\n", + " for meas in meas_k[:, unused_meas].T:\n", + " tracks.append(init_track(meas, k, ids))\n", + " ids += 1\n", "\n", " for track in tracks:\n", " if track['stage'] != 'deleted':\n", @@ -457,12 +437,12 @@ "\n", "def plot_logic_result(tracks, confirmed_tracks):\n", " fig, ax = plt.subplots(2, 1, figsize=(16, 16))\n", - "\n", + " confirmed_ids = [track['identity'] for track in confirmed_tracks]\n", " for track in tracks:\n", " x = np.vstack(track['x'])\n", " t = np.hstack(track['t']).flatten()\n", " assoc = np.hstack(track['associations']).flatten()\n", - " if track in confirmed_tracks:\n", + " if track['identity'] in confirmed_ids:\n", " ls = '-'\n", " else:\n", " ls = '--'\n", @@ -489,7 +469,7 @@ "import src.logic as logic\n", "F = np.array([[1, 1], [0, 1]])\n", "Q1 = 0.001*np.array([[1/4, 1/2], [1/2, 1]])\n", - "R = np.array([0.01])[:, None]\n", + "R = np.array([[0.01]])\n", "PD = 0.9\n", "clutter_model = dict(volume=dict(xmin=-10, xmax=10, ymin=0, ymax=0), lam=0.05*20)\n", "H = np.array([[1, 0]])\n", @@ -497,16 +477,15 @@ "h = lambda x: H@x\n", "motion_model = dict(f=f, Q=Q1)\n", "sensor_model = dict(h=h, R=R, PD=PD)\n", - "gater = gaters.MahalanobisGater(sensor_model, 4.7)\n", + "gater = gaters.MahalanobisGater(sensor_model, 9.2)#4.7)\n", "Y = list(dat24['Y2'].flatten())\n", - "N = len(Y)\n", "P0 = np.diag([R[0, 0], 0.1])\n", "logic_params = dict(N1=2, M1=2, N2=2, M2=3, N3=3)\n", "filt = filters.EKF(motion_model, sensor_model)\n", "init_track = lambda y, k, identity: dict(stage='tentative', \n", " nmeas=1, \n", " nass=1, \n", - " x=[np.array([y,0])], \n", + " x=[np.append(y,0)], \n", " P=[P0], \n", " t=[k], \n", " identity=identity,\n", @@ -559,7 +538,7 @@ "logic_params = dict(PD=sensor_model['PD'], PG=1, lam=lam, Ptm=Ptm, Pfc=Pfc, Bfa=0.05, Ldel=np.log(Ptm/(1-Pfc)))\n", "init_track = lambda y, k, identity: dict(stage='tentative',\n", " Lt=0,\n", - " x=[np.array([y,0])], \n", + " x=[np.append(y, 0)], \n", " P=[P0], \n", " t=[k], \n", " identity=identity,\n",