diff --git a/gym_scenario.py b/gym_scenario.py
index 3cd504d42d6ebf0b2baa385ea97b8034bd6ac6c7..66004e08cf57bedebe8a8d6cdcda53a3bbe02955 100644
--- a/gym_scenario.py
+++ b/gym_scenario.py
@@ -1,7 +1,7 @@
 from dataclasses import replace
 import itertools
 import math
-from typing import Optional, Union, dict
+from typing import Optional, Union
 
 import dm_env
 import gym
@@ -10,11 +10,12 @@ import numpy as np
 import jax
 
 from data import ObservedState, State, Aircraft, RadarAction, AircraftObs, AircraftObj, AircraftPos, SaAircraft
+from support import get_radar_detection_obs, construct_tracks, trim_observations, get_radar_obs
 import constants
 
 class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
     def __init__(self, env_spec: dict):
-        self.scenario_counter = env_spec["scenario_counter"]
+        self.scenario_id = env_spec["scenario_id"]
         self.use_static = env_spec["use_static"]
         self._state: Optional[State] = None
 
@@ -23,11 +24,14 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
         #actions=tuple([a for a in itertools.chain(green_radar_actions, red_radar_actions)])
         self._state = replace(self._state, actions=actions)
         self._state = self.step_state(self._state, current_step_key)
+        self.scenario_step += 1
+        step_type = dm_env.StepType.MID if self.scenario_step < constants.SCENARIO_LENGTH else dm_env.StepType.LAST
+
+        return dm_env.TimeStep(step_type = step_type, reward = None, discount = 1.0, observation=self._state)
 
     def calculate_dir_angle(self, own: AircraftObj) -> float:
         return math.atan2(own.aircraft.velocity_x, own.aircraft.velocity_y) * 180 / math.pi
 
-
     def is_observed_by_radar_probability(self, rng_key: jax.random.PRNGKey, intensity: float) -> bool:
         if intensity <= constants.RADAR_DETECTION_LENGTH:
             return True
@@ -75,37 +79,10 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
         angle = dir_angle - obs_angle
         if constants.RADAR_DETECTION_SPAN[0] <= angle <= constants.RADAR_DETECTION_SPAN[1]:
             # 95 % change to detect by radar detection
-            if jax.random.unifomr(rng_key) < 0.95:
+            if jax.random.uniform(rng_key) < 0.95:
                 return True
         return False
     
-    def trim_observations(self, own: AircraftObj, friends: list[AircraftObj], observations: tuple[AircraftObs, ...], time: float) -> tuple[AircraftObs, ...]:
-        """
-        TODO: Remove observations that has not yet been communicated or does not come from me or my friends
-        """
-        # Select observations from myself or friends
-        friends_names = {a.aircraft.name for a in friends}
-        trimmed_observations = tuple([o for o in observations if o.discover_name in friends_names])
-
-        # Remove observations that has not been communicated
-        trimmed_observations = tuple([o for o in trimmed_observations if o.discover_name == own.aircraft.name or o.time + constants.COMMUNICATION_DELAY <= time])
-        return trimmed_observations
-
-    def construct_tracks(own: AircraftObj, observations: tuple[AircraftObs, ...], time: float) -> tuple[SaAircraft, ...]:
-        prev_sa_obs = [AircraftObs(own.aircraft.name, a.observed_time, a.aircraft.name, a.aircraft.position) for a in own.sa]
-        combined_observations = list(observations) + prev_sa_obs
-        track_dict = dict()
-        for obs in combined_observations:
-            if obs.name in track_dict:
-                track = track_dict[obs.name]
-                if track.observed_time < obs.time:
-                    upd_track = SaAircraft(Aircraft(obs.name, obs.position, None, None), obs.time)
-                    track_dict[obs.name] = upd_track
-            else:
-                upd_track = SaAircraft(Aircraft(obs.name, obs.position, None, None), obs.time)
-                track_dict[obs.name] = upd_track
-        return tuple(track_dict.values())
-        
     def step_state(self, state: State, current_step_key: jax.random.PRNGKey) -> State:
         # Update positions
         new_green = list()
