From dda71db978347f7d65349576b5e8bc2ce020f639 Mon Sep 17 00:00:00 2001
From: Emil Karlsson <emil.j.karlsson@gmail.com>
Date: Mon, 7 Mar 2022 14:25:33 +0100
Subject: [PATCH] RL WIP

---
 validation-env/action_selectors.py |  82 +++++++++++++++++++++++
 validation-env/behaviour.py        | 101 +++++++++++++++++++++++++++++
 validation-env/rnn_agent.py        |  31 +++++++++
 validation-env/rnn_ns_agent.py     |  50 ++++++++++++++
 4 files changed, 264 insertions(+)
 create mode 100644 validation-env/action_selectors.py
 create mode 100644 validation-env/rnn_agent.py
 create mode 100644 validation-env/rnn_ns_agent.py

diff --git a/validation-env/action_selectors.py b/validation-env/action_selectors.py
new file mode 100644
index 00000000..d9a02d26
--- /dev/null
+++ b/validation-env/action_selectors.py
@@ -0,0 +1,82 @@
+import torch as th
+from torch.distributions import Categorical
+#from .epsilon_schedules import DecayThenFlatSchedule
+REGISTRY = {}
+
+
+class MultinomialActionSelector():
+
+    def __init__(self, args):
+        self.args = args
+
+        self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
+                                              decay="linear")
+        self.epsilon = self.schedule.eval(0)
+        self.test_greedy = getattr(args, "test_greedy", True)
+
+    def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
+        masked_policies = agent_inputs.clone()
+        masked_policies[avail_actions == 0.0] = 0.0
+
+        self.epsilon = self.schedule.eval(t_env)
+
+        if test_mode and self.test_greedy:
+            picked_actions = masked_policies.max(dim=2)[1]
+        else:
+            picked_actions = Categorical(masked_policies).sample().long()
+
+        return picked_actions
+
+
+REGISTRY["multinomial"] = MultinomialActionSelector
+
+
+class EpsilonGreedyActionSelector():
+
+    def __init__(self, args=None):
+        pass
+        #self.args = args
+
+        #self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
+        #                                      decay="linear")
+        #self.epsilon = self.schedule.eval(0)
+
+    def _select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
+
+        # Assuming agent_inputs is a batch of Q-Values for each agent bav
+        self.epsilon = self.schedule.eval(t_env)
+
+        if test_mode:
+            # Greedy action selection only
+            self.epsilon = self.args.evaluation_epsilon
+
+        # mask actions that are excluded from selection
+        masked_q_values = agent_inputs.clone()
+        masked_q_values[avail_actions == 0.0] = -float("inf")  # should never be selected!
+
+        random_numbers = th.rand_like(agent_inputs[:, :, 0])
+        pick_random = (random_numbers < self.epsilon).long()
+        random_actions = Categorical(avail_actions.float()).sample().long()
+
+        picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1]
+        return picked_actions
+
+    def select_action(self, agent_inputs):
+        return agent_inputs.max(dim=2)[1]
+
+
+REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector
+
+
+class SoftPoliciesSelector():
+
+    def __init__(self, args):
+        self.args = args
+
+    def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
+        m = Categorical(agent_inputs)
+        picked_actions = m.sample().long()
+        return picked_actions
+
+
+REGISTRY["soft_policies"] = SoftPoliciesSelector
\ No newline at end of file
diff --git a/validation-env/behaviour.py b/validation-env/behaviour.py
index d6324aac..d95e77f6 100644
--- a/validation-env/behaviour.py
+++ b/validation-env/behaviour.py
@@ -6,6 +6,9 @@ Contains definitions of behaviours that can be evaluated
 """
 from typing import Optional
 
+import numpy as np
+import torch as th
+
 from data import ObservedState, RadarAction, AircraftObj
 from mip_behaviour import empty_grid, state_to_actions_with_ip
 
@@ -102,6 +105,102 @@ class IpTwoRadar(Behaviour):
         return self.actions
 
 
+def convert_to_np(aircraft_name: str, state: ObservedState) -> np.array:
+    a_obj = None
+    friends = list()
+    for obj in state.own:
+        if obj.aircraft.name == aircraft_name:
+            a_obj = obj
+        friends.append(obj.aircraft.position.x)
+        friends.append(obj.aircraft.position.y)
+    assert a_obj is not None
+    array = th.zeros(4, 24)
+    array[:, 0] = a_obj.aircraft.velocity_x
+    array[:, 1] = a_obj.aircraft.velocity_y
+    array[:, 2] = a_obj.aircraft.position.x
+    array[:, 3] = a_obj.aircraft.position.y
+    for i in range(8):
+        if len(friends) < i:
+            array[:, 4 + i] = friends[i]
+        else:
+            array[:, 4 + i] = 0
+    enemies = 0
+    for enemy in enumerate(a_obj.sa):
+        if enemies < 4:
+            break
+        else:
+            array[:, 12 + 2 * enemies] = enemy.aircraft.position.x
+            array[:, 12 + 2 * enemies + 1] = enemy.aircraft.position.y
+    
+    return array
+
+    #np.concatenate([agent.state.p_vel] + [agent.state.p_pos] + entity_pos + other_pos + comm)
+    #return th.from_numpy(np.concatenate([[np.array([0, 0,
+    #                0, 0,
+    #                0, 0,  # Enemy
+    #                0, 0,  # Enemy
+    #                0, 0,  # Enemy
+    #                0, 0,  # Enemy
+    #                0, 0,  # Friend
+    #                0, 0,  # Friend
+    #                0, 0,  # Friend  ( 18)
+    #                0, 0,  # Friend action
+    #                0, 0,  # Friend action
+    #                0, 0])]]))  # Friend action
+
+
+def nr_to_action(a_obj: AircraftObj, number: int) -> Optional[RadarAction]:
+    action = None
+    if number == 0:
+        action = RadarAction(a_obj.aircraft.name, (-55, 55))
+    elif number == 1:
+        action = RadarAction(a_obj.aircraft.name, (-22.5, 22.5))
+    elif number == 2:
+        action = RadarAction(a_obj.aircraft.name, (10, 55))
+    elif number == 3:
+        action = RadarAction(a_obj.aircraft.name, (-55, 10))
+    return action
+
+
+class RlRadar(Behaviour):
+    agent = None
+    def __init__(self, th_path: str = "/home/elanclar/work/epymarl_bak/results/models/iql_ns_seed4_mpe:SimpleSpread-v0_2022-01-18 12:54:20.877694/2500025/agent.th"):
+        import torch as th
+        th_path = "/home/elanclar/work/epymarl_bak/src/results/models/iql_ns_seed11270277_mpe:SimpleSpread-v0_2022-02-01 17:44:49.518840/1025/agent.th"
+        from rnn_ns_agent import RNNNSAgent
+        self.agent = RNNNSAgent()
+        #print(th.load(th_path))
+        self.agent.load_state_dict(th.load(th_path))
+        self.agent.hidden_state = self.agent.init_hidden()
+        #print(self.agent)
+        self.agent.eval()
+        #from action_selectors import EpsilonGreedyActionSelection
+        #self.action_selector = EpsilonGreedyActionSelection()
+
+    def act(self, state: ObservedState) -> tuple[RadarAction, ...]:
+        actions = list()
+        for own in state.own:
+            data = convert_to_np(own.aircraft.name, state)
+            action_tensor = self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=1)[1]
+            assert 1 not in self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=0)[1]
+            assert 1 not in self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=-1)[1]
+            assert 1 not in self.agent.agents[0](data.float(), self.agent.hidden_state)[0].max(dim=-2)[1]
+            if action_tensor[0] == 1:
+                nr = 0
+            elif action_tensor[1] == 1:
+                nr = 1
+            elif action_tensor[2] == 1:
+                nr = 2
+            elif action_tensor[3] == 1:
+                nr = 3
+            else:
+                nr = None
+            if nr is not None:
+                action = nr_to_action(own, nr)
+                actions.append(action)
+        return tuple(actions)
+
+
 class StrmAllRadar(Behaviour):
     def act(self, state: ObservedState) -> tuple[RadarAction, ...]:
         no_obj = len(state.own)
@@ -143,6 +242,8 @@ def get_behaviour(behaviour_name: str) -> Behaviour:
         return StrmTwoHalfRadar()
     elif behaviour_name == "strmone":
         return StrmOneRadar()
+    elif behaviour_name == "rl":
+        return RlRadar()
     elif behaviour_name == "ipone":
         return IpOneRadar()
     elif behaviour_name == "iptwo":
diff --git a/validation-env/rnn_agent.py b/validation-env/rnn_agent.py
new file mode 100644
index 00000000..df739279
--- /dev/null
+++ b/validation-env/rnn_agent.py
@@ -0,0 +1,31 @@
+from dataclasses import dataclass
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class RNNAgent(nn.Module):
+    def __init__(self, input_shape, args):
+        super(RNNAgent, self).__init__()
+        self.args = args
+
+        self.fc1 = nn.Linear(input_shape, args.hidden_dim)
+        if self.args.use_rnn:
+            self.rnn = nn.GRUCell(args.hidden_dim, args.hidden_dim)
+        else:
+            self.rnn = nn.Linear(args.hidden_dim, args.hidden_dim)
+        self.fc2 = nn.Linear(args.hidden_dim, args.n_actions)
+
+    def init_hidden(self):
+        # make hidden states on same device as model
+        return self.fc1.weight.new(1, self.args.hidden_dim).zero_()
+
+    def forward(self, inputs, hidden_state):
+        x = F.relu(self.fc1(inputs))
+        h_in = hidden_state.reshape(-1, self.args.hidden_dim)
+        if self.args.use_rnn:
+            h = self.rnn(x, h_in)
+        else:
+            h = F.relu(self.rnn(x))
+        q = self.fc2(h)
+        return q, h
diff --git a/validation-env/rnn_ns_agent.py b/validation-env/rnn_ns_agent.py
new file mode 100644
index 00000000..c38e0238
--- /dev/null
+++ b/validation-env/rnn_ns_agent.py
@@ -0,0 +1,50 @@
+from dataclasses import dataclass
+
+import torch.nn as nn
+from rnn_agent import RNNAgent
+import torch as th
+
+
+@dataclass
+class DefaultArgs:
+    hidden_dim: int = 128
+    n_actions: int = 2
+    use_rnn: bool = True
+    n_agents: int = 4
+
+
+INPUT_SHAPE = 24
+
+
+class RNNNSAgent(nn.Module):
+    def __init__(self, input_shape=INPUT_SHAPE, args=DefaultArgs()):
+        super(RNNNSAgent, self).__init__()
+        self.args = args
+        self.n_agents = args.n_agents
+        self.input_shape = input_shape
+        self.agents = th.nn.ModuleList([RNNAgent(input_shape, args) for _ in range(self.n_agents)])
+
+    def init_hidden(self):
+        # make hidden states on same device as model
+        return th.cat([a.init_hidden() for a in self.agents])
+
+    def forward(self, inputs, hidden_state):
+        hiddens = []
+        qs = []
+        if inputs.size(0) == self.n_agents:
+            for i in range(self.n_agents):
+                q, h = self.agents[i](inputs[i].unsqueeze(0), hidden_state[:, i])
+                hiddens.append(h)
+                qs.append(q)
+            return th.cat(qs), th.cat(hiddens).unsqueeze(0)
+        else:
+            for i in range(self.n_agents):
+                inputs = inputs.view(-1, self.n_agents, self.input_shape)
+                q, h = self.agents[i](inputs[:, i], hidden_state[:, i])
+                hiddens.append(h.unsqueeze(1))
+                qs.append(q.unsqueeze(1))
+            return th.cat(qs, dim=-1).view(-1, q.size(-1)), th.cat(hiddens, dim=1)
+
+    def cuda(self, device="cuda:0"):
+        for a in self.agents:
+            a.cuda(device=device)
-- 
GitLab