From 849771dbc18996f038cb4841cac7a02acfe392e3 Mon Sep 17 00:00:00 2001 From: Akram Date: Thu, 8 Aug 2024 11:54:22 +0200 Subject: [PATCH] Add comments and fix import errors --- webots/controllers/RL_Supervisor/agent.py | 14 +++++++++++--- webots/controllers/RL_Supervisor/plotting.py | 4 ++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/webots/controllers/RL_Supervisor/agent.py b/webots/controllers/RL_Supervisor/agent.py index b3cdc92..2d071d3 100644 --- a/webots/controllers/RL_Supervisor/agent.py +++ b/webots/controllers/RL_Supervisor/agent.py @@ -160,7 +160,8 @@ def predict_action(self, state): probs = self.__neural_network.actor_network(state) if self.train_mode is True: - # Create a normal distribution with the calculated probabilities and the standard deviation + # Create a normal distribution with the calculated probabilities + # and the standard deviation dist = tfp.distributions.Normal(probs, self.__std_dev) # Sampling an action from the normal distribution @@ -425,10 +426,17 @@ def learn(self, states, actions, old_probs, values, rewards, dones): # optimize Critic Network weights with tf.GradientTape() as tape: + # The critical value represents the expected return from state 𝑠𝑡. + # It provides an estimate of how good it is to be in a given state. critic_value = self.__neural_network.critic_network(states) - returns = advantages + values + + # the total discounted reward accumulated from time step 𝑡 + estimate_returns = advantages + values + # Generate loss - critic_loss = tf.math.reduce_mean(tf.math.pow(returns - critic_value, 2)) + critic_loss = tf.math.reduce_mean( + tf.math.pow(estimate_returns - critic_value, 2) + ) # calculate gradient critic_params = self.__neural_network.critic_network.trainable_variables diff --git a/webots/controllers/RL_Supervisor/plotting.py b/webots/controllers/RL_Supervisor/plotting.py index a6afd7f..4b20bf6 100644 --- a/webots/controllers/RL_Supervisor/plotting.py +++ b/webots/controllers/RL_Supervisor/plotting.py @@ -1,8 +1,8 @@ """ Plotting script with Matplotlib """ # Imports -import matplotlib.pyplot as plt -import pandas as pd +import matplotlib.pyplot as plt # pylint: disable=import-error +import pandas as pd # pylint: disable=import-error # Define the path to the CSV file LOG_FILE = "logs/training_logs.csv"