diff --git a/gym_scenario.py b/gym_scenario.py index 3cd504d42d6ebf0b2baa385ea97b8034bd6ac6c7..66004e08cf57bedebe8a8d6cdcda53a3bbe02955 100644 --- a/gym_scenario.py +++ b/gym_scenario.py @@ -1,7 +1,7 @@ from dataclasses import replace import itertools import math -from typing import Optional, Union, dict +from typing import Optional, Union import dm_env import gym @@ -10,11 +10,12 @@ import numpy as np import jax from data import ObservedState, State, Aircraft, RadarAction, AircraftObs, AircraftObj, AircraftPos, SaAircraft +from support import get_radar_detection_obs, construct_tracks, trim_observations, get_radar_obs import constants class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): def __init__(self, env_spec: dict): - self.scenario_counter = env_spec["scenario_counter"] + self.scenario_id = env_spec["scenario_id"] self.use_static = env_spec["use_static"] self._state: Optional[State] = None @@ -23,11 +24,14 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): #actions=tuple([a for a in itertools.chain(green_radar_actions, red_radar_actions)]) self._state = replace(self._state, actions=actions) self._state = self.step_state(self._state, current_step_key) + self.scenario_step += 1 + step_type = dm_env.StepType.MID if self.scenario_step < constants.SCENARIO_LENGTH else dm_env.StepType.LAST + + return dm_env.TimeStep(step_type = step_type, reward = None, discount = 1.0, observation=self._state) def calculate_dir_angle(self, own: AircraftObj) -> float: return math.atan2(own.aircraft.velocity_x, own.aircraft.velocity_y) * 180 / math.pi - def is_observed_by_radar_probability(self, rng_key: jax.random.PRNGKey, intensity: float) -> bool: if intensity <= constants.RADAR_DETECTION_LENGTH: return True @@ -75,37 +79,10 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): angle = dir_angle - obs_angle if constants.RADAR_DETECTION_SPAN[0] <= angle <= constants.RADAR_DETECTION_SPAN[1]: # 95 % change to detect by radar detection - if jax.random.unifomr(rng_key) < 0.95: + if jax.random.uniform(rng_key) < 0.95: return True return False - def trim_observations(self, own: AircraftObj, friends: list[AircraftObj], observations: tuple[AircraftObs, ...], time: float) -> tuple[AircraftObs, ...]: - """ - TODO: Remove observations that has not yet been communicated or does not come from me or my friends - """ - # Select observations from myself or friends - friends_names = {a.aircraft.name for a in friends} - trimmed_observations = tuple([o for o in observations if o.discover_name in friends_names]) - - # Remove observations that has not been communicated - trimmed_observations = tuple([o for o in trimmed_observations if o.discover_name == own.aircraft.name or o.time + constants.COMMUNICATION_DELAY <= time]) - return trimmed_observations - - def construct_tracks(own: AircraftObj, observations: tuple[AircraftObs, ...], time: float) -> tuple[SaAircraft, ...]: - prev_sa_obs = [AircraftObs(own.aircraft.name, a.observed_time, a.aircraft.name, a.aircraft.position) for a in own.sa] - combined_observations = list(observations) + prev_sa_obs - track_dict = dict() - for obs in combined_observations: - if obs.name in track_dict: - track = track_dict[obs.name] - if track.observed_time < obs.time: - upd_track = SaAircraft(Aircraft(obs.name, obs.position, None, None), obs.time) - track_dict[obs.name] = upd_track - else: - upd_track = SaAircraft(Aircraft(obs.name, obs.position, None, None), obs.time) - track_dict[obs.name] = upd_track - return tuple(track_dict.values()) - def step_state(self, state: State, current_step_key: jax.random.PRNGKey) -> State: # Update positions new_green = list() @@ -151,27 +128,27 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): for other_aircraft, radar_key, radar_detection_key in zip(others, other_keys_radar, other_keys_radar_detection): if self.is_observed_by_radar(radar_key, own_aircraft, other_aircraft, action): - radar_obs = self.get_radar_obs(own_aircraft, other_aircraft, action, time) + radar_obs = get_radar_obs(own_aircraft, other_aircraft, action, time) logger.info(f"Radar observation {other_aircraft.aircraft.name}") sa_updates.append(radar_obs) - if self.is_observed_by_radar_detection(own_aircraft, other_aircraft, action): - radar_detection_obs = self.get_radar_detection_obs(own_aircraft, other_aircraft, action, time) + if self.is_observed_by_radar_detection(radar_detection_key, own_aircraft, other_aircraft, action): + radar_detection_obs = get_radar_detection_obs(own_aircraft, other_aircraft, action, time) logger.info(f"Radar detection observation of {own_aircraft.aircraft.name}") sa_updates.append(radar_detection_obs) # Update Situational awareness new_green = list() for own in state.green: - own_observations = self.trim_observations(own, state.green, sa_updates, time) - own_sa = self.construct_tracks(own, own_observations, time) + own_observations = trim_observations(own, state.green, sa_updates, time) + own_sa = construct_tracks(own, own_observations, time) assert len(own_sa) <= len(state.red) aircraft_obj = replace(own, sa=tuple(own_sa)) new_green.append(aircraft_obj) new_red = list() for own in state.red: - own_observations = self.trim_observations(own, state.red, sa_updates, time) - own_sa = self.construct_tracks(own, own_observations, time) + own_observations = trim_observations(own, state.red, sa_updates, time) + own_sa = construct_tracks(own, own_observations, time) assert len(own_sa) <= len(state.green) aircraft_obj = replace(own, sa=tuple(own_sa)) new_red.append(aircraft_obj) @@ -183,15 +160,17 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): return new_state - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> State: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> dm_env.TimeStep: rng_key = jax.random.PRNGKey(seed) if seed is not None else jax.random.PRNGKey(0) + self.scenario_step = 0 self.step_key, scenario_key = jax.random.split(rng_key) if self.use_static: self._generate_static_scenario() else: self._generate_scenario(scenario_key) - return self._state + ts = dm_env.TimeStep(step_type=dm_env.StepType.FIRST, reward=None, discount=None, observation=self._state) + return ts def _generate_static_scenario(self): # Green team initialisation @@ -216,22 +195,38 @@ class SensorControlEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): green_seed, red_seed = jax.random.split(scenario_key) green_keys = jax.random.split(green_seed, num=6) red_keys = jax.random.split(red_seed, num=6) - - no_greens = jax.random.randint(green_keys[0], 2, 7) - no_reds = jax.random.randint(red_keys[0], 2, 7) + + no_greens, no_reds = jax.random.randint(green_keys[0], shape=(2,), minval=2, maxval=7) + + green_x = jax.random.randint(green_keys[1], shape=(no_greens,), minval=0, maxval=20) + green_y = jax.random.randint(green_keys[2], shape=(no_greens,), minval=0, maxval=100) + green_vx = jax.random.uniform(green_keys[3], shape=(no_greens,), minval=0.15, maxval=0.3) + green_vy = jax.random.uniform(green_keys[4], shape=(no_greens,), minval=-0.1, maxval=0.1) + green_rcs = jax.random.uniform(green_keys[5], shape=(no_greens,), minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1]) + red_x = jax.random.randint(red_keys[1], shape=(no_reds,), minval=0, maxval=20) + red_y = jax.random.randint(red_keys[2], shape=(no_reds,), minval=0, maxval=100) + red_vx = -jax.random.uniform(red_keys[3], shape=(no_reds,), minval=0.15, maxval=0.3) #Note minus sign + red_vy = jax.random.uniform(red_keys[4], shape=(no_reds,), minval=-0.1, maxval=0.1) + red_rcs = jax.random.uniform(red_keys[5], shape=(no_reds,), minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1]) # Green team initialisation - greens = list() - for i in range(no_greens): - aircraftPos = AircraftPos(offset_x + jax.random.randint(green_keys[1], 0, 20), offset_y + jax.random.randint(green_keys[2], 0, 100)) - obj = AircraftObj(Aircraft(f"g{i}", aircraftPos, jax.random.uniform(green_keys[3], minval=0.15, maxval=0.3), jax.random.uniform(green_keys[4], minval=-0.1, maxval=0.1)), tuple(), jax.random.uniform(green_keys[5], minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1])) - greens.append(obj) - - reds = list() - for i in range(no_reds): - aircraftPos = AircraftPos(difference_x + offset_x + jax.random.randint(red_keys[1], 0, 20), offset_y + jax.random.randint(red_keys[2], 0, 100)) - obj = AircraftObj(Aircraft(f"r{i}", aircraftPos, -jax.random.uniform(red_keys[3], minval=0.15, maxval=0.3), jax.random.uniform(red_keys[4], minval=-0.1, maxval=0.1)), tuple(), jax.random.uniform(red_keys[5], minval=constants.RCS_SPAN[0], maxval=constants.RCS_SPAN[1])) - reds.append(obj) - - self._state = State(f"Scenario{self.scenario_counter}", 0, tuple(greens), tuple(reds), tuple(), tuple()) + green_aircraft_pos = [AircraftPos(offset_x + x, offset_y + y) for x, y in zip(green_x, green_y)] + greens = [AircraftObj(Aircraft(f"g{i}", aircraftPos, vx, vy), tuple(), rcs) for i, aircraftPos, vx, vy, rcs in zip(range(no_greens), green_aircraft_pos, green_vx, green_vy, green_rcs)] + # red team initialisation + red_aircraft_pos = [AircraftPos(difference_x + offset_x + x, offset_y + y) for x, y in zip(red_x, red_y)] + reds = [AircraftObj(Aircraft(f"r{i}", aircraftPos, vx, vy), tuple(), rcs) for i, aircraftPos, vx, vy, rcs in zip(range(no_reds), red_aircraft_pos, red_vx, red_vy, red_rcs)] + + # greens = list() + # for i, x, y, vx, vy, rcs in zip(range(no_greens), green_x, green_y, green_vx, green_vy, green_rcs): + # aircraftPos = AircraftPos(offset_x + x, offset_y + y) + # obj = AircraftObj(Aircraft(f"g{i}", aircraftPos, vx, vy), tuple(), rcs) + # greens.append(obj) + + # reds = list() + # for i, x, y, vx, vy, rcs in zip(range(no_reds), red_x, red_y, red_vx, red_vy, red_rcs): + # aircraftPos = AircraftPos(difference_x + offset_x + x, offset_y + y) + # obj = AircraftObj(Aircraft(f"r{i}", aircraftPos, vx, vy), tuple(), rcs) + # reds.append(obj) + + self._state = State(f"Scenario{self.scenario_id}", 0, tuple(greens), tuple(reds), tuple(), tuple()) diff --git a/main.py b/main.py index 4c9059b22fc9dda049c5533a3f03ca5bb3d2946b..31893d39fbfc625501b9eaf5ec30686f8907b854 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,16 @@ from behaviour import behaviour_choices def main() -> None: + logger = logging.getLogger("sensor-control") + logger.setLevel(logging.DEBUG) + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch_formatter = logging.Formatter("%(message)s") + ch.setFormatter(ch_formatter) + logger.addHandler(ch) + folder_name = 'tmp' + logger.info(f"Build dir: {folder_name}") + # Create argument parse parser = argparse.ArgumentParser() parser.add_argument("--seed", default=0, type=int) diff --git a/mdp_main.py b/mdp_main.py new file mode 100644 index 0000000000000000000000000000000000000000..cd28222e49cc3645996e69fc929247b828ca5352 --- /dev/null +++ b/mdp_main.py @@ -0,0 +1,182 @@ +import argparse +import random +import logging +import itertools +import datetime +import os +from pathlib import Path +import sys +import time +import jax.random + +from behaviour import behaviour_choices, get_behaviour +import constants +from gym_scenario import SensorControlEnv +from scenario import get_green_observe_state, get_red_observe_state, evaluate_state, cumulative_evaluation, draw_state, eval_to_csv +from scenario_io import write_scenario_to_file +from visualise import create_canvas + +def main_run_scenario(scenario_seed: int, green_behaviour_name: str, red_behaviour_name: str, + visualise: bool, visualise_delay: int, folder_name: str, + logger: logging.Logger, csv_logger: logging.Logger): + env_config = { + "scenario_id": scenario_seed, + "use_static": False + } + + logger.info("Generate scenario") + env = SensorControlEnv(env_config) + scenario_path = Path(folder_name + "/scenario.txt") + + green_behaviour = get_behaviour(green_behaviour_name) + red_behaviour = get_behaviour(red_behaviour_name) + cumulative_evaluation_dict = dict() + + logger.info("Start scenario") + start = datetime.datetime.now() + #Timing starts at a different point compared to the original code, + #since we include the reset call in the timing. + timestep = env.reset(seed=scenario_seed) + scenario_name = timestep.observation.name + logger.info(f"Writing scenario to {scenario_path.name}") + write_scenario_to_file(timestep.observation, scenario_path) + if visualise: + tk, canvas = create_canvas() + draw_state(timestep.observation, canvas, tk) + time.sleep(visualise_delay / 1000) + + timestep_index = 0 + while not timestep.last(): + logger.info(f"Time step {timestep_index}") + green_obs = get_green_observe_state(timestep.observation) + green_radar_actions = green_behaviour.act(green_obs) + red_obs = get_red_observe_state(timestep.observation) + red_radar_actions = red_behaviour.act(red_obs) + actions=tuple([a for a in itertools.chain(green_radar_actions, red_radar_actions)]) + timestep = env.step(actions) + evaluation_dict = evaluate_state(timestep.observation) + #This evaluation differs from the original code, since + #the old evaluation was s_t + a_t, while this is s_{t+1}, a_t + #In MDP formalism, the action taken at time t isn't effective until t+1. + #Hence any form of radar state as a consequence of actions taken at time t shouldn't + #be evaluated until t+1. + #"Evaluation" could also be inspecting what actions are <taken> in what context. + #That would be more of an inspection than an evaluation though, but then s_t + a_t + #is more relevant. + logger.info(f"{evaluation_dict}") + cumulative_evaluation_dict = cumulative_evaluation(evaluation_dict, cumulative_evaluation_dict, constants.SCENARIO_LENGTH) + if visualise: + #The original code doesn't draw the final state. + draw_state(timestep.observation, canvas, tk) + time.sleep(visualise_delay / 1000) + timestep_index += 1 + + logger.info("Cumulative results") + for k, v in cumulative_evaluation_dict.items(): + logger.info(f" {k}: {v}") + + cumulative_evaluation_dict["green__behaviour"] = green_behaviour_name + cumulative_evaluation_dict["red__behaviour"] = red_behaviour_name + cumulative_evaluation_dict["scenario_name"] = scenario_name + + csv_message = eval_to_csv(cumulative_evaluation_dict) + keys = eval_to_csv(cumulative_evaluation_dict, key_only=True) + csv_logger.info(csv_message) + logger.info(keys) + logger.info(csv_message) + + logger.info("Scenario finished") + + end = datetime.datetime.now() + time_spent = end - start + logger.info(f"Scenario took {time_spent}s to run") + logger.info(f"Build dir: {folder_name}") + +def main() -> None: + # Create argument parse + parser = argparse.ArgumentParser() + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--build_dir", help="The directory that the runs folder should be placed", default="build", type=str) + parser.add_argument("--tag", help="Tag of run", default="no_tag", type=str) + parser.add_argument("--green", default="simpleone", choices=behaviour_choices, type=str.lower) + parser.add_argument("--red", default="no", choices=behaviour_choices, type=str.lower) + parser.add_argument("--visualise", help="Should the run be visualised?", default=False, type=bool, action=argparse.BooleanOptionalAction) + parser.add_argument("--visualise_delay", help="Delay between each step in miliseconds", type=int, default=500) + parser.add_argument("--csvfile", help="Path to CSV-file for logging", type=str, default="") + parser.add_argument("--stream-log", help="Should the logging stream show?", default=True, type=bool, action=argparse.BooleanOptionalAction) + args = parser.parse_args() + + # Create directory for logging + folder_name = f"{args.build_dir}/{args.tag}/sensor-control_{args.green}_{args.red}_{args.seed}_{datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d_%H-%M-%S-%f')}" + if os.path.exists(folder_name): + raise RuntimeError(f"Build directory should not exist {folder_name}") + os.makedirs(folder_name) + + # Setup logging + logger = logging.getLogger("sensor-control") + logger.propagate = False #Jax adds handlers in parent logger. + logger.setLevel(logging.DEBUG) + if args.stream_log: + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch_formatter = logging.Formatter("%(message)s") + ch.setFormatter(ch_formatter) + logger.addHandler(ch) + fh = logging.FileHandler(folder_name + "/debug.log") + fh_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + fh.setFormatter(fh_formatter) + logger.addHandler(fh) + + csv_logger = logging.getLogger("csv") + csv_logger.propagate = False #Jax adds handlers in parent logger. + csv_logger.setLevel(logging.DEBUG) + if args.csvfile: + csvh = logging.FileHandler(args.csvfile, delay=True) + else: + csvh = logging.FileHandler(folder_name + "/log.csv", delay=True) + csvh_formatter = logging.Formatter("%(message)s") + csvh.setFormatter(csvh_formatter) + csv_logger.addHandler(csvh) + + # Perform main program + try: + # General info + logger.info(f"Build dir: {folder_name}") + logger.info(f"Green behaviour: {args.green}") + logger.info(f"Red behaviour: {args.red}") + logger.info(f"Seed: {args.seed}") + logger.info(f"Tag: {args.tag}") + logger.info(f"Visualise: {args.visualise}") + logger.info(f"Visualise delay: {args.visualise_delay}") + logger.info(f"CSV-file: {args.csvfile}") + + # User / Host + logger.info(f"{os.environ.get('HOSTNAME')}") + logger.info(f"{os.environ.get('USER')}") + + try: + import git # GitPython + repo = git.Repo(os.getcwd() + "/..") + logger.info(f"Git commit: {repo.head.commit}") + logger.info(f"Git is_dirty: {repo.is_dirty()}") + except Exception as e: + logger.error(e) + logger.error("Git information retrieval unsuccessful") + + # Ignore git parts from main.py + main_run_scenario(args.seed, args.green, args.red, + args.visualise, args.visualise_delay, folder_name, logger, csv_logger) + + # If program fails catch exception + except Exception as e: + logger.error(f"Exception occured:") + logger.error(f"{e}") + raise + + # Finished + logger.info("Application finished") + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index de75544f44d02067d997f65d99cdd085d22bb23f..9d26c1013d9871df04e0aa3ec0ed86c1abc0e9ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ gym jax jaxlib +dm_env