Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patrik #217

Open
wants to merge 57 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
b3fd098
reset_last_layer in atari_lib.py
Mattia-Colbertaldo Nov 8, 2023
78a8670
Update atari_lib.py
Mattia-Colbertaldo Nov 8, 2023
222642e
ResetLastLayers call in run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
0cc2a27
ResetLastLayers in dqn_agent.py
Mattia-Colbertaldo Nov 8, 2023
bc98e52
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
4295642
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
d587d56
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
e5634e4
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
4efe9d0
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
a2e5c17
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
b3f962c
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
a4ec827
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
50f2cfd
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
2e2cad5
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
d0f1bb5
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
e576019
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
bd8fd6a
Update run_experiment.py
Mattia-Colbertaldo Nov 8, 2023
dc11db3
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
dd1c694
Update dqn_agent.py
Mattia-Colbertaldo Nov 8, 2023
d5927b6
Update dqn_agent.py
Mattia-Colbertaldo Nov 8, 2023
5419474
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
cef3975
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
0d735ac
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
ff9b779
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
eba9aa2
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
ae778e1
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
63f46e3
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
43023e4
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
f649675
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
29a91fa
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
d80df95
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
bca846c
Update gym_lib.py
Mattia-Colbertaldo Nov 8, 2023
b6b3af5
Update gym_lib.py
patrikrac Nov 8, 2023
3f35f75
Update gym_lib.py
patrikrac Nov 8, 2023
c101f64
Added a reset weights method to the DQNAgent class
patrikrac Nov 11, 2023
6f42a10
Added a reset weights method to the DQNAgent class
patrikrac Nov 11, 2023
4835b95
Added a reset weights method to the DQNAgent class
patrikrac Nov 11, 2023
d51f3ee
removed target net update
patrikrac Nov 13, 2023
27693fc
Added some funky layer reset methodology
patrikrac Nov 13, 2023
1046cae
Some more tests
patrikrac Nov 13, 2023
772384a
Tests
patrikrac Nov 13, 2023
5970b1c
Tests
patrikrac Nov 13, 2023
8434aa0
Tests
patrikrac Nov 13, 2023
e9946a7
Tests
patrikrac Nov 13, 2023
017cf3f
Tests
patrikrac Nov 13, 2023
f364c28
Tests
patrikrac Nov 13, 2023
7e58add
Tests
patrikrac Nov 13, 2023
107d4a3
Tests
patrikrac Nov 13, 2023
97ee391
Tests
patrikrac Nov 13, 2023
6057221
Tests
patrikrac Nov 13, 2023
7d6a99b
Tests
patrikrac Nov 13, 2023
82f47ea
Added three new system variables to the agent
patrikrac Nov 13, 2023
bd3b1a9
Some minor adjustments on the branch
patrikrac Nov 14, 2023
2e2aa11
Stopped agent from reseting at itertaion 0
patrikrac Nov 15, 2023
012e1d5
Stopped agent from reseting at itertaion 0
patrikrac Nov 15, 2023
536e219
Added some much needed consistency to the reseting
patrikrac Nov 18, 2023
3a57732
Some more qol changes
patrikrac Nov 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion dopamine/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def __init__(self,
centered=True),
summary_writer=None,
summary_writing_frequency=500,
allow_partial_reload=False):
allow_partial_reload=False,
reset_period=None,
reset_dense1=False,
reset_dense2=False,
reset_last_layer=False,
reset_max=3):
"""Initializes the agent and constructs the components of its graph.

Args:
Expand Down Expand Up @@ -182,6 +187,16 @@ def __init__(self,
self.eval_mode = eval_mode
self.training_steps = 0
self.optimizer = optimizer
# Modified
self.optimizer_state = self.optimizer.variables()

self.reset_period = reset_period
self.reset_dense1 = reset_dense1
self.reset_dense2 = reset_dense2
self.reset_last_layer = reset_last_layer
self.reset_max = reset_max
self.reset_counter = 0

tf.compat.v1.disable_v2_behavior()
if isinstance(summary_writer, str): # If we're passing in directory name.
self.summary_writer = tf.compat.v1.summary.FileWriter(summary_writer)
Expand Down Expand Up @@ -210,6 +225,8 @@ def __init__(self,

self._build_networks()

self.online_convnet_state = self.online_convnet.get_weights()

self._train_op = self._build_train_op()
self._sync_qt_ops = self._build_sync_op()

Expand All @@ -231,6 +248,7 @@ def __init__(self,
self.summary_writer.add_graph(graph=tf.compat.v1.get_default_graph())
self._sess.run(tf.compat.v1.global_variables_initializer())


def _create_network(self, name):
"""Builds the convolutional network used to compute the agent's Q-values.

Expand Down Expand Up @@ -450,8 +468,48 @@ def _train_step(self):
if self.training_steps % self.target_update_period == 0:
self._sess.run(self._sync_qt_ops)

if (self.reset_period is not None and
self.training_steps % self.reset_period == 0\
and self.reset_counter < self.reset_max):
print("Resetting last layers...")
self.ResetWeights()

self.training_steps += 1

def ResetWeights(self):
# Reset the weights of the last layer
# self.online_convnet.set_weights(self.online_convnet_state)
# self.target_convnet.set_weights(self.online_convnet_state)
if self.reset_counter >= self.reset_max:
return

