diff --git a/README.md b/README.md index e7955450..052eb74a 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ min_bd = 0.0 max_bd = 1.0 # Init a random key -random_key = jax.random.PRNGKey(seed) +random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index 09072715..55a1db53 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -146,7 +146,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index ba059cdc..98ccc5bc 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -178,7 +178,7 @@ "outputs": [], "source": [ "state = cmaes.init()\n", - "random_key = jax.random.PRNGKey(0)" + "random_key = jax.random.key(0)" ] }, { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index f7aa235e..d42dadb8 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -205,7 +205,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "# in CMA-ME settings (from the paper), there is no init population\n", "# we multipy by zero to reproduce this setting\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index d37bf80e..1a8eeafb 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -198,7 +198,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.key(0)\n", "# no initial population - give all the same value as emitter init value\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index 57d1df05..50d99d56 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -163,7 +163,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -499,7 +499,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index cdee8b4b..d13ccad7 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -162,7 +162,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -490,7 +490,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 7a7b5296..d2b158da 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -176,7 +176,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 78ec01c8..078c1c65 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -113,7 +113,7 @@ "env = jumanji.make('Snake-v1')\n", "\n", "# Reset your (jit-able) environment\n", - "key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "state, timestep = jax.jit(env.reset)(key)\n", "\n", "# Interact with the (jit-able) environment\n", @@ -137,7 +137,7 @@ "outputs": [], "source": [ "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# get number of actions\n", "num_actions = env.action_spec().maximum + 1\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b7a0a256..626fb5de 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -133,7 +133,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.Key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -494,7 +494,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 86deebc4..068acab3 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -150,7 +150,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" @@ -504,7 +504,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_map(\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index f72ccda1..71289a96 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -155,7 +155,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" diff --git a/examples/mees.ipynb b/examples/mees.ipynb index 3cbf890f..c09c7132 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -151,7 +151,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index dae02a95..ed5a7c7a 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -140,7 +140,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -509,7 +509,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 0d005dbe..217f94be 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -205,7 +205,7 @@ "outputs": [], "source": [ "# initial population\n", - "random_key = jax.random.PRNGKey(42)\n", + "random_key = jax.random.key(42)\n", "random_key, subkey = jax.random.split(random_key)\n", "genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index b418bd31..2d157323 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -181,7 +181,7 @@ "outputs": [], "source": [ "# Initial population\n", - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "random_key, subkey = jax.random.split(random_key)\n", "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index bde9f5ed..900fc812 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -184,7 +184,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "\n", "# defines the population\n", "random_key, subkey = jax.random.split(random_key)\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 56ed4f01..c3c00ae5 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -164,7 +164,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 31e4f831..c5419a3f 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -143,7 +143,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 8c47ffe6..a30c3be3 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -157,7 +157,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 915cc272..b1ab220b 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -210,7 +210,7 @@ "outputs": [], "source": [ "# %%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", @@ -518,7 +518,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_util.tree_map(\n", diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 433bc1d2..294cca8e 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -26,7 +26,7 @@ def run_me() -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 0e332192..ede905f9 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -168,7 +168,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.Key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -504,7 +504,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index d2d98f85..5de9e24b 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -181,7 +181,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index b473d4b3..87584eb3 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -63,7 +63,7 @@ def compute_cvt_centroids( init="k-means++", n_clusters=num_centroids, n_init=1, - random_state=RandomState(subkey), + random_state=RandomState(jax.random.key_data(subkey)), ) k_means.fit(x) centroids = k_means.cluster_centers_ diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index 3cb52a3e..476705fa 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -16,7 +16,7 @@ from flax import linen as nn Array = Any -PRNGKey = Any +PRNGKey = jax.Array class EncoderLSTM(nn.Module): @@ -52,7 +52,7 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: # Use a dummy key since the default state init fn is just zeros. return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore - jax.random.PRNGKey(0), (batch_size, hidden_size) + jax.random.key(0), (batch_size, hidden_size) ) diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md index 56528323..bb3a09d8 100644 --- a/qdax/tasks/README.md +++ b/qdax/tasks/README.md @@ -19,7 +19,7 @@ Notes: import jax from qdax.tasks.arm import arm_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = arm_scoring_function @@ -56,7 +56,7 @@ desc_size = 2 import jax from qdax.tasks.standard_functions import sphere_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = sphere_scoring_function @@ -98,7 +98,7 @@ desc_size = 2 import jax from qdax.tasks.hypervolume_functions import square_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = square_scoring_function diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index 2dc6fa10..82d7e54a 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -81,7 +81,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) initial_population = ( jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0 ) diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index 5bfdfd58..02fdad13 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -95,7 +95,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) initial_population = jax.random.uniform( random_key, shape=(batch_size, num_dimensions) ) diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 2a8d3d1f..2ecc25f6 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -72,7 +72,7 @@ def test_dads_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 77094ffd..76da834e 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -66,7 +66,7 @@ def test_dads() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index abd94b45..f06a4298 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -69,7 +69,7 @@ def test_diayn_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index 3492d9c1..856e0174 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -62,7 +62,7 @@ def test_diayn() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index a1eb1b51..619c76e2 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -71,7 +71,7 @@ def scoring_fn( return fitnesses, {}, random_key # initial population - random_key = jax.random.PRNGKey(42) + random_key = jax.random.key(42) random_key, subkey = jax.random.split(random_key) genotypes = jax.random.uniform( subkey, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 5058bad6..b2365382 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -69,7 +69,7 @@ def test_me_pbt_sac() -> None: ) min_bd, max_bd = env.behavior_descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 39c3e942..53e2802e 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -69,7 +69,7 @@ def test_me_pbt_td3() -> None: ) min_bd, max_bd = env.behavior_descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index d1913b02..b5d56f1f 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -47,7 +47,7 @@ def test_mees() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index ad51c7ae..632dc993 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -82,7 +82,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) # defines the population random_key, subkey = jax.random.split(random_key) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index db7dc69e..537d4a26 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -76,7 +76,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 0be68277..813deb60 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -74,7 +74,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 0490a481..639f1a9d 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -54,7 +54,7 @@ def test_pgame() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 704416a4..7f1868f6 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -69,7 +69,7 @@ def test_qdpg() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index 8c26b510..57554b92 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -53,7 +53,7 @@ def test_sac() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/td3_test.py b/tests/baselines_test/td3_test.py index ff5d338c..55c09811 100644 --- a/tests/baselines_test/td3_test.py +++ b/tests/baselines_test/td3_test.py @@ -49,7 +49,7 @@ def test_td3() -> None: auto_reset=True, eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 4bbb9d82..9ec55a78 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -75,7 +75,7 @@ def test_aurora(env_name: str, batch_size: int) -> None: log_freq = 5 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init environment env, policy_network, scoring_fn, random_key = create_default_brax_task_components( diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index 16321fd4..d334fcad 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -32,7 +32,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: ) state = cmaes.init() - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) iteration_count = 0 for _ in range(num_iterations): diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index ebf712d5..cee20261 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -27,7 +27,7 @@ def test_multi_emitter() -> None: max_bd = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index c89ce04f..61c90f06 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -51,7 +51,7 @@ def test_map_elites(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 66bcc05f..383ab55a 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -40,7 +40,7 @@ def test_mels(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index 746b94a0..4e0eb574 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -79,7 +79,7 @@ def scoring_fn( metrics_function = partial(default_moqd_metrics, reference_point=reference_point) # initial population - random_key = jax.random.PRNGKey(42) + random_key = jax.random.key(42) random_key, subkey = jax.random.split(random_key) genotypes = jax.random.uniform( subkey, diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index 06e25fcd..07726b94 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -91,7 +91,7 @@ def test_sample() -> None: simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) samples, random_key = replay_buffer.sample(random_key, 3) diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index 98361b23..31e0e0c1 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -41,7 +41,7 @@ def test_arm(task_name: str, batch_size: int) -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) @@ -114,7 +114,7 @@ def test_arm_scoring_function() -> None: # Init a random key seed = 42 - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # arm has xy BD centered at 0.5 0.5 and min max range is [0,1] # 0 params of first genotype is horizontal and points towards negative x axis diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index c12518fb..55768171 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -34,7 +34,7 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) env, policy_network, scoring_fn, random_key = create_default_brax_task_components( env_name=env_name, diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index 3d619353..152c245d 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -50,7 +50,7 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index 636a02cf..dba574ac 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -30,7 +30,7 @@ def test_jumanji_utils() -> None: env = jumanji.make("Snake-v1") # Reset your (jit-able) environment - key = jax.random.PRNGKey(0) + key = jax.random.key(0) state, _timestep = jax.jit(env.reset)(key) # Interact with the (jit-able) environment @@ -38,7 +38,7 @@ def test_jumanji_utils() -> None: state, _timestep = jax.jit(env.step)(state, action) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # get number of actions num_actions = env.action_spec().maximum + 1 diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index 46f6ce9b..9424cc63 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -68,7 +68,7 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: grid_shape = tuple([resolution_per_axis for _ in range(bd_size)]) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of parameters init_variables = task.get_initial_parameters(init_batch_size) diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 87913364..b30cd7cc 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -40,7 +40,7 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_bd = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/environments_test/wrapper_test.py b/tests/environments_test/wrapper_test.py index f5e035ea..b29d89e1 100644 --- a/tests/environments_test/wrapper_test.py +++ b/tests/environments_test/wrapper_test.py @@ -110,7 +110,7 @@ def test_wrapper(env_name: str) -> None: print("Observation size: ", env.observation_size) print("Action size: ", env.action_size) - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) init_state = env.reset(random_key) joint_angle = jp.concatenate( diff --git a/tests/utils_test/plotting_test.py b/tests/utils_test/plotting_test.py index 17b4a8ea..807760ce 100644 --- a/tests/utils_test/plotting_test.py +++ b/tests/utils_test/plotting_test.py @@ -38,7 +38,7 @@ def test_onion_grid(num_descriptors: int, grid_shape: Tuple[int, ...]) -> None: minval = jnp.array([0] * num_descriptors) maxval = jnp.array([1] * num_descriptors) - random_key = jax.random.PRNGKey(seed=0) + random_key = jax.random.key(seed=0) random_key, key_desc, key_fit = jax.random.split(random_key, num=3) number_samples_test = 300 diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 8d19379e..981a546d 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -34,7 +34,7 @@ def test_sampling() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)