From dda71db978347f7d65349576b5e8bc2ce020f639 Mon Sep 17 00:00:00 2001 From: Emil Karlsson <emil.j.karlsson@gmail.com> Date: Mon, 7 Mar 2022 14:25:33 +0100 Subject: [PATCH] RL WIP --- validation-env/action_selectors.py | 82 +++++++++++++++++++++++ validation-env/behaviour.py | 101 +++++++++++++++++++++++++++++ validation-env/rnn_agent.py | 31 +++++++++ validation-env/rnn_ns_agent.py | 50 ++++++++++++++ 4 files changed, 264 insertions(+) create mode 100644 validation-env/action_selectors.py create mode 100644 validation-env/rnn_agent.py create mode 100644 validation-env/rnn_ns_agent.py diff --git a/validation-env/action_selectors.py b/validation-env/action_selectors.py new file mode 100644 index 00000000..d9a02d26 --- /dev/null +++ b/validation-env/action_selectors.py @@ -0,0 +1,82 @@ +import torch as th +from torch.distributions import Categorical +#from .epsilon_schedules import DecayThenFlatSchedule +REGISTRY = {} + + +class MultinomialActionSelector(): + + def __init__(self, args): + self.args = args + + self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, + decay="linear") + self.epsilon = self.schedule.eval(0) + self.test_greedy = getattr(args, "test_greedy", True) + + def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): + masked_policies = agent_inputs.clone() + masked_policies[avail_actions == 0.0] = 0.0 + + self.epsilon = self.schedule.eval(t_env) + + if test_mode and self.test_greedy: + picked_actions = masked_policies.max(dim=2)[1] + else: + picked_actions = Categorical(masked_policies).sample().long() + + return picked_actions + + +REGISTRY["multinomial"] = MultinomialActionSelector + + +class EpsilonGreedyActionSelector(): + + def __init__(self, args=None): + pass + #self.args = args + + #self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, + # decay="linear") + #self.epsilon = self.schedule.eval(0) + + def _select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): + + # Assuming agent_inputs is a batch of Q-Values for each agent bav + self.epsilon = self.schedule.eval(t_env) + + if test_mode: + # Greedy action selection only + self.epsilon = self.args.evaluation_epsilon + + # mask actions that are excluded from selection + masked_q_values = agent_inputs.clone() + masked_q_values[avail_actions == 0.0] = -float("inf") # should never be selected! + + random_numbers = th.rand_like(agent_inputs[:, :, 0]) + pick_random = (random_numbers < self.epsilon).long() + random_actions = Categorical(avail_actions.float()).sample().long() + + picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1] + return picked_actions + + def select_action(self, agent_inputs): + return agent_inputs.max(dim=2)[1] + + +REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector + + +class SoftPoliciesSelector(): + + def __init__(self, args): + self.args = args + + def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): + m = Categorical(agent_inputs) + picked_actions = m.sample().long() + return picked_actions + + +REGISTRY["soft_policies"] = SoftPoliciesSelector \ No newline at end of file diff --git a/validation-env/behaviour.py b/validation-env/behaviour.py index d6324aac..d95e77f6 100644 --- a/validation-env/behaviour.py +++ b/validation-env/behaviour.py @@ -6,6 +6,9 @@ Contains definitions of behaviours that can be evaluated """ from typing import Optional +import numpy as np +import torch as th + from data import ObservedState, RadarAction, AircraftObj from mip_behaviour import empty_grid, state_to_actions_with_ip @@ -102,6 +105,102 @@ class IpTwoRadar(Behaviour): return self.actions +def convert_to_np(aircraft_name: str, state: ObservedState) -> np.array: + a_obj = None + friends = list() + for obj in state.own: + if obj.aircraft.name == aircraft_name: + a_obj = obj + friends.append(obj.aircraft.position.x) + friends.append(obj.aircraft.position.y) + assert a_obj is not None + array = th.zeros(4, 24) + array[:, 0] = a_obj.aircraft.velocity_x + array[:, 1] = a_obj.aircraft.velocity_y + array[:, 2] = a_obj.aircraft.position.x + array[:, 3] = a_obj.aircraft.position.y + for i in range(8): + if len(friends) < i: + array[:, 4 + i] = friends[i] + else: + array[:, 4 + i] = 0 + enemies = 0 + for enemy in enumerate(a_obj.sa): + if enemies < 4: + break + else: + array[:, 12 + 2 * enemies] = enemy.aircraft.position.x + array[:, 12 + 2 * enemies + 1] = enemy.aircraft.position.y + + return array + + #np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + comm) + #return th.from_numpy(np.concatenate([[np.array([0, 0, + # 0, 0, + # 0, 0, # Enemy + # 0, 0, # Enemy + # 0, 0, # Enemy + # 0, 0, # Enemy + # 0, 0, # Friend + # 0, 0, # Friend + # 0, 0, # Friend ( 18) + # 0, 0, # Friend action + # 0, 0, # Friend action + # 0, 0])]])) # Friend action + + +def nr_to_action(a_obj: AircraftObj, number: int) -> Optional[RadarAction]: + action = None + if number == 0: + action = RadarAction(a_obj.aircraft.name, (-55, 55)) + elif number == 1: + action = RadarAction(a_obj.aircraft.name, (-22.5, 22.5)) + elif number == 2: + action = RadarAction(a_obj.aircraft.name, (10, 55)) + elif number == 3: + action = RadarAction(a_obj.aircraft.name, (-55, 10)) + return action + + +class RlRadar(Behaviour): + agent = None + def __init__(self, th_path: str = "/home/elanclar/work/epymarl_bak/results/models/iql_ns_seed4_mpe:SimpleSpread-v0_2022-01-18 12:54:20.877694/2500025/agent.th"): + import torch as th + th_path = "/home/elanclar/work/epymarl_bak/src/results/models/iql_ns_seed11270277_mpe:SimpleSpread-v0_2022-02-01 17:44:49.518840/1025/agent.th" + from rnn_ns_agent import RNNNSAgent + self.agent = RNNNSAgent() + #print(th.load(th_path)) + self.agent.load_state_dict(th.load(th_path)) + self.agent.hidden_state = self.agent.init_hidden() + #print(self.agent) + self.agent.eval() + #from action_selectors import EpsilonGreedyActionSelection + #self.action_selector = EpsilonGreedyActionSelection() + + def act(self, state: ObservedState) -> tuple[RadarAction, ...]: + actions = list() + for own in state.own: + data = convert_to_np(own.aircraft.name, state) + action_tensor = self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=1)[1] + assert 1 not in self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=0)[1] + assert 1 not in self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=-1)[1] + assert 1 not in self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=-2)[1] + if action_tensor[0] == 1: + nr = 0 + elif action_tensor[1] == 1: + nr = 1 + elif action_tensor[2] == 1: + nr = 2 + elif action_tensor[3] == 1: + nr = 3 + else: + nr = None + if nr is not None: + action = nr_to_action(own, nr) + actions.append(action) + return tuple(actions) + + class StrmAllRadar(Behaviour): def act(self, state: ObservedState) -> tuple[RadarAction, ...]: no_obj = len(state.own) @@ -143,6 +242,8 @@ def get_behaviour(behaviour_name: str) -> Behaviour: return StrmTwoHalfRadar() elif behaviour_name == "strmone": return StrmOneRadar() + elif behaviour_name == "rl": + return RlRadar() elif behaviour_name == "ipone": return IpOneRadar() elif behaviour_name == "iptwo": diff --git a/validation-env/rnn_agent.py b/validation-env/rnn_agent.py new file mode 100644 index 00000000..df739279 --- /dev/null +++ b/validation-env/rnn_agent.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +import torch.nn as nn +import torch.nn.functional as F + + +class RNNAgent(nn.Module): + def __init__(self, input_shape, args): + super(RNNAgent, self).__init__() + self.args = args + + self.fc1 = nn.Linear(input_shape, args.hidden_dim) + if self.args.use_rnn: + self.rnn = nn.GRUCell(args.hidden_dim, args.hidden_dim) + else: + self.rnn = nn.Linear(args.hidden_dim, args.hidden_dim) + self.fc2 = nn.Linear(args.hidden_dim, args.n_actions) + + def init_hidden(self): + # make hidden states on same device as model + return self.fc1.weight.new(1, self.args.hidden_dim).zero_() + + def forward(self, inputs, hidden_state): + x = F.relu(self.fc1(inputs)) + h_in = hidden_state.reshape(-1, self.args.hidden_dim) + if self.args.use_rnn: + h = self.rnn(x, h_in) + else: + h = F.relu(self.rnn(x)) + q = self.fc2(h) + return q, h diff --git a/validation-env/rnn_ns_agent.py b/validation-env/rnn_ns_agent.py new file mode 100644 index 00000000..c38e0238 --- /dev/null +++ b/validation-env/rnn_ns_agent.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass + +import torch.nn as nn +from rnn_agent import RNNAgent +import torch as th + + +@dataclass +class DefaultArgs: + hidden_dim: int = 128 + n_actions: int = 2 + use_rnn: bool = True + n_agents: int = 4 + + +INPUT_SHAPE = 24 + + +class RNNNSAgent(nn.Module): + def __init__(self, input_shape=INPUT_SHAPE, args=DefaultArgs()): + super(RNNNSAgent, self).__init__() + self.args = args + self.n_agents = args.n_agents + self.input_shape = input_shape + self.agents = th.nn.ModuleList([RNNAgent(input_shape, args) for _ in range(self.n_agents)]) + + def init_hidden(self): + # make hidden states on same device as model + return th.cat([a.init_hidden() for a in self.agents]) + + def forward(self, inputs, hidden_state): + hiddens = [] + qs = [] + if inputs.size(0) == self.n_agents: + for i in range(self.n_agents): + q, h = self.agents[i](inputs[i].unsqueeze(0), hidden_state[:, i]) + hiddens.append(h) + qs.append(q) + return th.cat(qs), th.cat(hiddens).unsqueeze(0) + else: + for i in range(self.n_agents): + inputs = inputs.view(-1, self.n_agents, self.input_shape) + q, h = self.agents[i](inputs[:, i], hidden_state[:, i]) + hiddens.append(h.unsqueeze(1)) + qs.append(q.unsqueeze(1)) + return th.cat(qs, dim=-1).view(-1, q.size(-1)), th.cat(hiddens, dim=1) + + def cuda(self, device="cuda:0"): + for a in self.agents: + a.cuda(device=device) -- GitLab