From 14f30f34e59e37a6d03b7e7763e6ba271c2b019c Mon Sep 17 00:00:00 2001 From: LisaCoiffard <91796648+LisaCoiffard@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:12:40 +0100 Subject: [PATCH 1/4] fix: Change indexing in UnstructuredRepertoire intra batch comp (#185) * change indexing to retrieve current fitness in intra batch comp Authored-by: Lisa --- qdax/core/containers/unstructured_repertoire.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 8512d3d6..4a1c0cdb 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -110,9 +110,7 @@ def intra_batch_comp( # We want to eliminate the same individual (distance 0) fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) - current_fitness = jnp.squeeze( - eval_scores.at[knn_relevant_indices.at[0].get()].get() - ) + current_fitness = jnp.squeeze(eval_scores.at[current_index].get()) # Is the fitness of the other individual higher? # If both are True then we discard the current individual since this individual From b2ef13ef69044bb1371ee3eebb299c98bf5a6341 Mon Sep 17 00:00:00 2001 From: LisaCoiffard <91796648+LisaCoiffard@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:24:34 +0100 Subject: [PATCH 2/4] chore: Move `cmaes.py` from `core` to `baselines` (#200) * Move cmaes.py from core to baselines * Change corresponding api docs --------- Authored-by: Lisa --- docs/api_documentation/core/cmaes.md | 2 +- examples/cmaes.ipynb | 2 +- mkdocs.yml | 2 +- qdax/{core => baselines}/cmaes.py | 0 qdax/core/emitters/cma_emitter.py | 2 +- qdax/core/emitters/cma_mega_emitter.py | 2 +- qdax/core/emitters/cma_rnd_emitter.py | 2 +- tests/core_test/cmaes_test.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename qdax/{core => baselines}/cmaes.py (100%) diff --git a/docs/api_documentation/core/cmaes.md b/docs/api_documentation/core/cmaes.md index 257bb89b..5c85704a 100644 --- a/docs/api_documentation/core/cmaes.md +++ b/docs/api_documentation/core/cmaes.md @@ -1,3 +1,3 @@ # CMAES class -::: qdax.core.cmaes.CMAES +::: qdax.baselines.cmaes.CMAES diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index ba059cdc..023b2e27 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -55,7 +55,7 @@ "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Ellipse\n", "\n", - "from qdax.core.cmaes import CMAES" + "from qdax.baselines.cmaes import CMAES" ] }, { diff --git a/mkdocs.yml b/mkdocs.yml index 9207b4f2..b447e83f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -154,11 +154,11 @@ nav: - DADS: api_documentation/core/dads.md - SAC: api_documentation/core/sac.md - TD3: api_documentation/core/td3.md - - CMAES: api_documentation/core/cmaes.md - Genetic Algorithm: api_documentation/core/genetic_algorithm.md - NSGA2: api_documentation/core/nsga2.md - SPEA2: api_documentation/core/spea2.md - PBT: api_documentation/core/pbt.md + - CMAES: api_documentation/core/cmaes.md - Containers: api_documentation/core/containers.md - Emitters: api_documentation/core/emitters.md - Neuroevolution: api_documentation/core/neuroevolution.md diff --git a/qdax/core/cmaes.py b/qdax/baselines/cmaes.py similarity index 100% rename from qdax/core/cmaes.py rename to qdax/baselines/cmaes.py diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index 315dcd9b..e3b476dd 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from qdax.core.cmaes import CMAES, CMAESState +from qdax.baselines.cmaes import CMAES, CMAESState from qdax.core.containers.mapelites_repertoire import ( MapElitesRepertoire, get_cells_indices, diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index c3f87fed..976f528b 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp -from qdax.core.cmaes import CMAES, CMAESState +from qdax.baselines.cmaes import CMAES, CMAESState from qdax.core.containers.mapelites_repertoire import ( MapElitesRepertoire, get_cells_indices, diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 27e4f0db..0715c437 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp -from qdax.core.cmaes import CMAESState +from qdax.baselines.cmaes import CMAESState from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index 16321fd4..daa7ce9d 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import pytest -from qdax.core.cmaes import CMAES +from qdax.baselines.cmaes import CMAES def test_cmaes() -> None: From 96163f218f0ec1918aa237acefe3671f201c141f Mon Sep 17 00:00:00 2001 From: Hannah Janmohamed <49594227+hannah-jan@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:33:24 +0100 Subject: [PATCH 3/4] fix: Fix #139 - ensure MOME works for genotypes that are not arrays (#199) * Fix #139: ensure MOME works for genotypes that are not arrays --- qdax/core/containers/mome_repertoire.py | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 43be3835..58a089a6 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -178,9 +178,6 @@ def _update_masked_pareto_front( pareto_front_len = pareto_front_fitnesses.shape[0] # type: ignore - first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0] - genotypes_dim = first_leaf.shape[1] - descriptors_dim = new_batch_of_descriptors.shape[1] # gather all data @@ -235,14 +232,11 @@ def _update_masked_pareto_front( front_size = len(pareto_front_fitnesses) # type: ignore new_front_fitness = new_front_fitness[:front_size, :] - genotypes_mask = jnp.repeat( - jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1 - ) new_front_genotypes = jax.tree_util.tree_map( - lambda x: x * genotypes_mask, new_front_genotypes + lambda x: x * new_mask_indices[0], new_front_genotypes ) new_front_genotypes = jax.tree_util.tree_map( - lambda x: x[:front_size, :], new_front_genotypes + lambda x: x[:front_size], new_front_genotypes ) descriptors_mask = jnp.repeat( @@ -297,25 +291,31 @@ def _add_one( index = index.astype(jnp.int32) - # get cell data - cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes) - cell_fitness = carry.fitnesses[index] - cell_descriptor = carry.descriptors[index] + # get current repertoire cell data + cell_genotype = jax.tree_util.tree_map( + lambda x: x[index][0], carry.genotypes + ) + cell_fitness = carry.fitnesses[index][0] + cell_descriptor = carry.descriptors[index][0] cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1) + new_genotypes = jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, axis=0), genotype + ) + # update pareto front ( cell_fitness, - cell_genotype, + cell_genotype, # new pf for cell cell_descriptor, cell_mask, ) = self._update_masked_pareto_front( - pareto_front_fitnesses=cell_fitness.squeeze(axis=0), - pareto_front_genotypes=cell_genotype.squeeze(axis=0), - pareto_front_descriptors=cell_descriptor.squeeze(axis=0), - mask=cell_mask.squeeze(axis=0), + pareto_front_fitnesses=cell_fitness, + pareto_front_genotypes=cell_genotype, + pareto_front_descriptors=cell_descriptor, + mask=cell_mask, new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0), - new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0), + new_batch_of_genotypes=new_genotypes, new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0), new_mask=jnp.zeros(shape=(1,), dtype=bool), ) From 6656f5e31849f4adc19b6331a2cc00c9000e6f61 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Sat, 21 Sep 2024 12:57:25 +0100 Subject: [PATCH 4/4] fix: Change to new-style-jax-rng-keys. (#195) * fix: Change to new-style-jax-rng-keys. --------- Co-authored-by: Milton Montero --- README.md | 2 +- examples/aurora.ipynb | 2 +- examples/cmaes.ipynb | 2 +- examples/cmame.ipynb | 2 +- examples/cmamega.ipynb | 2 +- examples/dads.ipynb | 4 +-- examples/dcrlme.ipynb | 2 +- examples/diayn.ipynb | 4 +-- examples/distributed_mapelites.ipynb | 2 +- examples/jumanji_snake.ipynb | 4 +-- examples/mapelites.ipynb | 4 +-- examples/me_sac_pbt.ipynb | 8 +++-- examples/me_td3_pbt.ipynb | 6 +++- examples/mees.ipynb | 2 +- examples/mels.ipynb | 4 +-- examples/mome.ipynb | 2 +- examples/nsga2_spea2.ipynb | 2 +- examples/omgmega.ipynb | 2 +- examples/pga_aurora.ipynb | 2 +- examples/pgame.ipynb | 2 +- examples/qdpg.ipynb | 2 +- examples/sac_pbt.ipynb | 8 +++-- examples/scripts/me_example.py | 2 +- examples/smerl.ipynb | 4 +-- examples/td3_pbt.ipynb | 6 +++- qdax/core/containers/mapelites_repertoire.py | 2 +- .../networks/seq2seq_networks.py | 31 ++++++++++--------- qdax/tasks/README.md | 6 ++-- qdax/utils/sampling.py | 4 +-- qdax/utils/train_seq2seq.py | 9 ++---- tests/baselines_test/cmame_test.py | 2 +- tests/baselines_test/cmamega_test.py | 2 +- tests/baselines_test/dads_smerl_test.py | 2 +- tests/baselines_test/dads_test.py | 2 +- tests/baselines_test/dcrlme_test.py | 2 +- tests/baselines_test/diayn_smerl_test.py | 2 +- tests/baselines_test/diayn_test.py | 2 +- tests/baselines_test/ga_test.py | 2 +- tests/baselines_test/me_pbt_sac_test.py | 5 ++- tests/baselines_test/me_pbt_td3_test.py | 5 ++- tests/baselines_test/mees_test.py | 2 +- tests/baselines_test/omgmega_test.py | 2 +- tests/baselines_test/pbt_sac_test.py | 5 ++- tests/baselines_test/pbt_td3_test.py | 5 ++- tests/baselines_test/pgame_test.py | 2 +- tests/baselines_test/qdpg_test.py | 2 +- tests/baselines_test/sac_test.py | 2 +- tests/baselines_test/td3_test.py | 2 +- tests/core_test/aurora_test.py | 2 +- tests/core_test/cmaes_test.py | 2 +- .../emitters_test/multi_emitter_test.py | 2 +- tests/core_test/map_elites_test.py | 2 +- tests/core_test/mels_test.py | 2 +- tests/core_test/mome_test.py | 2 +- .../buffers_test/buffer_test.py | 2 +- tests/default_tasks_test/arm_test.py | 4 +-- tests/default_tasks_test/brax_task_test.py | 2 +- .../hypervolume_functions_test.py | 2 +- tests/default_tasks_test/jumanji_envs_test.py | 4 +-- tests/default_tasks_test/qd_suite_test.py | 2 +- .../standard_functions_test.py | 2 +- tests/environments_test/wrapper_test.py | 2 +- tests/utils_test/plotting_test.py | 2 +- tests/utils_test/sampling_test.py | 2 +- tests/utils_test/uncertainty_metrics_test.py | 2 +- 65 files changed, 124 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index e7955450..052eb74a 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ min_bd = 0.0 max_bd = 1.0 # Init a random key -random_key = jax.random.PRNGKey(seed) +random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index 09072715..55a1db53 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -146,7 +146,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index 023b2e27..a93326ba 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -178,7 +178,7 @@ "outputs": [], "source": [ "state = cmaes.init()\n", - "random_key = jax.random.PRNGKey(0)" + "random_key = jax.random.key(0)" ] }, { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index f7aa235e..d42dadb8 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -205,7 +205,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "# in CMA-ME settings (from the paper), there is no init population\n", "# we multipy by zero to reproduce this setting\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index d37bf80e..1a8eeafb 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -198,7 +198,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.key(0)\n", "# no initial population - give all the same value as emitter init value\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index 57d1df05..50d99d56 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -163,7 +163,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -499,7 +499,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index eae0e6b3..057ef0c4 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -154,7 +154,7 @@ "source": [ "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index cdee8b4b..d13ccad7 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -162,7 +162,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -490,7 +490,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 7a7b5296..d2b158da 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -176,7 +176,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 78ec01c8..078c1c65 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -113,7 +113,7 @@ "env = jumanji.make('Snake-v1')\n", "\n", "# Reset your (jit-able) environment\n", - "key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "state, timestep = jax.jit(env.reset)(key)\n", "\n", "# Interact with the (jit-able) environment\n", @@ -137,7 +137,7 @@ "outputs": [], "source": [ "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# get number of actions\n", "num_actions = env.action_spec().maximum + 1\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b7a0a256..626fb5de 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -133,7 +133,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.Key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -494,7 +494,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 86deebc4..42c46188 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -150,7 +150,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" @@ -311,6 +311,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, @@ -504,7 +508,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_map(\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index f72ccda1..8caca62f 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -155,7 +155,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" @@ -314,6 +314,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, diff --git a/examples/mees.ipynb b/examples/mees.ipynb index 3cbf890f..c09c7132 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -151,7 +151,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index dae02a95..ed5a7c7a 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -140,7 +140,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -509,7 +509,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 0d005dbe..217f94be 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -205,7 +205,7 @@ "outputs": [], "source": [ "# initial population\n", - "random_key = jax.random.PRNGKey(42)\n", + "random_key = jax.random.key(42)\n", "random_key, subkey = jax.random.split(random_key)\n", "genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index b418bd31..2d157323 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -181,7 +181,7 @@ "outputs": [], "source": [ "# Initial population\n", - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "random_key, subkey = jax.random.split(random_key)\n", "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index bde9f5ed..900fc812 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -184,7 +184,7 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "random_key = jax.random.key(0)\n", "\n", "# defines the population\n", "random_key, subkey = jax.random.split(random_key)\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 56ed4f01..c3c00ae5 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -164,7 +164,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 31e4f831..c5419a3f 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -143,7 +143,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 8c47ffe6..a30c3be3 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -157,7 +157,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "random_key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 915cc272..53b526db 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -210,7 +210,7 @@ "outputs": [], "source": [ "# %%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", @@ -269,6 +269,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)" @@ -518,7 +522,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_util.tree_map(\n", diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 433bc1d2..294cca8e 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -26,7 +26,7 @@ def run_me() -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 0e332192..ede905f9 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -168,7 +168,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.Key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -504,7 +504,7 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", + "random_key = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=random_key)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index d2d98f85..3bbf237e 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -181,7 +181,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", @@ -232,6 +232,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)" diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index b473d4b3..87584eb3 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -63,7 +63,7 @@ def compute_cvt_centroids( init="k-means++", n_clusters=num_centroids, n_init=1, - random_state=RandomState(subkey), + random_state=RandomState(jax.random.key_data(subkey)), ) k_means.fit(x) centroids = k_means.cluster_centers_ diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index 3cb52a3e..a4bb2272 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -15,9 +15,6 @@ import numpy as np from flax import linen as nn -Array = Any -PRNGKey = Any - class EncoderLSTM(nn.Module): """EncoderLSTM Module wrapped in a lifted scan transform.""" @@ -31,14 +28,16 @@ class EncoderLSTM(nn.Module): ) @nn.compact def __call__( - self, carry: Tuple[Array, Array], x: Array - ) -> Tuple[Tuple[Array, Array], Array]: + self, carry: Tuple[jax.Array, jax.Array], x: jax.Array + ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]: """Applies the module.""" lstm_state, is_eos = carry features = lstm_state[0].shape[-1] new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) - def select_carried_state(new_state: Array, old_state: Array) -> Array: + def select_carried_state( + new_state: jax.Array, old_state: jax.Array + ) -> jax.Array: return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). @@ -49,10 +48,12 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: return (carried_lstm_state, is_eos), y @staticmethod - def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: + def initialize_carry( + batch_size: int, hidden_size: int + ) -> Tuple[jax.Array, jax.Array]: # Use a dummy key since the default state init fn is just zeros. return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore - jax.random.PRNGKey(0), (batch_size, hidden_size) + jax.random.key(0), (batch_size, hidden_size) ) @@ -62,7 +63,7 @@ class Encoder(nn.Module): hidden_size: int @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jax.Array) -> jax.Array: batch_size = inputs.shape[0] lstm = EncoderLSTM(name="encoder_lstm") init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) @@ -95,7 +96,7 @@ class DecoderLSTM(nn.Module): split_rngs={"params": False, "lstm": True}, ) @nn.compact - def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: + def __call__(self, carry: Tuple[jax.Array, jax.Array], x: jax.Array) -> jax.Array: """Applies the DecoderLSTM model.""" lstm_state, last_prediction = carry @@ -124,7 +125,9 @@ class Decoder(nn.Module): obs_size: int @nn.compact - def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]: + def __call__( + self, inputs: jax.Array, init_state: Any + ) -> Tuple[jax.Array, jax.Array]: """Applies the decoder model. Args: @@ -166,8 +169,8 @@ def setup(self) -> None: @nn.compact def __call__( - self, encoder_inputs: Array, decoder_inputs: Array - ) -> Tuple[Array, Array]: + self, encoder_inputs: jax.Array, decoder_inputs: jax.Array + ) -> Tuple[jax.Array, jax.Array]: """Applies the seq2seq model. Args: @@ -194,7 +197,7 @@ def __call__( return logits, predictions - def encode(self, encoder_inputs: Array) -> Array: + def encode(self, encoder_inputs: jax.Array) -> jax.Array: # encode inputs init_decoder_state = self.encoder(encoder_inputs) final_output, _hidden_state = init_decoder_state diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md index 56528323..bb3a09d8 100644 --- a/qdax/tasks/README.md +++ b/qdax/tasks/README.md @@ -19,7 +19,7 @@ Notes: import jax from qdax.tasks.arm import arm_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = arm_scoring_function @@ -56,7 +56,7 @@ desc_size = 2 import jax from qdax.tasks.standard_functions import sphere_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = sphere_scoring_function @@ -98,7 +98,7 @@ desc_size = 2 import jax from qdax.tasks.hypervolume_functions import square_scoring_function -random_key = jax.random.PRNGKey(0) +random_key = jax.random.key(0) # Get scoring function scoring_fn = square_scoring_function diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index be1d336d..94d4e160 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -159,8 +159,8 @@ def multi_sample_scoring_function( # vectorizing over axis 0 vectorizes over the num_samples random keys in_axes=(None, 0), # indicates that the vectorized axis will become axis 1, i.e., the final - # output is shape (batch_size, num_samples, ...) - out_axes=1, + # output is shape (batch_size, num_samples, ...) except for the random key + out_axes=(1, 1, 1, 0), ) all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( policies_params, keys diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index bd9570a9..fa7825b0 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -19,9 +19,6 @@ from qdax.custom_types import Params, RNGKey from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -Array = Any -PRNGKey = Any - def get_model( obs_size: int, teacher_force: bool = False, hidden_size: int = 10 @@ -40,7 +37,7 @@ def get_model( def get_initial_params( - model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...] + model: Seq2seq, random_key: RNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: """ Returns the initial parameters of a seq2seq model. @@ -62,8 +59,8 @@ def get_initial_params( @jax.jit def train_step( state: train_state.TrainState, - batch: Array, - lstm_random_key: PRNGKey, + batch: jax.Array, + lstm_random_key: RNGKey, ) -> Tuple[train_state.TrainState, Dict[str, float]]: """ Trains for one step. diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index 2dc6fa10..82d7e54a 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -81,7 +81,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) initial_population = ( jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0 ) diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index 5bfdfd58..02fdad13 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -95,7 +95,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) initial_population = jax.random.uniform( random_key, shape=(batch_size, num_dimensions) ) diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 2a8d3d1f..2ecc25f6 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -72,7 +72,7 @@ def test_dads_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 77094ffd..76da834e 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -66,7 +66,7 @@ def test_dads() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 05304944..1bc9688d 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -61,7 +61,7 @@ def test_dcrlme() -> None: policy_delay = 2 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init environment env = environments.create(env_name, episode_length=episode_length) diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index abd94b45..f06a4298 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -69,7 +69,7 @@ def test_diayn_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index 3492d9c1..856e0174 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -62,7 +62,7 @@ def test_diayn() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index a1eb1b51..619c76e2 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -71,7 +71,7 @@ def scoring_fn( return fitnesses, {}, random_key # initial population - random_key = jax.random.PRNGKey(42) + random_key = jax.random.key(42) random_key, subkey = jax.random.split(random_key) genotypes = jax.random.uniform( subkey, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 5058bad6..c8dcc5af 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -69,7 +69,7 @@ def test_me_pbt_sac() -> None: ) min_bd, max_bd = env.behavior_descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) @@ -163,6 +163,9 @@ def scoring_function(genotypes, random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( keys ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 39c3e942..f243725e 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -69,7 +69,7 @@ def test_me_pbt_td3() -> None: ) min_bd, max_bd = env.behavior_descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) @@ -161,6 +161,9 @@ def scoring_function(genotypes, random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( keys ) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index d1913b02..b5d56f1f 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -47,7 +47,7 @@ def test_mees() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index ad51c7ae..632dc993 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -82,7 +82,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) # defines the population random_key, subkey = jax.random.split(random_key) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index db7dc69e..9c4b2c83 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -76,7 +76,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( @@ -104,6 +104,9 @@ def init_environments(random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 0be68277..e45a9701 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -74,7 +74,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( @@ -100,6 +100,9 @@ def init_environments(random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 0490a481..639f1a9d 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -54,7 +54,7 @@ def test_pgame() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 704416a4..7f1868f6 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -69,7 +69,7 @@ def test_qdpg() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index 8c26b510..57554b92 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -53,7 +53,7 @@ def test_sac() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/td3_test.py b/tests/baselines_test/td3_test.py index ff5d338c..55c09811 100644 --- a/tests/baselines_test/td3_test.py +++ b/tests/baselines_test/td3_test.py @@ -49,7 +49,7 @@ def test_td3() -> None: auto_reset=True, eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 4bbb9d82..9ec55a78 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -75,7 +75,7 @@ def test_aurora(env_name: str, batch_size: int) -> None: log_freq = 5 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init environment env, policy_network, scoring_fn, random_key = create_default_brax_task_components( diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index daa7ce9d..dc6078d1 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -32,7 +32,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: ) state = cmaes.init() - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) iteration_count = 0 for _ in range(num_iterations): diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index ebf712d5..cee20261 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -27,7 +27,7 @@ def test_multi_emitter() -> None: max_bd = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index c89ce04f..61c90f06 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -51,7 +51,7 @@ def test_map_elites(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 66bcc05f..383ab55a 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -40,7 +40,7 @@ def test_mels(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index 746b94a0..4e0eb574 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -79,7 +79,7 @@ def scoring_fn( metrics_function = partial(default_moqd_metrics, reference_point=reference_point) # initial population - random_key = jax.random.PRNGKey(42) + random_key = jax.random.key(42) random_key, subkey = jax.random.split(random_key) genotypes = jax.random.uniform( subkey, diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index 06e25fcd..07726b94 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -91,7 +91,7 @@ def test_sample() -> None: simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) - random_key = jax.random.PRNGKey(0) + random_key = jax.random.key(0) samples, random_key = replay_buffer.sample(random_key, 3) diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index 98361b23..31e0e0c1 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -41,7 +41,7 @@ def test_arm(task_name: str, batch_size: int) -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) @@ -114,7 +114,7 @@ def test_arm_scoring_function() -> None: # Init a random key seed = 42 - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # arm has xy BD centered at 0.5 0.5 and min max range is [0,1] # 0 params of first genotype is horizontal and points towards negative x axis diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index c12518fb..55768171 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -34,7 +34,7 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) env, policy_network, scoring_fn, random_key = create_default_brax_task_components( env_name=env_name, diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index 3d619353..152c245d 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -50,7 +50,7 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_bd = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index 636a02cf..dba574ac 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -30,7 +30,7 @@ def test_jumanji_utils() -> None: env = jumanji.make("Snake-v1") # Reset your (jit-able) environment - key = jax.random.PRNGKey(0) + key = jax.random.key(0) state, _timestep = jax.jit(env.reset)(key) # Interact with the (jit-able) environment @@ -38,7 +38,7 @@ def test_jumanji_utils() -> None: state, _timestep = jax.jit(env.step)(state, action) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # get number of actions num_actions = env.action_spec().maximum + 1 diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index 46f6ce9b..9424cc63 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -68,7 +68,7 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: grid_shape = tuple([resolution_per_axis for _ in range(bd_size)]) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of parameters init_variables = task.get_initial_parameters(init_batch_size) diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 87913364..b30cd7cc 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -40,7 +40,7 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_bd = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init population of controllers random_key, subkey = jax.random.split(random_key) diff --git a/tests/environments_test/wrapper_test.py b/tests/environments_test/wrapper_test.py index f5e035ea..b29d89e1 100644 --- a/tests/environments_test/wrapper_test.py +++ b/tests/environments_test/wrapper_test.py @@ -110,7 +110,7 @@ def test_wrapper(env_name: str) -> None: print("Observation size: ", env.observation_size) print("Action size: ", env.action_size) - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) init_state = env.reset(random_key) joint_angle = jp.concatenate( diff --git a/tests/utils_test/plotting_test.py b/tests/utils_test/plotting_test.py index 17b4a8ea..807760ce 100644 --- a/tests/utils_test/plotting_test.py +++ b/tests/utils_test/plotting_test.py @@ -38,7 +38,7 @@ def test_onion_grid(num_descriptors: int, grid_shape: Tuple[int, ...]) -> None: minval = jnp.array([0] * num_descriptors) maxval = jnp.array([1] * num_descriptors) - random_key = jax.random.PRNGKey(seed=0) + random_key = jax.random.key(seed=0) random_key, key_desc, key_fit = jax.random.split(random_key, num=3) number_samples_test = 300 diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 8d19379e..981a546d 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -34,7 +34,7 @@ def test_sampling() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index d49e2527..3f2caea1 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -25,7 +25,7 @@ def test_uncertainty_metrics() -> None: genotype_dim = 8 # Init a random key - random_key = jax.random.PRNGKey(seed) + random_key = jax.random.key(seed) # First, init a deterministic environment init_policies = jax.random.uniform(