Skip to content

Commit

Permalink
Add comments and fix import errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Akram authored and Akram committed Aug 8, 2024
1 parent 93d2b39 commit 849771d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
14 changes: 11 additions & 3 deletions webots/controllers/RL_Supervisor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def predict_action(self, state):
probs = self.__neural_network.actor_network(state)

if self.train_mode is True:
# Create a normal distribution with the calculated probabilities and the standard deviation
# Create a normal distribution with the calculated probabilities
# and the standard deviation
dist = tfp.distributions.Normal(probs, self.__std_dev)

# Sampling an action from the normal distribution
Expand Down Expand Up @@ -425,10 +426,17 @@ def learn(self, states, actions, old_probs, values, rewards, dones):
# optimize Critic Network weights
with tf.GradientTape() as tape:

# The critical value represents the expected return from state 𝑠𝑡.
# It provides an estimate of how good it is to be in a given state.
critic_value = self.__neural_network.critic_network(states)
returns = advantages + values

# the total discounted reward accumulated from time step 𝑡
estimate_returns = advantages + values

# Generate loss
critic_loss = tf.math.reduce_mean(tf.math.pow(returns - critic_value, 2))
critic_loss = tf.math.reduce_mean(
tf.math.pow(estimate_returns - critic_value, 2)
)

# calculate gradient
critic_params = self.__neural_network.critic_network.trainable_variables
Expand Down
4 changes: 2 additions & 2 deletions webots/controllers/RL_Supervisor/plotting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""" Plotting script with Matplotlib """

# Imports
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.pyplot as plt # pylint: disable=import-error
import pandas as pd # pylint: disable=import-error

# Define the path to the CSV file
LOG_FILE = "logs/training_logs.csv"
Expand Down

0 comments on commit 849771d

Please sign in to comment.