Skip to content

Commit

Permalink
SHAC: tweak layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
peabody124 committed Nov 20, 2022
1 parent 8ff7005 commit f879469
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
9 changes: 6 additions & 3 deletions brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def make_ppo_networks(
.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
activation: networks.ActivationFn = linen.swish) -> PPONetworks:
activation: networks.ActivationFn = linen.swish,
layer_norm: bool = False) -> PPONetworks:
"""Make PPO networks with preprocessor."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
Expand All @@ -75,12 +76,14 @@ def make_ppo_networks(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=policy_hidden_layer_sizes,
activation=activation)
activation=activation,
layer_norm=layer_norm)
value_network = networks.make_value_network(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation)
activation=activation,
layer_norm=layer_norm)

return PPONetworks(
policy_network=policy_network,
Expand Down
5 changes: 2 additions & 3 deletions brax/training/agents/shac/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def compute_shac_policy_loss(
# jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227
def sum_step(carry, target_t):
gam, rew_acc = carry
reward, v, termination = target_t
reward, termination = target_t

# clean up gamma and rew_acc for done envs, otherwise update
rew_acc = jnp.where(termination, 0, rew_acc + gam * reward)
Expand All @@ -100,7 +100,7 @@ def sum_step(carry, target_t):
rew_acc = jnp.zeros_like(terminal_values)
gam = jnp.ones_like(terminal_values)
(gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc),
(rewards, values, termination))
(rewards, termination))

policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values)
# for trials that are truncated (i.e. hit the episode length) include reward for
Expand All @@ -118,7 +118,6 @@ def sum_step(carry, target_t):
total_loss = policy_loss + entropy_loss

return total_loss, {
'total_loss': total_loss,
'policy_loss': policy_loss,
'entropy_loss': entropy_loss
}
Expand Down
7 changes: 4 additions & 3 deletions brax/training/agents/shac/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def make_shac_networks(
.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
activation: networks.ActivationFn = linen.swish) -> SHACNetworks:
activation: networks.ActivationFn = linen.elu,
layer_norm: bool = True) -> SHACNetworks:
"""Make SHAC networks with preprocessor."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
Expand All @@ -77,13 +78,13 @@ def make_shac_networks(
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=policy_hidden_layer_sizes,
activation=activation,
layer_norm=True)
layer_norm=layer_norm)
value_network = networks.make_value_network(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation,
layer_norm=True)
layer_norm=layer_norm)

return SHACNetworks(
policy_network=policy_network,
Expand Down
4 changes: 3 additions & 1 deletion brax/training/agents/shac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def train(environment: envs.Env,
reward_scaling: float = 1.,
tau: float = 0.005, # this is 1-alpha from the original paper
lambda_: float = .95,
td_lambda: bool = True,
deterministic_eval: bool = False,
network_factory: types.NetworkFactory[
shac_networks.SHACNetworks] = shac_networks.make_shac_networks,
Expand Down Expand Up @@ -144,7 +145,8 @@ def train(environment: envs.Env,
shac_network=shac_network,
discounting=discounting,
reward_scaling=reward_scaling,
lambda_=lambda_)
lambda_=lambda_,
td_lambda=td_lambda)

value_gradient_update_fn = gradients.gradient_update_fn(
value_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)
Expand Down
4 changes: 2 additions & 2 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def __call__(self, data: jnp.ndarray):
kernel_init=self.kernel_init,
use_bias=self.bias)(
hidden)
if self.layer_norm:
hidden = linen.LayerNorm()(hidden)
if i != len(self.layer_sizes) - 1 or self.activate_final:
hidden = self.activation(hidden)
if self.layer_norm:
hidden = linen.LayerNorm(dtype=get_dtype(self.half_precision))(hidden)
return hidden


Expand Down

0 comments on commit f879469

Please sign in to comment.