diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index eae0e6b3..057ef0c4 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -154,7 +154,7 @@ "source": [ "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index 476705fa..a4bb2272 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -15,9 +15,6 @@ import numpy as np from flax import linen as nn -Array = Any -PRNGKey = jax.Array - class EncoderLSTM(nn.Module): """EncoderLSTM Module wrapped in a lifted scan transform.""" @@ -31,14 +28,16 @@ class EncoderLSTM(nn.Module): ) @nn.compact def __call__( - self, carry: Tuple[Array, Array], x: Array - ) -> Tuple[Tuple[Array, Array], Array]: + self, carry: Tuple[jax.Array, jax.Array], x: jax.Array + ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]: """Applies the module.""" lstm_state, is_eos = carry features = lstm_state[0].shape[-1] new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) - def select_carried_state(new_state: Array, old_state: Array) -> Array: + def select_carried_state( + new_state: jax.Array, old_state: jax.Array + ) -> jax.Array: return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). @@ -49,7 +48,9 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: return (carried_lstm_state, is_eos), y @staticmethod - def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: + def initialize_carry( + batch_size: int, hidden_size: int + ) -> Tuple[jax.Array, jax.Array]: # Use a dummy key since the default state init fn is just zeros. return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore jax.random.key(0), (batch_size, hidden_size) @@ -62,7 +63,7 @@ class Encoder(nn.Module): hidden_size: int @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jax.Array) -> jax.Array: batch_size = inputs.shape[0] lstm = EncoderLSTM(name="encoder_lstm") init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) @@ -95,7 +96,7 @@ class DecoderLSTM(nn.Module): split_rngs={"params": False, "lstm": True}, ) @nn.compact - def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: + def __call__(self, carry: Tuple[jax.Array, jax.Array], x: jax.Array) -> jax.Array: """Applies the DecoderLSTM model.""" lstm_state, last_prediction = carry @@ -124,7 +125,9 @@ class Decoder(nn.Module): obs_size: int @nn.compact - def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]: + def __call__( + self, inputs: jax.Array, init_state: Any + ) -> Tuple[jax.Array, jax.Array]: """Applies the decoder model. Args: @@ -166,8 +169,8 @@ def setup(self) -> None: @nn.compact def __call__( - self, encoder_inputs: Array, decoder_inputs: Array - ) -> Tuple[Array, Array]: + self, encoder_inputs: jax.Array, decoder_inputs: jax.Array + ) -> Tuple[jax.Array, jax.Array]: """Applies the seq2seq model. Args: @@ -194,7 +197,7 @@ def __call__( return logits, predictions - def encode(self, encoder_inputs: Array) -> Array: + def encode(self, encoder_inputs: jax.Array) -> jax.Array: # encode inputs init_decoder_state = self.encoder(encoder_inputs) final_output, _hidden_state = init_decoder_state diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index bd9570a9..fa7825b0 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -19,9 +19,6 @@ from qdax.custom_types import Params, RNGKey from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -Array = Any -PRNGKey = Any - def get_model( obs_size: int, teacher_force: bool = False, hidden_size: int = 10 @@ -40,7 +37,7 @@ def get_model( def get_initial_params( - model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...] + model: Seq2seq, random_key: RNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: """ Returns the initial parameters of a seq2seq model. @@ -62,8 +59,8 @@ def get_initial_params( @jax.jit def train_step( state: train_state.TrainState, - batch: Array, - lstm_random_key: PRNGKey, + batch: jax.Array, + lstm_random_key: RNGKey, ) -> Tuple[train_state.TrainState, Dict[str, float]]: """ Trains for one step. diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 05304944..1bc9688d 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -61,7 +61,7 @@ def test_dcrlme() -> None: policy_delay = 2 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init environment env = environments.create(env_name, episode_length=episode_length) diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index d49e2527..3f2caea1 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -25,7 +25,7 @@ def test_uncertainty_metrics() -> None: genotype_dim = 8 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # First, init a deterministic environment init_policies = jax.random.uniform(