Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
networks.py 4.92 KiB
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

ACTOR_FILE_PATH = "weights/actor_weights.h5"
CRITIC_FILE_PATH = "weights/critic_weights.h5"

TARGET_ACTOR_FILE_PATH = "weights/target_actor_weights.h5"
TARGET_CRITIC_FILE_PATH = "weights/target_critic_weights.h5"


class ActorCritic():
    """ Actor and Critic networks along with related functions """

    def __init__(self, state_shape, gamma, tau):
        self.state_shape = state_shape
        self.tau = tau
        self.gamma = gamma

        self.actor = self.build_actor()
        self.critic = self.build_critic()

        self.actor_optimizer = tf.keras.optimizers.Adam(0.001)
        self.critic_optimizer = tf.keras.optimizers.Adam(0.001)

        self.target_actor = self.build_actor()
        self.target_critic = self.build_critic()

        self.target_actor.set_weights(self.actor.get_weights())
        self.target_critic.set_weights(self.critic.get_weights())


    def build_actor(self):
        """ Creates the actor network and returns it """

        state_input = layers.Input(shape=self.state_shape)
        fc1 = layers.Dense(512, activation="relu")(state_input)
        fc2 = layers.Dense(512, activation="relu")(fc1)
        fc3 = layers.Dense(512, activation="relu")(fc2)

        lin_vel = layers.Dense(1, activation="sigmoid")(fc3)
        ang_vel = layers.Dense(1, activation="tanh")(fc3)

        output = layers.Concatenate()([lin_vel, ang_vel])
        model = tf.keras.Model(state_input, output)
        return model


    def build_critic(self):
        """ Creates the critic network and returns it """

        state_input = layers.Input(shape=self.state_shape)
        fc1 = layers.Dense(512, activation="relu")(state_input)

        action_input = layers.Input(shape=2)
        action_output = layers.Dense(512, activation="relu")(action_input)
        
        concat = layers.Concatenate()([fc1, action_output])
        fc2 = layers.Dense(512, activation="relu")(concat)
        fc3 = layers.Dense(512, activation="relu")(fc2)
        
        output = layers.Dense(1, activation="linear")(fc3)
        model = tf.keras.Model([state_input, action_input], output)
        return model


    # Update critic and actor networks 
    # Source: https://keras.io/examples/rl/ddpg_pendulum/
    def update_networks(self, states, actions, rewards, new_states, dones):

        # Update Critic
        with tf.GradientTape() as tape:
        
            # Forward propagate target actor to predict an action for each state in batch
            target_actions = self.target_actor(new_states, training=True)
            y = rewards + self.gamma * self.target_critic([new_states, target_actions], training=True) * (1 - dones)

            # Get critic value and compute loss between critic value and target
            critic_value = self.critic([states, actions], training=True)
            critic_loss = tf.math.reduce_mean(tf.math.square(y - critic_value))

        #print(critic_loss)
        critic_grad = tape.gradient(critic_loss, self.critic.trainable_variables)
        self.critic_optimizer.apply_gradients(zip(critic_grad, self.critic.trainable_variables))

        # Update Actor
        with tf.GradientTape() as tape:
            
            _actions = self.actor(states, training=True)
            _critic_value = self.critic([states, _actions], training=True)
            actor_loss = -tf.math.reduce_mean(_critic_value)
            #print("actor loss:", actor_loss)
        
        actor_grad = tape.gradient(actor_loss, self.actor.trainable_variables)
        self.actor_optimizer.apply_gradients(zip(actor_grad, self.actor.trainable_variables))


    # Update target networks
    # Source: https://github.com/philtabor/Youtube-Code-Repository/tree/master/ReinforcementLearning/PolicyGradient/DDPG/tensorflow2/pendulum
    def update_target_networks(self):
        """ Updates target actor and critic """

        weights = []
        targets = self.target_actor.weights
        for i, weight in enumerate(self.actor.weights):
            weights.append(weight * self.tau + targets[i] * (1.0 - self.tau))
        self.target_actor.set_weights(weights)

        weights = []
        targets = self.target_critic.weights
        for i, weight in enumerate(self.critic.weights):
            weights.append(weight * self.tau + targets[i] * (1.0 - self.tau))
        self.target_critic.set_weights(weights)
        

    def save_weights_to_file(self):
        self.actor.save_weights(ACTOR_FILE_PATH)
        self.critic.save_weights(CRITIC_FILE_PATH)
        self.target_actor.save_weights(TARGET_ACTOR_FILE_PATH)
        self.target_critic.save_weights(TARGET_CRITIC_FILE_PATH)
        print("Weights saved.")


    def load_weights_from_file(self):
        self.actor.load_weights(ACTOR_FILE_PATH)
        self.critic.load_weights(CRITIC_FILE_PATH)
        self.target_actor.load_weights(TARGET_ACTOR_FILE_PATH)
        self.target_critic.load_weights(TARGET_CRITIC_FILE_PATH)
        print("Weights loaded.")