diff --git a/exercise-session-1.ipynb b/exercise-session-1.ipynb
index 780e1c864d666adfa3f0a2b1bf43712164db4151..36b938ab0c9d3dba9a1982b0c4b34ef7c3093fd5 100644
--- a/exercise-session-1.ipynb
+++ b/exercise-session-1.ipynb
@@ -31,7 +31,13 @@
     "import scipy.stats as stats\n",
     "from src.sim import generate_data\n",
     "from src.trajectories import get_ex1_trajectories\n",
-    "ex1_trajectories = get_ex1_trajectories()"
+    "ex1_trajectories = get_ex1_trajectories()\n",
+    "import src.gaters as gaters\n",
+    "import src.filters as filters\n",
+    "import src.associators as associators\n",
+    "import src.trackers as trackers\n",
+    "import src.plotters as plotters\n",
+    "import src.models as models"
    ]
   },
   {
@@ -41,8 +47,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def evaluate_tracker(trajectory, tracker, MC):\n",
-    "    T = trajectory\n",
+    "def evaluate_tracker(T, tracker, MC):\n",
     "    # Initialize correctly\n",
     "    xhat = np.zeros((4, T.shape[1], MC))\n",
     "    xhat[:2, 0, :] = np.repeat(T[:, [0]], MC, axis=1)\n",
@@ -52,7 +57,7 @@
     "    Phat[:, :, 0, :] = np.repeat(np.diag([10, 10, 1, 1])[:, :, None], MC, axis=2)\n",
     "    Y = []\n",
     "    for m in tqdm.tqdm(range(MC), desc=\"MC iteration: \"):\n",
-    "        Y.append(generate_data(T, tracker.filt.sensor_model, tracker.clutter_model, rng))\n",
+    "        Y.append(generate_data({'T': T}, tracker.filt.sensor_model, tracker.clutter_model, rng))\n",
     "        xhat[:, :, m], Phat[:, :, :, m] = tracker.evaluate(Y[-1], xhat[:, :, m], Phat[:, :, :, m])\n",
     "    # Evaluate RMSE\n",
     "    rmse = np.sqrt(np.mean((T[:, :, None] - xhat[:2, :, :])**2, axis=1))\n",
@@ -67,36 +72,16 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "import src.gaters as gaters\n",
-    "import src.filters as filters\n",
-    "import src.associators as associators\n",
-    "import src.trackers as trackers\n",
-    "import src.plotters as plotters\n",
-    "\n",
-    "# Setup sensor and clutter model\n",
-    "def h(x): # Assumes positional coordinates first\n",
-    "    if len(x.shape) < 2:\n",
-    "        xt = x.reshape(-1, 1)\n",
-    "    else:\n",
-    "        xt = x\n",
-    "    target_range = jnp.linalg.norm(xt[:2, :], axis=0) # Use JAX numpy to be able to auto-differentiate\n",
-    "    target_bearing = jnp.arctan2(xt[1, :], xt[0, :])\n",
-    "    return jnp.vstack([target_range, target_bearing]).squeeze()\n",
-    "    \n",
     "R = np.diag([10, 0.001])**2\n",
     "PD = 0.9\n",
     "lam = 2\n",
     "volume = dict(xmin=0, xmax=2500, ymin=0, ymax=1000)\n",
-    "sensor_model = dict(h=h, R=R, PD=PD)\n",
+    "sensor_model = models.radar_model(R=R, PD=PD)\n",
     "clutter_model = dict(volume=volume, lam=lam)\n",
     "\n",
     "# CV model\n",
-    "F = np.identity(4)\n",
-    "F[:2, 2:] = np.identity(2)\n",
-    "G = np.vstack([np.identity(2), 1/2*np.identity(2)])\n",
-    "Q = lambda q: G@(q*np.identity(2))@G.T\n",
-    "f = lambda x: F@x\n",
-    "motion_model = dict(f=f, Q=Q(1))\n",
+    "Q = np.identity(2)\n",
+    "motion_model = models.cv_model(Q=Q, D=2, T=1)\n",
     "\n",
     "# Setup Gater\n",
     "gamma = 4.7\n",
@@ -127,7 +112,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "motion_model['Q'] = Q(1)\n",
+    "motion_model = models.cv_model(Q=np.identity(2), D=2, T=1)\n",
     "# Setup filter\n",
     "filt = filters.EKF(motion_model, sensor_model)\n",
     "# Setup tracker\n",
@@ -168,7 +153,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "motion_model['Q'] = Q(100)\n",
+    "motion_model = models.cv_model(Q=np.identity(2)*100, D=2, T=1)\n",
     "# Setup filter\n",
     "filt = filters.EKF(motion_model, sensor_model)\n",
     "# Setup tracker\n",
@@ -211,23 +196,21 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "motion_model_two = motion_model.copy()\n",
-    "motion_model_two['Q'] = Q(100)\n",
+    "motion_model = models.cv_model(Q=np.identity(2), D=2, T=1) # Use Q=I\n",
+    "motion_model_two = models.cv_model(Q=100*np.identity(2), D=2, T=1) # Use Q=100I\n",
     "filtone = filters.EKF(motion_model, sensor_model)\n",
     "filttwo = filters.EKF(motion_model_two, sensor_model)\n",
-    "p = 0.8\n",
+    "p = 0.9\n",
     "trans_prob = np.array([[p, 1-p], [1-p, p]])\n",
     "imm = filters.IMM([filtone, filttwo], sensor_model, trans_prob)\n",
-    "gater.gamma = 10.6#9.2"
+    "gater.gamma = 12.43 # Corresponds to alpha=0.998. Tuned for performance."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "c74df784",
-   "metadata": {
-    "scrolled": true
-   },
+   "metadata": {},
    "outputs": [],
    "source": [
     "# Setup tracker\n",
@@ -244,7 +227,7 @@
     "mode_probabilities = []\n",
     "Y = []\n",
     "for m in tqdm.tqdm(range(MC), desc=\"MC iteration: \"):\n",
-    "    Y.append(generate_data(T, tracker.filt.sensor_model, tracker.clutter_model, rng))\n",
+    "    Y.append(generate_data({'T': T}, tracker.filt.sensor_model, tracker.clutter_model, rng))\n",
     "    xhat[:, :, m], Phat[:, :, :, m] = tracker.evaluate(Y[-1], xhat[:, :, m], Phat[:, :, :, m])\n",
     "    mode_probabilities.append(imm.mode_probabilities)\n",
     "    imm.mode_probabilities = [imm.mode_probabilities[0]]\n",
@@ -304,16 +287,8 @@
    "outputs": [],
    "source": [
     "T = np.round(np.diff(dat24['T1'][0])[0], 2)\n",
-    "F = np.identity(4)\n",
-    "F[:2, 2:] = np.identity(2)*T\n",
-    "G = np.vstack([np.identity(2)*T, T**2/2*np.identity(2)])\n",
-    "q1 = 1*np.identity(2)\n",
-    "q2 = 1e4*np.identity(2)\n",
-    "Q1 = G@q1@G.T\n",
-    "Q2 = G@q2@G.T\n",
-    "f = lambda x: F@x\n",
-    "mm = dict(f=f, Q=Q1)\n",
-    "mm2 = dict(f=f, Q=Q2)\n",
+    "mm = models.cv_model(Q=np.identity(2), D=2, T=T)\n",
+    "mm2 = models.cv_model(Q=1e4*np.identity(2), D=2, T=T)\n",
     "\n",
     "sensor_model['R'] = np.diag([135, 1.5*np.pi/180])**2\n",
     "clutter_model['volume'] = dict(xmin=20000, xmax=32000, ymin=-10000, ymax=35000)\n",
@@ -374,6 +349,14 @@
     "$$N1=2,~M1=2,~N2=2,~M2=3,~N3=3$$"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "0b81e293-471d-4eb5-87b3-c24b108b6e65",
+   "metadata": {},
+   "source": [
+    "#### Convenience functionality"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -382,34 +365,30 @@
    "outputs": [],
    "source": [
     "def update_track(meas_k, unused_meas, filt, track, associator, gater, logic, logic_params):\n",
-    "    # Calculate prediction error of each measurement\n",
+    "    # Calculate prediction error of each unused measurement\n",
     "    yhat = filt.sensor_model['h'](track['x'][-1])\n",
+    "    eps = meas_k[:, unused_meas]-yhat\n",
     "    \n",
-    "    eps = meas_k[unused_meas]-yhat\n",
-    "    \n",
+    "    # No unused measurements -> update track with false measurement\n",
     "    if eps.size == 0:\n",
     "        track = logic(np.array([]), filt, track, logic_params)\n",
     "        return\n",
+    "    # Necessary for broadcasting later on\n",
     "    if eps.ndim < 2:\n",
     "        eps = np.expand_dims(eps, 0)\n",
     "    \n",
     "    # Gating step\n",
-    "    accepted_meas = gater.gate(track['x'][-1], track['P'][-1], eps)\n",
+    "    accepted_meas = gater.gate(track['x'][-1], track['P'][-1], meas_k[:, unused_meas])\n",
     "    # If any measurements are accepted, select the nearest one\n",
     "    if accepted_meas.any():\n",
     "        # Association step\n",
     "        yind = associator.associate(eps[:, accepted_meas])\n",
     "        # Update track information\n",
-    "        track = logic(meas_k[unused_meas][accepted_meas][yind], filt, track, logic_params)\n",
-    "        # Unfortunately, Python can't update an array value with multi-level indexing\n",
-    "        tmp = unused_meas[unused_meas]\n",
-    "        tmp2 = tmp[accepted_meas]\n",
-    "        tmp2[yind] = 0\n",
-    "        tmp[accepted_meas] = tmp2\n",
-    "        unused_meas[unused_meas] = tmp\n",
+    "        track = logic(meas_k[:, unused_meas][:, accepted_meas][:, yind], filt, track, logic_params)\n",
+    "        # Remove this measurement from further consideration\n",
+    "        unused_meas[meas_k.flatten()==meas_k[:, unused_meas][:, accepted_meas][:, yind]] = 0\n",
     "        # Update\n",
     "        track['x'][-1], track['P'][-1] = filt.update(track['x'][-1], track['P'][-1], eps[:, accepted_meas][:, yind])\n",
-    "        \n",
     "    else:\n",
     "        track = logic(np.array([]), filt, track, logic_params)\n",
     "\n",
@@ -419,7 +398,7 @@
     "\n",
     "    ids = 0\n",
     "    for k, meas_k in tqdm.tqdm(enumerate(Y), desc=\"Evaluating observations: \"):\n",
-    "        unused_meas = np.ones((1,meas_k.size), dtype=bool)\n",
+    "        unused_meas = np.ones((meas_k.size,), dtype=bool)\n",
     "\n",
     "        for track in confirmed_tracks:\n",
     "            if track['stage'] == 'deleted':\n",
@@ -443,9 +422,10 @@
     "                confirmed_tracks.append(track) # If a track has been confirmed, add it to confirmed tracks\n",
     "\n",
     "        # Used the unused measurements to initiate new tracks\n",
-    "        for meas in meas_k[unused_meas]:\n",
-    "            tracks.append(init_track(meas, k, ids))\n",
-    "            ids += 1\n",
+    "        if unused_meas.any():\n",
+    "            for meas in meas_k[:, unused_meas].T:\n",
+    "                tracks.append(init_track(meas, k, ids))\n",
+    "                ids += 1\n",
     "\n",
     "        for track in tracks:\n",
     "            if track['stage'] != 'deleted':\n",
@@ -457,12 +437,12 @@
     "\n",
     "def plot_logic_result(tracks, confirmed_tracks):\n",
     "    fig, ax = plt.subplots(2, 1, figsize=(16, 16))\n",
-    "\n",
+    "    confirmed_ids = [track['identity'] for track in confirmed_tracks]\n",
     "    for track in tracks:\n",
     "        x = np.vstack(track['x'])\n",
     "        t = np.hstack(track['t']).flatten()\n",
     "        assoc = np.hstack(track['associations']).flatten()\n",
-    "        if track in confirmed_tracks:\n",
+    "        if track['identity'] in confirmed_ids:\n",
     "            ls = '-'\n",
     "        else:\n",
     "            ls = '--'\n",
@@ -489,7 +469,7 @@
     "import src.logic as logic\n",
     "F = np.array([[1, 1], [0, 1]])\n",
     "Q1 = 0.001*np.array([[1/4, 1/2], [1/2, 1]])\n",
-    "R = np.array([0.01])[:, None]\n",
+    "R = np.array([[0.01]])\n",
     "PD = 0.9\n",
     "clutter_model = dict(volume=dict(xmin=-10, xmax=10, ymin=0, ymax=0), lam=0.05*20)\n",
     "H = np.array([[1, 0]])\n",
@@ -497,16 +477,15 @@
     "h = lambda x: H@x\n",
     "motion_model = dict(f=f, Q=Q1)\n",
     "sensor_model = dict(h=h, R=R, PD=PD)\n",
-    "gater = gaters.MahalanobisGater(sensor_model, 4.7)\n",
+    "gater = gaters.MahalanobisGater(sensor_model, 9.2)#4.7)\n",
     "Y = list(dat24['Y2'].flatten())\n",
-    "N = len(Y)\n",
     "P0 = np.diag([R[0, 0], 0.1])\n",
     "logic_params = dict(N1=2, M1=2, N2=2, M2=3, N3=3)\n",
     "filt = filters.EKF(motion_model, sensor_model)\n",
     "init_track = lambda y, k, identity: dict(stage='tentative', \n",
     "                                         nmeas=1, \n",
     "                                         nass=1, \n",
-    "                                         x=[np.array([y,0])], \n",
+    "                                         x=[np.append(y,0)], \n",
     "                                         P=[P0], \n",
     "                                         t=[k], \n",
     "                                         identity=identity,\n",
@@ -559,7 +538,7 @@
     "logic_params = dict(PD=sensor_model['PD'], PG=1, lam=lam, Ptm=Ptm, Pfc=Pfc, Bfa=0.05, Ldel=np.log(Ptm/(1-Pfc)))\n",
     "init_track = lambda y, k, identity: dict(stage='tentative',\n",
     "                                        Lt=0,\n",
-    "                                        x=[np.array([y,0])], \n",
+    "                                        x=[np.append(y, 0)], \n",
     "                                        P=[P0], \n",
     "                                        t=[k], \n",
     "                                        identity=identity,\n",