Skip to content

Commit

Permalink
SHAC: layer norm and gradient clipping
Browse files Browse the repository at this point in the history
Starting to see progress training the ant environment.
  • Loading branch information
peabody124 committed Nov 20, 2022
1 parent 4000c95 commit 8ff7005
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
8 changes: 3 additions & 5 deletions brax/training/agents/shac/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Proximal policy optimization training.
"""Short-Horizon Actor Critic.
See: https://arxiv.org/pdf/1707.06347.pdf
See: https://arxiv.org/pdf/2204.07137.pdf
"""

from typing import Any, Tuple
Expand Down Expand Up @@ -46,8 +46,7 @@ def compute_shac_policy_loss(
reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]:
"""Computes SHAC critic loss.
This implements Eq. 5 of 2204.07137. It needs to account for any episodes where
the episode terminates and include the terminal values appopriately.
This implements Eq. 5 of 2204.07137.
Args:
policy_params: Policy network parameters
Expand Down Expand Up @@ -129,7 +128,6 @@ def compute_shac_critic_loss(
params: Params,
normalizer_params: Any,
data: types.Transition,
rng: jnp.ndarray,
shac_network: shac_networks.SHACNetworks,
discounting: float = 0.9,
reward_scaling: float = 1.0,
Expand Down
6 changes: 4 additions & 2 deletions brax/training/agents/shac/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def make_shac_networks(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=policy_hidden_layer_sizes,
activation=activation)
activation=activation,
layer_norm=True)
value_network = networks.make_value_network(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation)
activation=activation,
layer_norm=True)

return SHACNetworks(
policy_network=policy_network,
Expand Down
13 changes: 9 additions & 4 deletions brax/training/agents/shac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ def train(environment: envs.Env,
preprocess_observations_fn=normalize)
make_policy = shac_networks.make_inference_fn(shac_network)

policy_optimizer = optax.adam(learning_rate=actor_learning_rate)
value_optimizer = optax.adam(learning_rate=critic_learning_rate)
policy_optimizer = optax.chain(
optax.clip(1.0),
optax.adam(learning_rate=actor_learning_rate, b1=0.7, b2=0.95)
)
value_optimizer = optax.chain(
optax.clip(1.0),
optax.adam(learning_rate=critic_learning_rate, b1=0.7, b2=0.95)
)

value_loss_fn = functools.partial(
shac_losses.compute_shac_critic_loss,
Expand Down Expand Up @@ -184,6 +190,7 @@ def f(carry, unused_t):

policy_gradient_update_fn = gradients.gradient_update_fn(
rollout_loss_fn, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)
policy_gradient_update_fn = jax.jit(policy_gradient_update_fn)

def minibatch_step(
carry, data: types.Transition,
Expand All @@ -194,7 +201,6 @@ def minibatch_step(
params,
normalizer_params,
data,
key_loss,
optimizer_state=optimizer_state)

return (optimizer_state, params, key), metrics
Expand Down Expand Up @@ -317,7 +323,6 @@ def training_epoch_with_timing(
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
print(f'env_state: {env_state.qp.pos.shape}')

if not eval_env:
eval_env = env
Expand Down
11 changes: 9 additions & 2 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MLP(linen.Module):
kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
activate_final: bool = False
bias: bool = True
layer_norm: bool = True

@linen.compact
def __call__(self, data: jnp.ndarray):
Expand All @@ -52,6 +53,8 @@ def __call__(self, data: jnp.ndarray):
kernel_init=self.kernel_init,
use_bias=self.bias)(
hidden)
if self.layer_norm:
hidden = linen.LayerNorm()(hidden)
if i != len(self.layer_sizes) - 1 or self.activate_final:
hidden = self.activation(hidden)
return hidden
Expand Down Expand Up @@ -86,11 +89,13 @@ def make_policy_network(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
activation: ActivationFn = linen.relu,
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a policy network."""
policy_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [param_size],
activation=activation,
layer_norm=layer_norm,
kernel_init=jax.nn.initializers.lecun_uniform())

def apply(processor_params, policy_params, obs):
Expand All @@ -107,11 +112,13 @@ def make_value_network(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
activation: ActivationFn = linen.relu,
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a policy network."""
value_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
layer_norm=layer_norm,
kernel_init=jax.nn.initializers.lecun_uniform())

def apply(processor_params, policy_params, obs):
Expand Down

0 comments on commit 8ff7005

Please sign in to comment.