@@ -151,27 +128,27 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
             
             for other_aircraft, radar_key, radar_detection_key in zip(others, other_keys_radar, other_keys_radar_detection):
                 if self.is_observed_by_radar(radar_key, own_aircraft, other_aircraft, action):
-                    radar_obs = self.get_radar_obs(own_aircraft, other_aircraft, action, time)
+                    radar_obs = get_radar_obs(own_aircraft, other_aircraft, action, time)
                     logger.info(f"Radar observation {other_aircraft.aircraft.name}")
                     sa_updates.append(radar_obs)
-                if self.is_observed_by_radar_detection(own_aircraft, other_aircraft, action):
-                    radar_detection_obs = self.get_radar_detection_obs(own_aircraft, other_aircraft, action, time)
+                if self.is_observed_by_radar_detection(radar_detection_key, own_aircraft, other_aircraft, action):
+                    radar_detection_obs = get_radar_detection_obs(own_aircraft, other_aircraft, action, time)
                     logger.info(f"Radar detection observation of {own_aircraft.aircraft.name}")
                     sa_updates.append(radar_detection_obs)
 
         # Update Situational awareness
         new_green = list()
         for own in state.green:
-            own_observations = self.trim_observations(own, state.green, sa_updates, time)
-            own_sa = self.construct_tracks(own, own_observations, time)
+            own_observations = trim_observations(own, state.green, sa_updates, time)
+            own_sa = construct_tracks(own, own_observations, time)
             assert len(own_sa) <= len(state.red)
             aircraft_obj = replace(own, sa=tuple(own_sa))
             new_green.append(aircraft_obj)
         
         new_red = list()
         for own in state.red:
-            own_observations = self.trim_observations(own, state.red, sa_updates, time)
-            own_sa = self.construct_tracks(own, own_observations, time)
+            own_observations = trim_observations(own, state.red, sa_updates, time)
+            own_sa = construct_tracks(own, own_observations, time)
             assert len(own_sa) <= len(state.green)
             aircraft_obj = replace(own, sa=tuple(own_sa))
             new_red.append(aircraft_obj)
@@ -183,15 +160,17 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
 
         return new_state
         
-    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> State:
+    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> dm_env.TimeStep:
         rng_key = jax.random.PRNGKey(seed) if seed is not None else jax.random.PRNGKey(0)
+        self.scenario_step = 0
 
         self.step_key, scenario_key = jax.random.split(rng_key)
         if self.use_static:
             self._generate_static_scenario()
         else:
             self._generate_scenario(scenario_key)
-        return self._state
+        ts = dm_env.TimeStep(step_type=dm_env.StepType.FIRST, reward=None, discount=None, observation=self._state)
+        return ts
 
     def _generate_static_scenario(self):
         # Green team initialisation
@@ -216,22 +195,38 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
         green_seed, red_seed = jax.random.split(scenario_key)
         green_keys = jax.random.split(green_seed, num=6)
         red_keys = jax.random.split(red_seed, num=6)
-
-        no_greens = jax.random.randint(green_keys[0], 2, 7)
-        no_reds = jax.random.randint(red_keys[0], 2, 7)
+        
+        no_greens, no_reds = jax.random.randint(green_keys[0], shape=(2,), minval=2, maxval=7)
+
+        green_x = jax.random.randint(green_keys[1], shape=(no_greens,), minval=0, maxval=20)
+        green_y = jax.random.randint(green_keys[2], shape=(no_greens,), minval=0, maxval=100)
+        green_vx = jax.random.uniform(green_keys[3], shape=(no_greens,), minval=0.15, maxval=0.3)
+        green_vy = jax.random.uniform(green_keys[4], shape=(no_greens,), minval=-0.1, maxval=0.1)
+        green_rcs = jax.random.uniform(green_keys[5], shape=(no_greens,), minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1])
+        red_x = jax.random.randint(red_keys[1], shape=(no_reds,), minval=0, maxval=20)
+        red_y = jax.random.randint(red_keys[2], shape=(no_reds,), minval=0, maxval=100)
+        red_vx = -jax.random.uniform(red_keys[3], shape=(no_reds,), minval=0.15, maxval=0.3) #Note minus sign
+        red_vy = jax.random.uniform(red_keys[4], shape=(no_reds,), minval=-0.1, maxval=0.1)
+        red_rcs = jax.random.uniform(red_keys[5], shape=(no_reds,), minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1])
 
         # Green team initialisation
