Skip to content

Commit

Permalink
Change remaining jax.random.PRNGKey to jax.random.key
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 13, 2024
1 parent 7ba39e3 commit dc31b0c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/dcrlme.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"source": [
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init environment\n",
"env = environments.create(env_name, episode_length=episode_length)\n",
Expand Down
29 changes: 16 additions & 13 deletions qdax/core/neuroevolution/networks/seq2seq_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
import numpy as np
from flax import linen as nn

Array = Any
PRNGKey = jax.Array


class EncoderLSTM(nn.Module):
"""EncoderLSTM Module wrapped in a lifted scan transform."""
Expand All @@ -31,14 +28,16 @@ class EncoderLSTM(nn.Module):
)
@nn.compact
def __call__(
self, carry: Tuple[Array, Array], x: Array
) -> Tuple[Tuple[Array, Array], Array]:
self, carry: Tuple[jax.Array, jax.Array], x: jax.Array
) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
"""Applies the module."""
lstm_state, is_eos = carry
features = lstm_state[0].shape[-1]
new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)

def select_carried_state(new_state: Array, old_state: Array) -> Array:
def select_carried_state(
new_state: jax.Array, old_state: jax.Array
) -> jax.Array:
return jnp.where(is_eos[:, np.newaxis], old_state, new_state)

# LSTM state is a tuple (c, h).
Expand All @@ -49,7 +48,9 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array:
return (carried_lstm_state, is_eos), y

@staticmethod
def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]:
def initialize_carry(
batch_size: int, hidden_size: int
) -> Tuple[jax.Array, jax.Array]:
# Use a dummy key since the default state init fn is just zeros.
return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore
jax.random.key(0), (batch_size, hidden_size)
Expand All @@ -62,7 +63,7 @@ class Encoder(nn.Module):
hidden_size: int

@nn.compact
def __call__(self, inputs: Array) -> Array:
def __call__(self, inputs: jax.Array) -> jax.Array:
batch_size = inputs.shape[0]
lstm = EncoderLSTM(name="encoder_lstm")
init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
Expand Down Expand Up @@ -95,7 +96,7 @@ class DecoderLSTM(nn.Module):
split_rngs={"params": False, "lstm": True},
)
@nn.compact
def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
def __call__(self, carry: Tuple[jax.Array, jax.Array], x: jax.Array) -> jax.Array:
"""Applies the DecoderLSTM model."""

lstm_state, last_prediction = carry
Expand Down Expand Up @@ -124,7 +125,9 @@ class Decoder(nn.Module):
obs_size: int

@nn.compact
def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]:
def __call__(
self, inputs: jax.Array, init_state: Any
) -> Tuple[jax.Array, jax.Array]:
"""Applies the decoder model.
Args:
Expand Down Expand Up @@ -166,8 +169,8 @@ def setup(self) -> None:

@nn.compact
def __call__(
self, encoder_inputs: Array, decoder_inputs: Array
) -> Tuple[Array, Array]:
self, encoder_inputs: jax.Array, decoder_inputs: jax.Array
) -> Tuple[jax.Array, jax.Array]:
"""Applies the seq2seq model.
Args:
Expand All @@ -194,7 +197,7 @@ def __call__(

return logits, predictions

def encode(self, encoder_inputs: Array) -> Array:
def encode(self, encoder_inputs: jax.Array) -> jax.Array:
# encode inputs
init_decoder_state = self.encoder(encoder_inputs)
final_output, _hidden_state = init_decoder_state
Expand Down
9 changes: 3 additions & 6 deletions qdax/utils/train_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
from qdax.custom_types import Params, RNGKey
from qdax.environments.bd_extractors import AuroraExtraInfoNormalization

Array = Any
PRNGKey = Any


def get_model(
obs_size: int, teacher_force: bool = False, hidden_size: int = 10
Expand All @@ -40,7 +37,7 @@ def get_model(


def get_initial_params(
model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...]
model: Seq2seq, random_key: RNGKey, encoder_input_shape: Tuple[int, ...]
) -> Dict[str, Any]:
"""
Returns the initial parameters of a seq2seq model.
Expand All @@ -62,8 +59,8 @@ def get_initial_params(
@jax.jit
def train_step(
state: train_state.TrainState,
batch: Array,
lstm_random_key: PRNGKey,
batch: jax.Array,
lstm_random_key: RNGKey,
) -> Tuple[train_state.TrainState, Dict[str, float]]:
"""
Trains for one step.
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/dcrlme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_dcrlme() -> None:
policy_delay = 2

# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key = jax.random.key(seed)

# Init environment
env = environments.create(env_name, episode_length=episode_length)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils_test/uncertainty_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_uncertainty_metrics() -> None:
genotype_dim = 8

# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key = jax.random.key(seed)

# First, init a deterministic environment
init_policies = jax.random.uniform(
Expand Down

0 comments on commit dc31b0c

Please sign in to comment.