Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539243423
  • Loading branch information
psc-g committed Jun 12, 2023
1 parent d41ed0f commit a6f414c
Show file tree
Hide file tree
Showing 10 changed files with 3,177 additions and 9 deletions.
79 changes: 79 additions & 0 deletions dopamine/jax/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,85 @@ def preprocess_atari_inputs(x):
identity_preprocess_fn = lambda x: x


@gin.configurable
class Stack(nn.Module):
"""Stack of pooling and convolutional blocks with residual connections."""
num_ch: int
num_blocks: int
use_max_pooling: bool = True

@nn.compact
def __call__(self, x):
initializer = nn.initializers.xavier_uniform()
conv_out = nn.Conv(
features=self.num_ch,
kernel_size=(3, 3),
strides=1,
kernel_init=initializer,
padding='SAME')(
x)
if self.use_max_pooling:
conv_out = nn.max_pool(
conv_out, window_shape=(3, 3), padding='SAME', strides=(2, 2))

for _ in range(self.num_blocks):
block_input = conv_out
conv_out = nn.relu(conv_out)
conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3),
strides=1, padding='SAME')(conv_out)
conv_out = nn.relu(conv_out)
conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3),
strides=1, padding='SAME')(conv_out)
conv_out += block_input

return conv_out


@gin.configurable
class ImpalaEncoder(nn.Module):
"""Impala Network which also outputs penultimate representation layers."""
nn_scale: int = 1
stack_sizes: Tuple[int, ...] = (16, 32, 32)
num_blocks: int = 2

def setup(self):
self._stacks = [
Stack(num_ch=stack_size * self.nn_scale,
num_blocks=self.num_blocks) for stack_size in self.stack_sizes
]

@nn.compact
def __call__(self, x):
for stack in self._stacks:
x = stack(x)
return nn.relu(x)


### DQN Network with ImpalaEncoder ###
@gin.configurable
class ImpalaDQNNetwork(nn.Module):
"""The convolutional network used to compute the agent's Q-values."""
num_actions: int
inputs_preprocessed: bool = False
nn_scale: int = 1

def setup(self):
self.encoder = ImpalaEncoder(nn_scale=self.nn_scale)

@nn.compact
def __call__(self, x):
initializer = nn.initializers.xavier_uniform()
if not self.inputs_preprocessed:
x = preprocess_atari_inputs(x)
x = self.encoder(x)
x = x.reshape((-1)) # flatten
x = nn.Dense(features=512, kernel_init=initializer)(x)
x = nn.relu(x)
q_values = nn.Dense(features=self.num_actions,
kernel_init=initializer)(x)
return atari_lib.DQNNetworkType(q_values)


