diff --git a/qdax/core/neuroevolution/networks/dads_networks.py b/qdax/core/neuroevolution/networks/dads_networks.py index 59abf810..863bdab5 100644 --- a/qdax/core/neuroevolution/networks/dads_networks.py +++ b/qdax/core/neuroevolution/networks/dads_networks.py @@ -5,6 +5,7 @@ import tensorflow_probability.substrates.jax as tfp from jax.nn import initializers +from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import Action, Observation, Skill, StateDescriptor @@ -74,9 +75,12 @@ def __call__( obs = obs[:, self.omit_input_dynamics_dim :] obs = jnp.concatenate((obs, skill), axis=1) - x = obs - for features in self.hidden_layer_sizes: - x = nn.relu(nn.Dense(features, kernel_init=init)(x)) + x = MLP( + layer_sizes=self.hidden_layer_sizes, + kernel_init=init, + activation=nn.relu, + final_activation=nn.relu, + )(obs) dist = distribution(x) return dist.log_prob(target) @@ -89,10 +93,12 @@ class Actor(nn.Module): @nn.compact def __call__(self, obs: Observation) -> jnp.ndarray: init = initializers.variance_scaling(1.0, "fan_in", "uniform") - x = obs - for features in self.hidden_layer_sizes: - x = nn.relu(nn.Dense(features, kernel_init=init)(x)) - return nn.Dense(2 * self.action_size, kernel_init=init)(x) + + return MLP( + layer_sizes=self.hidden_layer_sizes + (2 * self.action_size,), + kernel_init=init, + activation=nn.relu, + )(obs) class Critic(nn.Module): @@ -103,15 +109,19 @@ def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: init = initializers.variance_scaling(1.0, "fan_in", "uniform") input_ = jnp.concatenate([obs, action], axis=-1) - def make_critic_network() -> jnp.ndarray: - x = input_ - for features in self.hidden_layer_sizes: - x = nn.relu(nn.Dense(features, kernel_init=init)(x)) - return nn.Dense(1, kernel_init=init)(x) + value_1 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) - value1 = make_critic_network() - value2 = make_critic_network() - return jnp.concatenate([value1, value2], axis=-1) + return jnp.concatenate([value_1, value_2], axis=-1) def make_dads_networks( diff --git a/qdax/core/neuroevolution/networks/diayn_networks.py b/qdax/core/neuroevolution/networks/diayn_networks.py index 617d31a8..e292e131 100644 --- a/qdax/core/neuroevolution/networks/diayn_networks.py +++ b/qdax/core/neuroevolution/networks/diayn_networks.py @@ -3,36 +3,21 @@ import flax.linen as nn import jax.numpy as jnp +from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import Action, Observation -class MLP(nn.Module): - features: Tuple[int, ...] - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - for feat in self.features[:-1]: - x = nn.relu( - nn.Dense( - feat, - kernel_init=nn.initializers.variance_scaling( - 1.0, "fan_in", "uniform" - ), - )(x) - ) - return nn.Dense( - self.features[-1], - kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), - )(x) - - class Actor(nn.Module): action_size: int hidden_layer_size: Tuple[int, ...] @nn.compact def __call__(self, obs: Observation) -> jnp.ndarray: - return MLP(self.hidden_layer_size + (2 * self.action_size,))(obs) + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) class Critic(nn.Module): @@ -41,9 +26,22 @@ class Critic(nn.Module): @nn.compact def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: input_ = jnp.concatenate([obs, action], axis=-1) - value1 = MLP(self.hidden_layer_size + (1,))(input_) - value2 = MLP(self.hidden_layer_size + (1,))(input_) - return jnp.concatenate([value1, value2], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) class Discriminator(nn.Module): @@ -52,7 +50,11 @@ class Discriminator(nn.Module): @nn.compact def __call__(self, obs: Observation) -> jnp.ndarray: - return MLP(self.hidden_layer_size + (self.num_skills,))(obs) + return MLP( + layer_sizes=self.hidden_layer_size + (self.num_skills,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) def make_diayn_networks( diff --git a/qdax/core/neuroevolution/networks/sac_networks.py b/qdax/core/neuroevolution/networks/sac_networks.py index d5d2003a..a236afd4 100644 --- a/qdax/core/neuroevolution/networks/sac_networks.py +++ b/qdax/core/neuroevolution/networks/sac_networks.py @@ -3,36 +3,20 @@ import flax.linen as nn import jax.numpy as jnp +from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import Action, Observation -class MLP(nn.Module): - features: Tuple[int, ...] - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - for feat in self.features[:-1]: - x = nn.relu( - nn.Dense( - feat, - kernel_init=nn.initializers.variance_scaling( - 1.0, "fan_in", "uniform" - ), - )(x) - ) - return nn.Dense( - self.features[-1], - kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), - )(x) - - class Actor(nn.Module): action_size: int hidden_layer_size: Tuple[int, ...] @nn.compact def __call__(self, obs: Observation) -> jnp.ndarray: - return MLP(self.hidden_layer_size + (2 * self.action_size,))(obs) + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + )(obs) class Critic(nn.Module): @@ -41,9 +25,22 @@ class Critic(nn.Module): @nn.compact def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: input_ = jnp.concatenate([obs, action], axis=-1) - value1 = MLP(self.hidden_layer_size + (1,))(input_) - value2 = MLP(self.hidden_layer_size + (1,))(input_) - return jnp.concatenate([value1, value2], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) def make_sac_networks(