-
Marcus Gandal authoredMarcus Gandal authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
networks.py 4.92 KiB
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
ACTOR_FILE_PATH = "weights/actor_weights.h5"
CRITIC_FILE_PATH = "weights/critic_weights.h5"
TARGET_ACTOR_FILE_PATH = "weights/target_actor_weights.h5"
TARGET_CRITIC_FILE_PATH = "weights/target_critic_weights.h5"
class ActorCritic():
""" Actor and Critic networks along with related functions """
def __init__(self, state_shape, gamma, tau):
self.state_shape = state_shape
self.tau = tau
self.gamma = gamma
self.actor = self.build_actor()
self.critic = self.build_critic()
self.actor_optimizer = tf.keras.optimizers.Adam(0.001)
self.critic_optimizer = tf.keras.optimizers.Adam(0.001)
self.target_actor = self.build_actor()
self.target_critic = self.build_critic()
self.target_actor.set_weights(self.actor.get_weights())
self.target_critic.set_weights(self.critic.get_weights())
def build_actor(self):
""" Creates the actor network and returns it """
state_input = layers.Input(shape=self.state_shape)
fc1 = layers.Dense(512, activation="relu")(state_input)
fc2 = layers.Dense(512, activation="relu")(fc1)
fc3 = layers.Dense(512, activation="relu")(fc2)
lin_vel = layers.Dense(1, activation="sigmoid")(fc3)
ang_vel = layers.Dense(1, activation="tanh")(fc3)
output = layers.Concatenate()([lin_vel, ang_vel])
model = tf.keras.Model(state_input, output)
return model
def build_critic(self):
""" Creates the critic network and returns it """
state_input = layers.Input(shape=self.state_shape)
fc1 = layers.Dense(512, activation="relu")(state_input)
action_input = layers.Input(shape=2)
action_output = layers.Dense(512, activation="relu")(action_input)
concat = layers.Concatenate()([fc1, action_output])
fc2 = layers.Dense(512, activation="relu")(concat)
fc3 = layers.Dense(512, activation="relu")(fc2)
output = layers.Dense(1, activation="linear")(fc3)
model = tf.keras.Model([state_input, action_input], output)
return model
# Update critic and actor networks
# Source: https://keras.io/examples/rl/ddpg_pendulum/
def update_networks(self, states, actions, rewards, new_states, dones):
# Update Critic
with tf.GradientTape() as tape:
# Forward propagate target actor to predict an action for each state in batch
target_actions = self.target_actor(new_states, training=True)
y = rewards + self.gamma * self.target_critic([new_states, target_actions], training=True) * (1 - dones)
# Get critic value and compute loss between critic value and target
critic_value = self.critic([states, actions], training=True)
critic_loss = tf.math.reduce_mean(tf.math.square(y - critic_value))
#print(critic_loss)
critic_grad = tape.gradient(critic_loss, self.critic.trainable_variables)
self.critic_optimizer.apply_gradients(zip(critic_grad, self.critic.trainable_variables))
# Update Actor
with tf.GradientTape() as tape:
_actions = self.actor(states, training=True)
_critic_value = self.critic([states, _actions], training=True)
actor_loss = -tf.math.reduce_mean(_critic_value)
#print("actor loss:", actor_loss)
actor_grad = tape.gradient(actor_loss, self.actor.trainable_variables)
self.actor_optimizer.apply_gradients(zip(actor_grad, self.actor.trainable_variables))
# Update target networks
# Source: https://github.com/philtabor/Youtube-Code-Repository/tree/master/ReinforcementLearning/PolicyGradient/DDPG/tensorflow2/pendulum
def update_target_networks(self):
""" Updates target actor and critic """
weights = []
targets = self.target_actor.weights
for i, weight in enumerate(self.actor.weights):
weights.append(weight * self.tau + targets[i] * (1.0 - self.tau))
self.target_actor.set_weights(weights)
weights = []
targets = self.target_critic.weights
for i, weight in enumerate(self.critic.weights):
weights.append(weight * self.tau + targets[i] * (1.0 - self.tau))
self.target_critic.set_weights(weights)
def save_weights_to_file(self):
self.actor.save_weights(ACTOR_FILE_PATH)
self.critic.save_weights(CRITIC_FILE_PATH)
self.target_actor.save_weights(TARGET_ACTOR_FILE_PATH)
self.target_critic.save_weights(TARGET_CRITIC_FILE_PATH)
print("Weights saved.")
def load_weights_from_file(self):
self.actor.load_weights(ACTOR_FILE_PATH)
self.critic.load_weights(CRITIC_FILE_PATH)
self.target_actor.load_weights(TARGET_ACTOR_FILE_PATH)
self.target_critic.load_weights(TARGET_CRITIC_FILE_PATH)
print("Weights loaded.")