-        greens = list()
-        for i in range(no_greens):
-            aircraftPos = AircraftPos(offset_x + jax.random.randint(green_keys[1], 0, 20), offset_y + jax.random.randint(green_keys[2], 0, 100))
-            obj = AircraftObj(Aircraft(f"g{i}", aircraftPos, jax.random.uniform(green_keys[3], minval=0.15, maxval=0.3), jax.random.uniform(green_keys[4], minval=-0.1, maxval=0.1)), tuple(), jax.random.uniform(green_keys[5], minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1])) 
-            greens.append(obj)
-
-        reds = list()
-        for i in range(no_reds):
-            aircraftPos = AircraftPos(difference_x + offset_x + jax.random.randint(red_keys[1], 0, 20), offset_y + jax.random.randint(red_keys[2], 0, 100))
-            obj = AircraftObj(Aircraft(f"r{i}", aircraftPos, -jax.random.uniform(red_keys[3], minval=0.15, maxval=0.3), jax.random.uniform(red_keys[4], minval=-0.1, maxval=0.1)), tuple(), jax.random.uniform(red_keys[5], minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1])) 
-            reds.append(obj)
-
-        self._state = State(f"Scenario{self.scenario_counter}", 0, tuple(greens), tuple(reds), tuple(), tuple())
+        green_aircraft_pos = [AircraftPos(offset_x + x, offset_y + y) for x, y in zip(green_x, green_y)]
+        greens = [AircraftObj(Aircraft(f"g{i}", aircraftPos, vx, vy), tuple(), rcs) for i, aircraftPos, vx, vy, rcs in zip(range(no_greens), green_aircraft_pos, green_vx, green_vy, green_rcs)]
+        # red team initialisation
+        red_aircraft_pos = [AircraftPos(difference_x + offset_x + x, offset_y + y) for x, y in zip(red_x, red_y)]
+        reds = [AircraftObj(Aircraft(f"r{i}", aircraftPos, vx, vy), tuple(), rcs) for i, aircraftPos, vx, vy, rcs in zip(range(no_reds), red_aircraft_pos, red_vx, red_vy, red_rcs)]
+
+        # greens = list()
+        # for i, x, y, vx, vy, rcs in zip(range(no_greens), green_x, green_y, green_vx, green_vy, green_rcs):
+        #     aircraftPos = AircraftPos(offset_x + x, offset_y + y)
+        #     obj = AircraftObj(Aircraft(f"g{i}", aircraftPos, vx, vy), tuple(), rcs) 
+        #     greens.append(obj)
+
+        # reds = list()
+        # for i, x, y, vx, vy, rcs in zip(range(no_reds), red_x, red_y, red_vx, red_vy, red_rcs): 
+        #     aircraftPos = AircraftPos(difference_x + offset_x + x, offset_y + y)
+        #     obj = AircraftObj(Aircraft(f"r{i}", aircraftPos, vx, vy), tuple(), rcs) 
+        #     reds.append(obj)
+
+        self._state = State(f"Scenario{self.scenario_id}", 0, tuple(greens), tuple(reds), tuple(), tuple())
 
diff --git a/main.py b/main.py
index 4c9059b22fc9dda049c5533a3f03ca5bb3d2946b..31893d39fbfc625501b9eaf5ec30686f8907b854 100644
--- a/main.py
+++ b/main.py
@@ -16,6 +16,16 @@ from behaviour import behaviour_choices
 
 
 def main() -> None:
+    logger = logging.getLogger("sensor-control")
+    logger.setLevel(logging.DEBUG)
+    ch = logging.StreamHandler()
+    ch.setLevel(logging.DEBUG)
+    ch_formatter = logging.Formatter("%(message)s")
+    ch.setFormatter(ch_formatter)
+    logger.addHandler(ch)
+    folder_name = 'tmp'
+    logger.info(f"Build dir: {folder_name}")
+
     # Create argument parse
     parser = argparse.ArgumentParser()
     parser.add_argument("--seed", default=0, type=int)