### DQN Networks ###
@gin.configurable
class NatureDQNNetwork(nn.Module):
Expand Down
189 changes: 189 additions & 0 deletions dopamine/labs/atari_100k/atari_100k_rainbow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Atari 100k rainbow agent with support for data augmentation."""

import copy
import functools

from absl import logging
Expand All @@ -22,9 +23,55 @@
import gin
import jax
import jax.numpy as jnp
import numpy as onp
import tensorflow as tf


@functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11))
def select_action(
network_def,
params,
state,
rng,
num_actions,
eval_mode,
epsilon_eval,
epsilon_train,
epsilon_decay_period,
training_steps,
min_replay_history,
epsilon_fn,
support,
):
"""Select an action from the set of available actions."""
epsilon = jnp.where(
eval_mode,
epsilon_eval,
epsilon_fn(
epsilon_decay_period,
training_steps,
min_replay_history,
epsilon_train,
),
)

rng, rng1, rng2, rng3 = jax.random.split(rng, num=4)

@functools.partial(jax.vmap, in_axes=(0, 0), axis_name='batch')
def q_function(state, key):
q_values = network_def.apply(
params, state, key=key, eval_mode=eval_mode, support=support
).q_values
return q_values

q_values = q_function(state, jax.random.split(rng2, state.shape[0]))

best_actions = jnp.argmax(q_values, axis=-1)
random_actions = jax.random.randint(rng3, (state.shape[0],), 0, num_actions)
p = jax.random.uniform(rng1, shape=(state.shape[0],))
return rng, jnp.where(p <= epsilon, random_actions, best_actions)


############################ Data Augmentation ############################


Expand Down Expand Up @@ -118,6 +165,7 @@ def __init__(self,
self.train_preprocess_fn = functools.partial(
preprocess_inputs_with_augmentation,
data_augmentation=data_augmentation)
self.state_shape = self.state.shape

def _training_step_update(self):
"""Gradient update during every training step."""
Expand Down Expand Up @@ -156,3 +204,144 @@ def _training_step_update(self):
step=self.training_steps)
self.summary_writer.flush()

def step(self, reward=None, observation=None):
"""Selects an action, and optionally records a transition and trains.
If `reward` or `observation` is None, the agent's state will _not_ be
updated and nothing will be written to the buffer. The user must call
`log_transition` themselves in this case.
Args:
reward: Optional reward to log.
observation: Optional observation to log. Must call `log_transition` later
if not passed here.
Returns:
Selected action.
"""
if reward is not None and observation is not None:
self._last_observation = self._observation
self._record_observation(observation)
if not self.eval_mode:
self._store_transition(
self._last_observation, self.action, reward, False
)

if not self.eval_mode:
self._train_step()

state = self.preprocess_fn(self.state)
self._rng, action = select_action(
self.network_def,
self.online_params,
state,
self._rng,
self.num_actions,
self.eval_mode,
self.epsilon_eval,
self.epsilon_train,
self.epsilon_decay_period,
self.training_steps,
self.min_replay_history,
self.epsilon_fn,
self._support,
)
self.action = onp.asarray(action)
return self.action

def _reset_state(self, n_envs=None):
"""Resets the agent state by filling it with zeros."""
if n_envs is None:
self.state = onp.zeros((1, *self.state_shape))
else:
self.state = onp.zeros((n_envs, *self.state_shape))

def _record_observation(self, observation):
"""Records an observation and update state.
Extracts a frame from the observation vector and overwrites the oldest
frame in the state buffer.
Args:
observation: numpy array, an observation from the environment.
"""
# Set current observation. We do the reshaping to handle environments
# without frame stacking.
observation = observation.squeeze(-1)
if len(observation.shape) == len(self.observation_shape):
self._observation = onp.reshape(observation, self.observation_shape)
else:
self._observation = onp.reshape(
observation, (observation.shape[0], *self.observation_shape)
)
# Swap out the oldest frame with the current frame.
self.state = onp.roll(self.state, -1, axis=-1)
self.state[..., -1] = self._observation

def reset_all(self, new_obs):
"""Resets the agent state by filling it with zeros."""
n_envs = new_obs.shape[0]
self.state = onp.zeros((n_envs, *self.state_shape))
self._record_observation(new_obs)

def reset_one(self, env_id):
self.state[env_id].fill(0)

def delete_one(self, env_id):
self.state = onp.concatenate(
[self.state[:env_id], self.state[env_id + 1 :]], 0
)

def cache_train_state(self):
self.training_state = (
copy.deepcopy(self.state),
copy.deepcopy(self._last_observation),
copy.deepcopy(self._observation),
)

def restore_train_state(self):
(self.state, self._last_observation, self._observation) = (
self.training_state
)

def log_transition(self, observation, action, reward, terminal, episode_end):
self._last_observation = self._observation
self._record_observation(observation)

if not self.eval_mode:
self._store_transition(
self._last_observation,
action,
reward,
terminal,
episode_end=episode_end,
)

def _store_transition(
self,
last_observation,
action,
reward,
is_terminal,
*args,
priority=None,
episode_end=False
):
"""Stores a transition when in training mode."""
is_prioritized = hasattr(self._replay, 'sum_tree')
if is_prioritized and priority is None:
priority = onp.ones_like(reward)
if self._replay_scheme == 'prioritized':
priority *= self._replay.sum_tree.max_recorded_priority

to_store = (last_observation, action, reward, is_terminal, *args)
to_store = (onp.asarray(x) for x in to_store)
if not hasattr(self._replay, '_n_envs'):
to_store = (onp.squeeze(x) for x in to_store)
priority = onp.squeeze(priority)
elif hasattr(self._replay, '_n_envs') and not reward.shape:
to_store = (onp.expand_dims(x, 0) if not (x.shape and x.shape[0] == 1)
else x for x in to_store)
priority = onp.expand_dims(priority, 0)
if not self.eval_mode:
self._replay.add(*to_store, priority=priority, episode_end=episode_end)
Loading

0 comments on commit a6f414c

Please sign in to comment.