From b1d0685d96321ca488f77076ce56831a733bb169 Mon Sep 17 00:00:00 2001 From: Milton Montero Date: Tue, 21 May 2024 17:58:14 +0200 Subject: [PATCH 1/4] fix: Change to new-style-jax-rng-keys. All calls to jax.random.PRNGKey have been changed to jax.random.key. When passing keys to non-jax methods (such as the ones in scikit-learn), we use jax.key_data to recover the underlying raw key information. --- README.md | 2 +- examples/aurora.ipynb | 2 +- examples/cmaes.ipynb | 2 +- examples/cmame.ipynb | 2 +- examples/cmamega.ipynb | 2 +- examples/dads.ipynb | 4 ++-- examples/diayn.ipynb | 4 ++-- examples/distributed_mapelites.ipynb | 2 +- examples/jumanji_snake.ipynb | 4 ++-- examples/mapelites.ipynb | 4 ++-- examples/me_sac_pbt.ipynb | 4 ++-- examples/me_td3_pbt.ipynb | 2 +- examples/mees.ipynb | 2 +- examples/mels.ipynb | 4 ++-- examples/mome.ipynb | 2 +- examples/nsga2_spea2.ipynb | 2 +- examples/omgmega.ipynb | 2 +- examples/pga_aurora.ipynb | 2 +- examples/pgame.ipynb | 2 +- examples/qdpg.ipynb | 2 +- examples/sac_pbt.ipynb | 4 ++-- examples/scripts/me_example.py | 2 +- examples/smerl.ipynb | 4 ++-- examples/td3_pbt.ipynb | 2 +- qdax/core/containers/mapelites_repertoire.py | 2 +- qdax/core/neuroevolution/networks/seq2seq_networks.py | 4 ++-- qdax/tasks/README.md | 6 +++--- tests/baselines_test/cmame_test.py | 2 +- tests/baselines_test/cmamega_test.py | 2 +- tests/baselines_test/dads_smerl_test.py | 2 +- tests/baselines_test/dads_test.py | 2 +- tests/baselines_test/diayn_smerl_test.py | 2 +- tests/baselines_test/diayn_test.py | 2 +- tests/baselines_test/ga_test.py | 2 +- tests/baselines_test/me_pbt_sac_test.py | 2 +- tests/baselines_test/me_pbt_td3_test.py | 2 +- tests/baselines_test/mees_test.py | 2 +- tests/baselines_test/omgmega_test.py | 2 +- tests/baselines_test/pbt_sac_test.py | 2 +- tests/baselines_test/pbt_td3_test.py | 2 +- tests/baselines_test/pgame_test.py | 2 +- tests/baselines_test/qdpg_test.py | 2 +- tests/baselines_test/sac_test.py | 2 +- tests/baselines_test/td3_test.py | 2 +- tests/core_test/aurora_test.py | 2 +- tests/core_test/cmaes_test.py | 2 +- tests/core_test/emitters_test/multi_emitter_test.py | 2 +- tests/core_test/map_elites_test.py | 2 +- tests/core_test/mels_test.py | 2 +- tests/core_test/mome_test.py | 2 +- .../neuroevolution_test/buffers_test/buffer_test.py | 2 +- tests/default_tasks_test/arm_test.py | 4 ++-- tests/default_tasks_test/brax_task_test.py | 2 +- tests/default_tasks_test/hypervolume_functions_test.py | 2 +- tests/default_tasks_test/jumanji_envs_test.py | 4 ++-- tests/default_tasks_test/qd_suite_test.py | 2 +- tests/default_tasks_test/standard_functions_test.py | 2 +- tests/environments_test/wrapper_test.py | 2 +- tests/utils_test/plotting_test.py | 2 +- tests/utils_test/sampling_test.py | 2 +- 60 files changed, 73 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index 551680eb..236da191 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ min_bd = 0.0 max_bd = 1.0 # Init a random key -random_key = jax.random.PRNGKey(seed) +random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index e4b86238..1be7f0fc 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -159,7 +159,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index c8e2a9fe..9936faa6 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -189,7 +189,7 @@ "outputs": [], "source": [ "state = cmaes.init()\n", - "random_key = jax.random.PRNGKey(0)" + "random_key = jax.random.key(0)" ] }, { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 3c355eea..9d362615 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -217,7 +217,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "# in CMA-ME settings (from the paper), there is no init population\n", "# we multipy by zero to reproduce this setting\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index e5749993..c3c62e69 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -210,7 +210,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.key(0)\n", "# no initial population - give all the same value as emitter init value\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index b3cc43b5..cebc42d3 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -182,7 +182,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -518,7 +518,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 0562e7c2..1eeae301 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -180,7 +180,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -508,7 +508,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 2e6fd991..a1829cc5 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -185,7 +185,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index a6a140fd..f8107c16 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -122,7 +122,7 @@ "env = jumanji.make('Snake-v1')\n", "\n", "# Reset your (jit-able) environment\n", - "key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "state, timestep = jax.jit(env.reset)(key)\n", "\n", "# Interact with the (jit-able) environment\n", @@ -146,7 +146,7 @@ "outputs": [], "source": [ "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# get number of actions\n", "num_actions = env.action_spec().maximum + 1\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b1fea651..ee6683ae 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -146,7 +146,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.Key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -507,7 +507,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index c387643a..a4f34089 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -167,7 +167,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" @@ -521,7 +521,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_map(\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 924d550f..6ad0a00d 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -172,7 +172,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" diff --git a/examples/mees.ipynb b/examples/mees.ipynb index ad1a4740..3d97d5fc 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -163,7 +163,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index bd489ca2..8f3b4457 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -153,7 +153,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -522,7 +522,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 555381e6..d1e17aa8 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -216,7 +216,7 @@ "outputs": [], "source": [ "# initial population\n", - "random_key = jax.random.PRNGKey(42)\n", + "random_key = jax.random.key(42)\n", "random_key, subkey = jax.random.split(random_key)\n", "init_genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index be662981..8836c66c 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -193,7 +193,7 @@ "outputs": [], "source": [ "# Initial population\n", - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "random_key, subkey = jax.random.split(random_key)\n", "init_genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 8d417cc0..884157d3 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -196,7 +196,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "\n", "# defines the population\n", "random_key, subkey = jax.random.split(random_key)\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 6152ce63..5658caa5 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -177,7 +177,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 9b638b2d..cfcc8d3c 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -156,7 +156,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 102d5262..4f1691df 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -169,7 +169,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 7762083f..2df8c51f 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -226,7 +226,7 @@ "outputs": [], "source": [ "# %%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", @@ -534,7 +534,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_map(\n", diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 699c6aba..013c5ca9 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -26,7 +26,7 @@ def run_me() -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index fe655fe2..fc7c77fd 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -187,7 +187,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.Key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -523,7 +523,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index ec98b9da..79cba034 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -197,7 +197,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..284835f6 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -56,7 +56,7 @@ def compute_cvt_centroids( init="k-means++", n_clusters=num_centroids, n_init=1, - random_state=RandomState(subkey), + random_state=RandomState(jax.random.key_data(subkey)), ) k_means.fit(x) centroids = k_means.cluster_centers_ diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index ea7618ba..b4ca103c 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -17,7 +17,7 @@ from flax import linen as nn Array = Any -PRNGKey = Any +PRNGKey = jax.Array class EncoderLSTM(nn.Module): @@ -53,7 +53,7 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: # Use a dummy key since the default state init fn is just zeros. return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore - jax.random.PRNGKey(0), (batch_size, hidden_size) + jax.random.key(0), (batch_size, hidden_size) ) diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md index 56528323..bb3a09d8 100644 --- a/qdax/tasks/README.md +++ b/qdax/tasks/README.md @@ -19,7 +19,7 @@ Notes: import jax from qdax.tasks.arm import arm_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = arm_scoring_function @@ -56,7 +56,7 @@ desc_size = 2 import jax from qdax.tasks.standard_functions import sphere_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = sphere_scoring_function @@ -98,7 +98,7 @@ desc_size = 2 import jax from qdax.tasks.hypervolume_functions import square_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = square_scoring_function diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index c86bd622..035cfa87 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -81,7 +81,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) initial_population = ( jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0 ) diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index fdd9330b..a6a9e25c 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -95,7 +95,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) initial_population = jax.random.uniform( random_key, shape=(batch_size, num_dimensions) ) diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 1e782f2a..6c57fea8 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -71,7 +71,7 @@ def test_dads_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 0b9af46e..09f994d0 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -65,7 +65,7 @@ def test_dads() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index abd94b45..f06a4298 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -69,7 +69,7 @@ def test_diayn_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index 3492d9c1..856e0174 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -62,7 +62,7 @@ def test_diayn() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 4e11370b..a84643f7 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -71,7 +71,7 @@ def scoring_fn( return fitnesses, {}, random_key # initial population - random_key = jax.random.PRNGKey(42) + random_key = jax.random.key(42) random_key, subkey = jax.random.split(random_key) init_genotypes = jax.random.uniform( subkey, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 079fde45..9724270b 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -69,7 +69,7 @@ def test_me_pbt_sac() -> None: ) min_bd, max_bd = env.behavior_descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 5c6fbb0a..ce79f44a 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -69,7 +69,7 @@ def test_me_pbt_td3() -> None: ) min_bd, max_bd = env.behavior_descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 3f3314fd..a56a247e 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -47,7 +47,7 @@ def test_mees() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index 7b0f0639..b341bab6 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -82,7 +82,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) # defines the population random_key, subkey = jax.random.split(random_key) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index c83f277c..5d1fcf47 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -76,7 +76,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 9e6134c9..8680900d 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -74,7 +74,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 9cb1b3fb..463c6236 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -54,7 +54,7 @@ def test_pgame() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 1889f197..9b967f50 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -69,7 +69,7 @@ def test_qdpg() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index c667aa66..6f70a274 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -53,7 +53,7 @@ def test_sac() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/td3_test.py b/tests/baselines_test/td3_test.py index ff5d338c..55c09811 100644 --- a/tests/baselines_test/td3_test.py +++ b/tests/baselines_test/td3_test.py @@ -49,7 +49,7 @@ def test_td3() -> None: auto_reset=True, eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 2b238237..67c18d6d 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -75,7 +75,7 @@ def test_aurora(env_name: str, batch_size: int) -> None: log_freq = 5 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init environment env, policy_network, scoring_fn, random_key = create_default_brax_task_components( diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index 16321fd4..d334fcad 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -32,7 +32,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: ) state = cmaes.init() - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) iteration_count = 0 for _ in range(num_iterations): diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index 93b3e081..c8c4f84b 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -27,7 +27,7 @@ def test_multi_emitter() -> None: max_bd = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index b532aa65..a6a27370 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -51,7 +51,7 @@ def test_map_elites(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 21f90517..596d715c 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -40,7 +40,7 @@ def test_mels(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index c70683ef..1a2c0b2f 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -79,7 +79,7 @@ def scoring_fn( metrics_function = partial(default_moqd_metrics, reference_point=reference_point) # initial population - random_key = jax.random.PRNGKey(42) + random_key = jax.random.key(42) random_key, subkey = jax.random.split(random_key) init_genotypes = jax.random.uniform( subkey, diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index e0e298c1..e14b4dfb 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -87,7 +87,7 @@ def test_sample() -> None: simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) samples, random_key = replay_buffer.sample(random_key, 3) diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index e71e761c..803d9aa0 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -41,7 +41,7 @@ def test_arm(task_name: str, batch_size: int) -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) @@ -110,7 +110,7 @@ def test_arm_scoring_function() -> None: # Init a random key seed = 42 - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # arm has xy BD centered at 0.5 0.5 and min max range is [0,1] # 0 params of first genotype is horizontal and points towards negative x axis diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index f8c63259..a3dd2bfc 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -34,7 +34,7 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) env, policy_network, scoring_fn, random_key = create_default_brax_task_components( env_name=env_name, diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index a390f709..0d709139 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -50,7 +50,7 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index eed90127..0fbb42df 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -30,7 +30,7 @@ def test_jumanji_utils() -> None: env = jumanji.make("Snake-v1") # Reset your (jit-able) environment - key = jax.random.PRNGKey(0) + key = jax.random.key(0) state, _timestep = jax.jit(env.reset)(key) # Interact with the (jit-able) environment @@ -38,7 +38,7 @@ def test_jumanji_utils() -> None: state, _timestep = jax.jit(env.step)(state, action) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # get number of actions num_actions = env.action_spec().maximum + 1 diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index a0542e9b..78a4cb39 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -68,7 +68,7 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: grid_shape = tuple([resolution_per_axis for _ in range(bd_size)]) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of parameters init_variables = task.get_initial_parameters(init_batch_size) diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 7b310389..404e1e24 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -40,7 +40,7 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_bd = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/environments_test/wrapper_test.py b/tests/environments_test/wrapper_test.py index f5e035ea..b29d89e1 100644 --- a/tests/environments_test/wrapper_test.py +++ b/tests/environments_test/wrapper_test.py @@ -110,7 +110,7 @@ def test_wrapper(env_name: str) -> None: print("Observation size: ", env.observation_size) print("Action size: ", env.action_size) - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) init_state = env.reset(random_key) joint_angle = jp.concatenate( diff --git a/tests/utils_test/plotting_test.py b/tests/utils_test/plotting_test.py index 17b4a8ea..807760ce 100644 --- a/tests/utils_test/plotting_test.py +++ b/tests/utils_test/plotting_test.py @@ -38,7 +38,7 @@ def test_onion_grid(num_descriptors: int, grid_shape: Tuple[int, ...]) -> None: minval = jnp.array([0] * num_descriptors) maxval = jnp.array([1] * num_descriptors) - random_key = jax.random.PRNGKey(seed=0) + random_key = jax.random.key(seed=0) random_key, key_desc, key_fit = jax.random.split(random_key, num=3) number_samples_test = 300 diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 6ce6cbe9..d55fe373 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -34,7 +34,7 @@ def test_sampling() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) From dc31b0c9a967ddf92caedd91d93d0174bd084ef5 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 13 Sep 2024 23:03:41 +0100 Subject: [PATCH 2/4] Change remaining jax.random.PRNGKey to jax.random.key --- examples/dcrlme.ipynb | 2 +- .../networks/seq2seq_networks.py | 29 ++++++++++--------- qdax/utils/train_seq2seq.py | 9 ++---- tests/baselines_test/dcrlme_test.py | 2 +- tests/utils_test/uncertainty_metrics_test.py | 2 +- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index eae0e6b3..057ef0c4 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -154,7 +154,7 @@ "source": [ "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index 476705fa..a4bb2272 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -15,9 +15,6 @@ import numpy as np from flax import linen as nn -Array = Any -PRNGKey = jax.Array - class EncoderLSTM(nn.Module): """EncoderLSTM Module wrapped in a lifted scan transform.""" @@ -31,14 +28,16 @@ class EncoderLSTM(nn.Module): ) @nn.compact def __call__( - self, carry: Tuple[Array, Array], x: Array - ) -> Tuple[Tuple[Array, Array], Array]: + self, carry: Tuple[jax.Array, jax.Array], x: jax.Array + ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]: """Applies the module.""" lstm_state, is_eos = carry features = lstm_state[0].shape[-1] new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) - def select_carried_state(new_state: Array, old_state: Array) -> Array: + def select_carried_state( + new_state: jax.Array, old_state: jax.Array + ) -> jax.Array: return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). @@ -49,7 +48,9 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: return (carried_lstm_state, is_eos), y @staticmethod - def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: + def initialize_carry( + batch_size: int, hidden_size: int + ) -> Tuple[jax.Array, jax.Array]: # Use a dummy key since the default state init fn is just zeros. return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore jax.random.key(0), (batch_size, hidden_size) @@ -62,7 +63,7 @@ class Encoder(nn.Module): hidden_size: int @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jax.Array) -> jax.Array: batch_size = inputs.shape[0] lstm = EncoderLSTM(name="encoder_lstm") init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) @@ -95,7 +96,7 @@ class DecoderLSTM(nn.Module): split_rngs={"params": False, "lstm": True}, ) @nn.compact - def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: + def __call__(self, carry: Tuple[jax.Array, jax.Array], x: jax.Array) -> jax.Array: """Applies the DecoderLSTM model.""" lstm_state, last_prediction = carry @@ -124,7 +125,9 @@ class Decoder(nn.Module): obs_size: int @nn.compact - def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]: + def __call__( + self, inputs: jax.Array, init_state: Any + ) -> Tuple[jax.Array, jax.Array]: """Applies the decoder model. Args: @@ -166,8 +169,8 @@ def setup(self) -> None: @nn.compact def __call__( - self, encoder_inputs: Array, decoder_inputs: Array - ) -> Tuple[Array, Array]: + self, encoder_inputs: jax.Array, decoder_inputs: jax.Array + ) -> Tuple[jax.Array, jax.Array]: """Applies the seq2seq model. Args: @@ -194,7 +197,7 @@ def __call__( return logits, predictions - def encode(self, encoder_inputs: Array) -> Array: + def encode(self, encoder_inputs: jax.Array) -> jax.Array: # encode inputs init_decoder_state = self.encoder(encoder_inputs) final_output, _hidden_state = init_decoder_state diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index bd9570a9..fa7825b0 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -19,9 +19,6 @@ from qdax.custom_types import Params, RNGKey from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -Array = Any -PRNGKey = Any - def get_model( obs_size: int, teacher_force: bool = False, hidden_size: int = 10 @@ -40,7 +37,7 @@ def get_model( def get_initial_params( - model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...] + model: Seq2seq, random_key: RNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: """ Returns the initial parameters of a seq2seq model. @@ -62,8 +59,8 @@ def get_initial_params( @jax.jit def train_step( state: train_state.TrainState, - batch: Array, - lstm_random_key: PRNGKey, + batch: jax.Array, + lstm_random_key: RNGKey, ) -> Tuple[train_state.TrainState, Dict[str, float]]: """ Trains for one step. diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 05304944..1bc9688d 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -61,7 +61,7 @@ def test_dcrlme() -> None: policy_delay = 2 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init environment env = environments.create(env_name, episode_length=episode_length) diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index d49e2527..3f2caea1 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -25,7 +25,7 @@ def test_uncertainty_metrics() -> None: genotype_dim = 8 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # First, init a deterministic environment init_policies = jax.random.uniform( From 2040636e97b7e0b9283510b75e1d8695a212b6e4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 20 Sep 2024 14:04:41 +0000 Subject: [PATCH 3/4] fixing bugs from new data structure of RNG keys --- qdax/utils/sampling.py | 4 ++-- tests/baselines_test/me_pbt_sac_test.py | 3 +++ tests/baselines_test/me_pbt_td3_test.py | 3 +++ tests/baselines_test/pbt_sac_test.py | 3 +++ tests/baselines_test/pbt_td3_test.py | 3 +++ 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index be1d336d..94d4e160 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -159,8 +159,8 @@ def multi_sample_scoring_function( # vectorizing over axis 0 vectorizes over the num_samples random keys in_axes=(None, 0), # indicates that the vectorized axis will become axis 1, i.e., the final - # output is shape (batch_size, num_samples, ...) - out_axes=1, + # output is shape (batch_size, num_samples, ...) except for the random key + out_axes=(1, 1, 1, 0), ) all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( policies_params, keys diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index b2365382..c8dcc5af 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -163,6 +163,9 @@ def scoring_function(genotypes, random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( keys ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 53e2802e..f243725e 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -161,6 +161,9 @@ def scoring_function(genotypes, random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( keys ) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index 537d4a26..9c4b2c83 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -104,6 +104,9 @@ def init_environments(random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 813deb60..e45a9701 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -100,6 +100,9 @@ def init_environments(random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) From 5cac247b7b8b8ce52e22cf06e585e9c3e86eb485 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 20 Sep 2024 14:19:16 +0000 Subject: [PATCH 4/4] fix key management in examples using pmap --- examples/me_sac_pbt.ipynb | 4 ++++ examples/me_td3_pbt.ipynb | 4 ++++ examples/sac_pbt.ipynb | 4 ++++ examples/td3_pbt.ipynb | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 068acab3..42c46188 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -311,6 +311,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 71289a96..8caca62f 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -314,6 +314,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index b1ab220b..53b526db 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -269,6 +269,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)" diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 5de9e24b..3bbf237e 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -232,6 +232,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)"