diff --git a/mdp_main.py b/mdp_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd28222e49cc3645996e69fc929247b828ca5352
--- /dev/null
+++ b/mdp_main.py
@@ -0,0 +1,182 @@
+import argparse
+import random
+import logging
+import itertools
+import datetime
+import os
+from pathlib import Path
+import sys
+import time
+import jax.random
+
+from behaviour import behaviour_choices, get_behaviour
+import constants
+from gym_scenario import SensorControlEnv
+from scenario import get_green_observe_state, get_red_observe_state, evaluate_state, cumulative_evaluation, draw_state, eval_to_csv
+from scenario_io import write_scenario_to_file
+from visualise import create_canvas
+
+def main_run_scenario(scenario_seed: int, green_behaviour_name: str, red_behaviour_name: str,
+                        visualise: bool, visualise_delay: int, folder_name: str, 
+                        logger: logging.Logger, csv_logger: logging.Logger):
+    env_config = {
+        "scenario_id": scenario_seed,
+        "use_static": False
+    }
+
+    logger.info("Generate scenario")
+    env = SensorControlEnv(env_config)
+    scenario_path = Path(folder_name + "/scenario.txt")
+
+    green_behaviour = get_behaviour(green_behaviour_name)
+    red_behaviour = get_behaviour(red_behaviour_name)
+    cumulative_evaluation_dict = dict()
+
+    logger.info("Start scenario")
+    start = datetime.datetime.now()
+    #Timing starts at a different point compared to the original code,
+    #since we include the reset call in the timing.
+    timestep = env.reset(seed=scenario_seed)
+    scenario_name = timestep.observation.name
+    logger.info(f"Writing scenario to {scenario_path.name}")
+    write_scenario_to_file(timestep.observation, scenario_path)
+    if visualise:
+        tk, canvas = create_canvas()
+        draw_state(timestep.observation, canvas, tk)
+        time.sleep(visualise_delay / 1000)
+
+    timestep_index = 0
+    while not timestep.last():
+        logger.info(f"Time step {timestep_index}")
+        green_obs = get_green_observe_state(timestep.observation)
+        green_radar_actions = green_behaviour.act(green_obs)
+        red_obs = get_red_observe_state(timestep.observation)
+        red_radar_actions = red_behaviour.act(red_obs)
+        actions=tuple([a for a in itertools.chain(green_radar_actions, red_radar_actions)])
+        timestep = env.step(actions)
+        evaluation_dict = evaluate_state(timestep.observation) 
+        #This evaluation differs from the original code, since
+        #the old evaluation was s_t + a_t, while this is s_{t+1}, a_t
+        #In MDP formalism, the action taken at time t isn't effective until t+1.
+        #Hence any form of radar state as a consequence of actions taken at time t shouldn't
+        #be evaluated until t+1.
+        #"Evaluation" could also be inspecting what actions are <taken> in what context.
+        #That would be more of an inspection than an evaluation though, but then s_t + a_t 
+        #is more relevant.
+        logger.info(f"{evaluation_dict}")
+        cumulative_evaluation_dict = cumulative_evaluation(evaluation_dict, cumulative_evaluation_dict, constants.SCENARIO_LENGTH)
+        if visualise:
+            #The original code doesn't draw the final state.
+            draw_state(timestep.observation, canvas, tk)
+            time.sleep(visualise_delay / 1000)
+        timestep_index += 1
+
+    logger.info("Cumulative results")
+    for k, v in cumulative_evaluation_dict.items():
+        logger.info(f"    {k}: {v}")
+
+    cumulative_evaluation_dict["green__behaviour"] = green_behaviour_name
+    cumulative_evaluation_dict["red__behaviour"] = red_behaviour_name
+    cumulative_evaluation_dict["scenario_name"] = scenario_name
+
+    csv_message = eval_to_csv(cumulative_evaluation_dict)
+    keys = eval_to_csv(cumulative_evaluation_dict, key_only=True)
+    csv_logger.info(csv_message)
+    logger.info(keys)
+    logger.info(csv_message)
+        
+    logger.info("Scenario finished")
+
+    end = datetime.datetime.now()
+    time_spent = end - start
+    logger.info(f"Scenario took {time_spent}s to run")
+    logger.info(f"Build dir: {folder_name}")
+
+def main() -> None:
+    # Create argument parse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--seed", default=0, type=int)
+    parser.add_argument("--build_dir", help="The directory that the runs folder should be placed", default="build", type=str)
+    parser.add_argument("--tag", help="Tag of run", default="no_tag", type=str)
+    parser.add_argument("--green", default="simpleone", choices=behaviour_choices, type=str.lower)
+    parser.add_argument("--red", default="no", choices=behaviour_choices, type=str.lower)
+    parser.add_argument("--visualise", help="Should the run be visualised?", default=False, type=bool, action=argparse.BooleanOptionalAction)
+    parser.add_argument("--visualise_delay", help="Delay between each step in miliseconds", type=int, default=500)
+    parser.add_argument("--csvfile", help="Path to CSV-file for logging", type=str, default="")
+    parser.add_argument("--stream-log", help="Should the logging stream show?", default=True, type=bool, action=argparse.BooleanOptionalAction)
+    args = parser.parse_args()
+
+    # Create directory for logging
+    folder_name = f"{args.build_dir}/{args.tag}/sensor-control_{args.green}_{args.red}_{args.seed}_{datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d_%H-%M-%S-%f')}"
+    if os.path.exists(folder_name):
+        raise RuntimeError(f"Build directory should not exist {folder_name}")
+    os.makedirs(folder_name)
+
+    # Setup logging
+    logger = logging.getLogger("sensor-control")
+    logger.propagate = False #Jax adds handlers in parent logger.
+    logger.setLevel(logging.DEBUG)
+    if args.stream_log:
+        ch = logging.StreamHandler()
+        ch.setLevel(logging.DEBUG)
+        ch_formatter = logging.Formatter("%(message)s")
+        ch.setFormatter(ch_formatter)
+        logger.addHandler(ch)
+    fh = logging.FileHandler(folder_name + "/debug.log")
+    fh_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+    fh.setFormatter(fh_formatter)
+    logger.addHandler(fh)
+
+    csv_logger = logging.getLogger("csv")
+    csv_logger.propagate = False #Jax adds handlers in parent logger.
+    csv_logger.setLevel(logging.DEBUG)
+    if args.csvfile:
+        csvh = logging.FileHandler(args.csvfile, delay=True)
+    else:
+        csvh = logging.FileHandler(folder_name + "/log.csv", delay=True)
+    csvh_formatter = logging.Formatter("%(message)s")
+    csvh.setFormatter(csvh_formatter)
+    csv_logger.addHandler(csvh)
+
+    # Perform main program
+    try:
+        # General info
+        logger.info(f"Build dir: {folder_name}")
+        logger.info(f"Green behaviour: {args.green}")
+        logger.info(f"Red behaviour: {args.red}")
+        logger.info(f"Seed: {args.seed}")
+        logger.info(f"Tag: {args.tag}")
+        logger.info(f"Visualise: {args.visualise}")
+        logger.info(f"Visualise delay: {args.visualise_delay}")
+        logger.info(f"CSV-file: {args.csvfile}")
+
+        # User / Host
+        logger.info(f"{os.environ.get('HOSTNAME')}")
+        logger.info(f"{os.environ.get('USER')}")
+
+        try:
+            import git  # GitPython
+            repo = git.Repo(os.getcwd() + "/..")
+            logger.info(f"Git commit: {repo.head.commit}")
+            logger.info(f"Git is_dirty: {repo.is_dirty()}")
+        except Exception as e:
+            logger.error(e)
+            logger.error("Git information retrieval unsuccessful")
+
+        # Ignore git parts from main.py
+        main_run_scenario(args.seed, args.green, args.red, 
+                        args.visualise, args.visualise_delay, folder_name, logger, csv_logger)
+
+    # If program fails catch exception
+    except Exception as e:
+        logger.error(f"Exception occured:")
+        logger.error(f"{e}")
+        raise
+
+    # Finished
+    logger.info("Application finished")
+    sys.exit(0)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/requirements.txt b/requirements.txt
index de75544f44d02067d997f65d99cdd085d22bb23f..9d26c1013d9871df04e0aa3ec0ed86c1abc0e9ca 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,3 +6,4 @@ gym
 
 jax
 jaxlib
+dm_env