print("Resetting weights...")
if self.reset_last_layer:
print("Resetting last layer!")
self.online_convnet.layers[-1].last_layer.kernel.initializer.run(session=self._sess)
self.online_convnet.layers[-1].last_layer.bias.initializer.run(session=self._sess)

if self.reset_dense1:
print("Resetting dense1 layer!")
self.online_convnet.layers[-1].dense1.kernel.initializer.run(session=self._sess)
self.online_convnet.layers[-1].dense1.bias.initializer.run(session=self._sess)

if self.reset_dense2:
print("Resetting dense2 layer!")
self.online_convnet.layers[-1].dense2.kernel.initializer.run(session=self._sess)
self.online_convnet.layers[-1].dense2.bias.initializer.run(session=self._sess)

# Legacy code
# self.online_convnet.last_layer.kernel.initializer.run(session=self._sess)
# self.online_convnet.last_layer.bias.initializer.run(session=self._sess)

# self._sess.run(tf.compat.v1.global_variables_initializer())
# Reset the optimizer state
optimizer_reset = tf.compat.v1.variables_initializer(self.optimizer_state)
self._sess.run(optimizer_reset)

self.reset_counter += 1

def _record_observation(self, observation):
"""Records an observation and update state.

Expand Down
5 changes: 5 additions & 0 deletions dopamine/discrete_domains/atari_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def __init__(self, num_actions, name=None):
name='fully_connected')
self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected')

# Modification
def reset_last_layer(self):
"""Reset the last layer of the network."""
self.dense2 = tf.keras.layers.Dense(self.num_actions, name='fully_connected')

def call(self, state):
"""Creates the output tensor/op given the state tensor as input.

Expand Down
21 changes: 21 additions & 0 deletions dopamine/discrete_domains/gym_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, min_vals, max_vals, num_actions,
self.num_atoms = num_atoms
self.min_vals = min_vals
self.max_vals = max_vals
self.activation_fn = activation_fn
# Defining layers.
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(512, activation=activation_fn,
Expand All @@ -127,6 +128,22 @@ def __init__(self, min_vals, max_vals, num_actions,
self.last_layer = tf.keras.layers.Dense(num_actions * num_atoms,
name='fully_connected')

# Modified: saving the initial weights to load them after
# model.save_weights('model.h5')

# Modified
def reset_layer(self, layer):
a,b = layer.get_weights()[0].shape
layer.set_weights([np.random.randn(a,b), np.ones(layer.get_weights()[1].shape)])

# Modified
def reset_last_layer(self):
"""Reset the last layer(s) of the network."""
self.reset_layer(self.dense1)
self.reset_layer(self.dense2)
self.reset_layer(self.last_layer)


def call(self, state):
"""Creates the output tensor/op given the state tensor as input."""
x = tf.cast(state, tf.float32)
Expand Down Expand Up @@ -158,6 +175,10 @@ def __init__(self, num_actions, name=None):
self.net = BasicDiscreteDomainNetwork(
CARTPOLE_MIN_VALS, CARTPOLE_MAX_VALS, num_actions)

# Modified
def reset_last_layer(self):
self.net.reset_last_layer()

def call(self, state):
"""Creates the output tensor/op given the state tensor as input."""
x = self.net(state)
Expand Down
10 changes: 9 additions & 1 deletion dopamine/discrete_domains/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def __init__(self,
max_steps_per_episode=27000,
clip_rewards=True,
use_legacy_logger=True,
fine_grained_print_to_console=True):
fine_grained_print_to_console=True,
reset_period=None):
"""Initialize the Runner object in charge of running a full experiment.

Args:
Expand Down Expand Up @@ -234,6 +235,8 @@ def __init__(self,

self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)

self._reset_period = reset_period

# Create a collector dispatcher for metrics reporting.
self._collector_dispatcher = collector_dispatcher.CollectorDispatcher(
self._base_dir)
Expand Down Expand Up @@ -603,6 +606,11 @@ def run_experiment(self):
return

for iteration in range(self._start_iteration, self._num_iterations):
# Modified: Check if the reset period is reached, and if so, reset the weights.
if (self._reset_period is not None and
iteration != 0 and iteration % self._reset_period == 0):
self._agent.ResetWeights()

statistics = self._run_one_iteration(iteration)
if self._use_legacy_logger:
self._log_experiment(iteration, statistics)
Expand Down
6 changes: 5 additions & 1 deletion dopamine/jax/agents/sac/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def __init__(self,
summary_writing_frequency=500,
allow_partial_reload=False,
seed=None,
collector_allowlist=('tensorboard')):
collector_allowlist=('tensorboard'),
reset_period=None):
r"""Initializes the agent and constructs the necessary components.

Args:
Expand Down Expand Up @@ -387,6 +388,9 @@ def __init__(self,
self.allow_partial_reload = allow_partial_reload
self._collector_allowlist = collector_allowlist

# Reset period is used to reset the agent's state every reset_period steps.
self.reset_period = reset_period

self._rng = jax.random.PRNGKey(seed)
state_shape = self.observation_shape + (stack_size,)
self.state = onp.zeros(state_shape)
Expand Down