diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index 68ad34a3..c2290cc3 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Proximal policy optimization training. +"""Short-Horizon Actor Critic. -See: https://arxiv.org/pdf/1707.06347.pdf +See: https://arxiv.org/pdf/2204.07137.pdf """ from typing import Any, Tuple @@ -46,8 +46,7 @@ def compute_shac_policy_loss( reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]: """Computes SHAC critic loss. - This implements Eq. 5 of 2204.07137. It needs to account for any episodes where - the episode terminates and include the terminal values appopriately. + This implements Eq. 5 of 2204.07137. Args: policy_params: Policy network parameters @@ -129,7 +128,6 @@ def compute_shac_critic_loss( params: Params, normalizer_params: Any, data: types.Transition, - rng: jnp.ndarray, shac_network: shac_networks.SHACNetworks, discounting: float = 0.9, reward_scaling: float = 1.0, diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py index c4e325af..47a4a0b1 100644 --- a/brax/training/agents/shac/networks.py +++ b/brax/training/agents/shac/networks.py @@ -76,12 +76,14 @@ def make_shac_networks( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=True) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=True) return SHACNetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index 9510afce..e4f621fe 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -130,8 +130,14 @@ def train(environment: envs.Env, preprocess_observations_fn=normalize) make_policy = shac_networks.make_inference_fn(shac_network) - policy_optimizer = optax.adam(learning_rate=actor_learning_rate) - value_optimizer = optax.adam(learning_rate=critic_learning_rate) + policy_optimizer = optax.chain( + optax.clip(1.0), + optax.adam(learning_rate=actor_learning_rate, b1=0.7, b2=0.95) + ) + value_optimizer = optax.chain( + optax.clip(1.0), + optax.adam(learning_rate=critic_learning_rate, b1=0.7, b2=0.95) + ) value_loss_fn = functools.partial( shac_losses.compute_shac_critic_loss, @@ -184,6 +190,7 @@ def f(carry, unused_t): policy_gradient_update_fn = gradients.gradient_update_fn( rollout_loss_fn, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + policy_gradient_update_fn = jax.jit(policy_gradient_update_fn) def minibatch_step( carry, data: types.Transition, @@ -194,7 +201,6 @@ def minibatch_step( params, normalizer_params, data, - key_loss, optimizer_state=optimizer_state) return (optimizer_state, params, key), metrics @@ -317,7 +323,6 @@ def training_epoch_with_timing( key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:]) env_state = reset_fn(key_envs) - print(f'env_state: {env_state.qp.pos.shape}') if not eval_env: eval_env = env diff --git a/brax/training/networks.py b/brax/training/networks.py index 5856360a..903d1008 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -41,6 +41,7 @@ class MLP(linen.Module): kernel_init: Initializer = jax.nn.initializers.lecun_uniform() activate_final: bool = False bias: bool = True + layer_norm: bool = True @linen.compact def __call__(self, data: jnp.ndarray): @@ -52,6 +53,8 @@ def __call__(self, data: jnp.ndarray): kernel_init=self.kernel_init, use_bias=self.bias)( hidden) + if self.layer_norm: + hidden = linen.LayerNorm()(hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) return hidden @@ -86,11 +89,13 @@ def make_policy_network( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.relu) -> FeedForwardNetwork: + activation: ActivationFn = linen.relu, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a policy network.""" policy_module = MLP( layer_sizes=list(hidden_layer_sizes) + [param_size], activation=activation, + layer_norm=layer_norm, kernel_init=jax.nn.initializers.lecun_uniform()) def apply(processor_params, policy_params, obs): @@ -107,11 +112,13 @@ def make_value_network( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.relu) -> FeedForwardNetwork: + activation: ActivationFn = linen.relu, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a policy network.""" value_module = MLP( layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, + layer_norm=layer_norm, kernel_init=jax.nn.initializers.lecun_uniform()) def apply(processor_params, policy_params, obs):