From b1535eefa220068bf171bd71d77ade7c9f827d3f Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Tue, 3 Sep 2024 15:10:51 +0000 Subject: [PATCH] Stop returning random keys --- README.md | 10 +- docs/overview.md | 10 +- examples/aurora.ipynb | 38 ++--- examples/cmaes.ipynb | 6 +- examples/cmame.ipynb | 14 +- examples/cmamega.ipynb | 18 +- examples/dads.ipynb | 14 +- examples/diayn.ipynb | 14 +- examples/distributed_mapelites.ipynb | 24 +-- examples/jumanji_snake.ipynb | 30 ++-- examples/mapelites.ipynb | 26 +-- examples/me_sac_pbt.ipynb | 12 +- examples/me_td3_pbt.ipynb | 10 +- examples/mees.ipynb | 24 +-- examples/mels.ipynb | 22 +-- examples/mome.ipynb | 22 +-- examples/nsga2_spea2.ipynb | 24 +-- examples/omgmega.ipynb | 14 +- examples/pga_aurora.ipynb | 40 ++--- examples/pgame.ipynb | 22 +-- examples/qdpg.ipynb | 22 +-- examples/sac_pbt.ipynb | 10 +- examples/scripts/me_example.py | 12 +- examples/smerl.ipynb | 14 +- examples/td3_pbt.ipynb | 8 +- qdax/baselines/dads.py | 57 +++---- qdax/baselines/dads_smerl.py | 9 +- qdax/baselines/diayn.py | 55 +++--- qdax/baselines/diayn_smerl.py | 11 +- qdax/baselines/genetic_algorithm.py | 49 +++--- qdax/baselines/nsga2.py | 16 +- qdax/baselines/pbt.py | 15 +- qdax/baselines/sac.py | 97 +++++------ qdax/baselines/sac_pbt.py | 59 +++---- qdax/baselines/spea2.py | 25 +-- qdax/baselines/td3.py | 47 +++--- qdax/baselines/td3_pbt.py | 57 ++++--- qdax/core/aurora.py | 47 +++--- qdax/core/cmaes.py | 11 +- qdax/core/containers/ga_repertoire.py | 11 +- qdax/core/containers/mapelites_repertoire.py | 40 ++--- qdax/core/containers/mome_repertoire.py | 25 ++- qdax/core/containers/repertoire.py | 4 +- .../containers/uniform_replacement_archive.py | 12 +- .../containers/unstructured_repertoire.py | 21 ++- qdax/core/distributed_map_elites.py | 49 +++--- qdax/core/emitters/cma_emitter.py | 60 ++++--- qdax/core/emitters/cma_mega_emitter.py | 47 +++--- qdax/core/emitters/cma_pool_emitter.py | 38 ++--- qdax/core/emitters/cma_rnd_emitter.py | 28 ++-- qdax/core/emitters/dpg_emitter.py | 32 ++-- qdax/core/emitters/emitter.py | 19 +-- qdax/core/emitters/mees_emitter.py | 128 ++++++-------- qdax/core/emitters/multi_emitter.py | 40 +++-- qdax/core/emitters/mutation_operators.py | 67 ++++---- qdax/core/emitters/omg_mega_emitter.py | 44 ++--- qdax/core/emitters/pbt_me_emitter.py | 48 +++--- qdax/core/emitters/pbt_variation_operators.py | 20 +-- qdax/core/emitters/qdcg_emitter.py | 102 +++++------ qdax/core/emitters/qpg_emitter.py | 71 ++++---- qdax/core/emitters/standard_emitters.py | 27 ++- qdax/core/map_elites.py | 46 +++-- qdax/core/mels.py | 18 +- qdax/core/mome.py | 18 +- qdax/core/neuroevolution/buffers/buffer.py | 7 +- .../buffers/trajectory_buffer.py | 22 ++- qdax/core/neuroevolution/losses/sac_loss.py | 22 +-- qdax/core/neuroevolution/losses/td3_loss.py | 14 +- qdax/core/neuroevolution/mdp_utils.py | 32 ++-- .../networks/seq2seq_networks.py | 2 +- qdax/tasks/README.md | 12 +- qdax/tasks/arm.py | 28 +--- qdax/tasks/brax_envs.py | 158 ++++++++---------- qdax/tasks/hypervolume_functions.py | 9 +- qdax/tasks/jumanji_envs.py | 24 ++- qdax/tasks/qd_suite/archimedean_spiral.py | 4 +- qdax/tasks/qd_suite/qd_suite_task.py | 6 +- qdax/tasks/standard_functions.py | 18 +- qdax/utils/sampling.py | 52 +++--- qdax/utils/train_seq2seq.py | 22 +-- tests/baselines_test/cmame_test.py | 18 +- tests/baselines_test/cmamega_test.py | 24 ++- tests/baselines_test/dads_smerl_test.py | 2 +- tests/baselines_test/dads_test.py | 2 +- tests/baselines_test/diayn_smerl_test.py | 2 +- tests/baselines_test/diayn_test.py | 2 +- tests/baselines_test/ga_test.py | 22 ++- tests/baselines_test/me_pbt_sac_test.py | 18 +- tests/baselines_test/me_pbt_td3_test.py | 18 +- tests/baselines_test/mees_test.py | 33 ++-- tests/baselines_test/omgmega_test.py | 18 +- tests/baselines_test/pbt_sac_test.py | 8 +- tests/baselines_test/pbt_td3_test.py | 8 +- tests/baselines_test/pgame_test.py | 31 ++-- tests/baselines_test/qdpg_test.py | 31 ++-- tests/baselines_test/sac_test.py | 4 +- tests/baselines_test/td3_test.py | 2 +- tests/core_test/aurora_test.py | 26 +-- tests/core_test/cmaes_test.py | 4 +- .../emitters_test/multi_emitter_test.py | 12 +- tests/core_test/map_elites_test.py | 24 +-- tests/core_test/mels_test.py | 20 +-- tests/core_test/mome_test.py | 22 +-- .../buffers_test/buffer_test.py | 4 +- tests/default_tasks_test/arm_test.py | 42 ++--- tests/default_tasks_test/brax_task_test.py | 23 ++- .../hypervolume_functions_test.py | 12 +- tests/default_tasks_test/jumanji_envs_test.py | 14 +- tests/default_tasks_test/qd_suite_test.py | 10 +- .../standard_functions_test.py | 12 +- tests/environments_test/wrapper_test.py | 4 +- tests/utils_test/plotting_test.py | 4 +- tests/utils_test/sampling_test.py | 20 +-- 113 files changed, 1354 insertions(+), 1604 deletions(-) diff --git a/README.md b/README.md index 6999f58a..f1f37a3a 100644 --- a/README.md +++ b/README.md @@ -64,10 +64,10 @@ min_descriptor = 0.0 max_descriptor = 1.0 # Init a random key -random_key = jax.random.PRNGKey(seed) +key = jax.random.key(seed) # Init population of controllers -random_key, subkey = jax.random.split(random_key) +key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions), @@ -111,14 +111,14 @@ centroids = compute_euclidean_centroids( ) # Initializes repertoire and emitter state -repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key) +repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run MAP-Elites loop for i in range(num_iterations): - (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + (repertoire, emitter_state, metrics, key,) = map_elites.update( repertoire, emitter_state, - random_key, + key, ) # Get contents of repertoire diff --git a/docs/overview.md b/docs/overview.md index a5b6def7..59f95517 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -22,18 +22,18 @@ More importantly, QDax handles the archive management which is the key idea of Q ## Code Example ```python # Initializes repertoire and emitter state -repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key) +repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) for i in range(num_iterations): # generate new population with the emitter - genotypes, random_key = map_elites._emitter.emit( - repertoire, emitter_state, random_key + genotypes, key = map_elites._emitter.emit( + repertoire, emitter_state, key ) # scores/evaluates the population - fitnesses, descriptors, extra_scores, random_key = map_elites._scoring_function( - genotypes, random_key + fitnesses, descriptors, extra_scores, key = map_elites._scoring_function( + genotypes, key ) # update repertoire diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index c64e7757..9b515d83 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -159,7 +159,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -170,14 +170,14 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -202,7 +202,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -224,7 +224,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -243,9 +243,9 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n", + "env, policy_network, scoring_fn, key = create_default_brax_task_components(\n", " env_name=env_name,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "def observation_extractor_fn(\n", @@ -339,20 +339,20 @@ " \"\"\"Scan the udpate function.\"\"\"\n", " (\n", " repertoire,\n", - " random_key,\n", + " key,\n", " aurora_extra_info\n", " ) = carry\n", "\n", " # update\n", - " (repertoire, _, metrics, random_key,) = aurora.update(\n", + " (repertoire, _, metrics, key,) = aurora.update(\n", " repertoire,\n", " None,\n", - " random_key,\n", + " key,\n", " aurora_extra_info=aurora_extra_info,\n", " )\n", "\n", " return (\n", - " (repertoire, random_key, aurora_extra_info),\n", + " (repertoire, key, aurora_extra_info),\n", " metrics,\n", " )\n", "\n", @@ -380,7 +380,7 @@ ")\n", "\n", "# Init the model params\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", ")\n", @@ -423,18 +423,18 @@ ")\n", "\n", "# init step of the aurora algorithm\n", - "repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n", + "repertoire, emitter_state, aurora_extra_info, key = aurora.init(\n", " init_variables,\n", " aurora_extra_info,\n", " jnp.asarray(l_value_init),\n", " max_observation_size,\n", - " random_key,\n", + " key,\n", ")\n", "\n", "# initializing means and stds and AURORA\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "repertoire, aurora_extra_info = aurora.train(\n", - " repertoire, model_params, iteration=0, random_key=subkey\n", + " repertoire, model_params, iteration=0, key=subkey\n", ")\n", "\n", "# design aurora's schedule\n", @@ -468,11 +468,11 @@ "while iteration < max_iterations:\n", "\n", " (\n", - " (repertoire, random_key, aurora_extra_info),\n", + " (repertoire, key, aurora_extra_info),\n", " metrics,\n", " ) = jax.lax.scan(\n", " update_scan_fn,\n", - " (repertoire, random_key, aurora_extra_info),\n", + " (repertoire, key, aurora_extra_info),\n", " (),\n", " length=log_freq,\n", " )\n", @@ -485,7 +485,7 @@ " # autoencoder steps and CVC\n", " if (iteration + 1) in schedules:\n", " # train the autoencoder\n", - " random_key, subkey = jax.random.split(random_key)\n", + " key, subkey = jax.random.split(key)\n", " repertoire, aurora_extra_info = aurora.train(\n", " repertoire, model_params, iteration, subkey\n", " )\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index d7b30b1d..f2cab036 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -189,7 +189,7 @@ "outputs": [], "source": [ "state = cmaes.init()\n", - "random_key = jax.random.PRNGKey(0)" + "key = jax.random.key(0)" ] }, { @@ -225,7 +225,7 @@ " iteration_count += 1\n", "\n", " # sample\n", - " samples, random_key = cmaes.sample(state, random_key)\n", + " samples, key = cmaes.sample(state, key)\n", "\n", " # udpate\n", " state = cmaes.update(state, samples)\n", @@ -296,7 +296,7 @@ "fig, ax = plt.subplots(figsize=(12, 6))\n", "\n", "# sample points to show fitness landscape\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n", "f_x = fitness_fn(x)\n", "\n", diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 1da3420d..c56202f5 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -166,9 +166,9 @@ " scores, descriptors = fitness_scoring(x), _descriptors(x)\n", " return scores, descriptors, {}\n", "\n", - "def scoring_fn(x, random_key):\n", + "def scoring_fn(x, key):\n", " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", - " return fitnesses, descriptors, extra_scores, random_key" + " return fitnesses, descriptors, extra_scores, key" ] }, { @@ -217,10 +217,10 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "# in CMA-ME settings (from the paper), there is no init population\n", "# we multipy by zero to reproduce this setting\n", - "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", + "initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", "centroids = compute_euclidean_centroids(\n", " grid_shape=grid_shape,\n", @@ -271,7 +271,7 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + "repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key)" ] }, { @@ -289,9 +289,9 @@ "source": [ "%%time\n", "\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index cc55d3b1..2a6e56b5 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -160,9 +160,9 @@ "\n", " return scores, descriptors, extra_scores\n", "\n", - "def scoring_fn(x, random_key):\n", + "def scoring_fn(x, key):\n", " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", - " return fitnesses, descriptors, extra_scores, random_key" + " return fitnesses, descriptors, extra_scores, key" ] }, { @@ -210,17 +210,17 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "# no initial population - give all the same value as emitter init value\n", - "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", + "initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=2,\n", " num_init_cvt_samples=10000,\n", " num_centroids=num_centroids,\n", " minval=minval,\n", " maxval=maxval,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "emitter = CMAMEGAEmitter(\n", @@ -245,7 +245,7 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + "repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key)" ] }, { @@ -256,9 +256,9 @@ "source": [ "%%time\n", "\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" diff --git a/examples/dads.ipynb b/examples/dads.ipynb index ad932719..ce4cb7c7 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -176,7 +176,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -492,10 +492,10 @@ "jit_env_step = jax.jit(visual_env.step)\n", "\n", "@jax.jit\n", - "def jit_inference_fn(params, observation, random_key):\n", + "def jit_inference_fn(params, observation, key):\n", " obs = jnp.concatenate([observation, skill], axis=0)\n", - " action, random_key = dads.select_action(obs, params, random_key, deterministic=True)\n", - " return action, random_key" + " action, key = dads.select_action(obs, params, key, deterministic=True)\n", + " return action, key" ] }, { @@ -512,11 +512,11 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", - "state = jit_env_reset(rng=random_key)\n", + "key = jax.random.key(seed=1)\n", + "state = jit_env_reset(rng=key)\n", "while not state.done:\n", " rollout.append(state)\n", - " action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n", + " action, key = jit_inference_fn(my_params, state.obs, key)\n", " state = jit_env_step(state, action)\n", "\n", "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index e10824c5..f86809c9 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -174,7 +174,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -482,10 +482,10 @@ "jit_env_step = jax.jit(visual_env.step)\n", "\n", "@jax.jit\n", - "def jit_inference_fn(params, observation, random_key):\n", + "def jit_inference_fn(params, observation, key):\n", " obs = jnp.concatenate([observation, skill], axis=0)\n", - " action, random_key = diayn.select_action(obs, params, random_key, deterministic=True)\n", - " return action, random_key" + " action, key = diayn.select_action(obs, params, key, deterministic=True)\n", + " return action, key" ] }, { @@ -502,11 +502,11 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", - "state = jit_env_reset(rng=random_key)\n", + "key = jax.random.key(seed=1)\n", + "state = jit_env_reset(rng=key)\n", "while not state.done:\n", " rollout.append(state)\n", - " action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n", + " action, key = jit_inference_fn(my_params, state.obs, key)\n", " state = jit_env_step(state, action)\n", "\n", "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 8d65bbf2..e8dde684 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -189,7 +189,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -200,14 +200,14 @@ ")\n", "\n", "# Init population of controllers (batch size controllers)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "\n", "# Create the initial environment states (batch_size_per_device environments)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size_per_device, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -232,7 +232,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -254,7 +254,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -341,18 +341,18 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "# Compute initial repertoire and emitter state\n", - "random_key = jax.random.split(random_key, num=num_devices)\n", - "random_key = jnp.stack(random_key)\n", + "key = jax.random.split(key, num=num_devices)\n", + "key = jnp.stack(key)\n", "\n", "# add a dimension for devices\n", "init_variables = jax.tree_util.tree_map(\n", @@ -361,10 +361,10 @@ ")\n", "\n", "# get initial elements\n", - "repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n", + "repertoire, emitter_state, key = map_elites.get_distributed_init_fn(\n", " centroids=centroids,\n", " devices=devices,\n", - ")(genotypes=init_variables, random_key=random_key)" + ")(genotypes=init_variables, key=key)" ] }, { @@ -398,7 +398,7 @@ " start_time = time.time()\n", "\n", " # main iterations\n", - " repertoire, emitter_state, random_key, metrics = update_fn(repertoire, emitter_state, random_key)\n", + " repertoire, emitter_state, key, metrics = update_fn(repertoire, emitter_state, key)\n", "\n", " # get metrics\n", " metrics = jax.tree_util.tree_map(lambda x: x[0], metrics)\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 8631492e..e26bdc1b 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -122,7 +122,7 @@ "env = jumanji.make('Snake-v1')\n", "\n", "# Reset your (jit-able) environment\n", - "key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "state, timestep = jax.jit(env.reset)(key)\n", "\n", "# Interact with the (jit-able) environment\n", @@ -146,7 +146,7 @@ "outputs": [], "source": [ "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# get number of actions\n", "num_actions = env.action_spec().maximum + 1\n", @@ -185,7 +185,7 @@ " env_state,\n", " timestep,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"Play an environment step and return the updated state and the transition.\n", " Everything is deterministic in this simple example.\n", @@ -214,7 +214,7 @@ " next_state_desc=next_state_desc,\n", " )\n", "\n", - " return next_state, next_timestep, policy_params, random_key, transition" + " return next_state, next_timestep, policy_params, key, transition" ] }, { @@ -235,7 +235,7 @@ "outputs": [], "source": [ "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "\n", "# compute observation size from observation spec\n", @@ -246,7 +246,7 @@ "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "\n", @@ -288,7 +288,7 @@ " return descriptors\n", "\n", "# create a random projection to a two dim space\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "linear_projection = jax.random.uniform(\n", " subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32\n", ")\n", @@ -325,10 +325,10 @@ "outputs": [], "source": [ "def scoring_function(\n", - " genotypes: jnp.ndarray, random_key: RNGKey\n", + " genotypes: jnp.ndarray, key: RNGKey\n", ") -> Tuple[Fitness, ExtraScores, RNGKey]:\n", - " fitnesses, _, extra_scores, random_key = scoring_fn(genotypes, random_key)\n", - " return fitnesses.reshape(-1, 1), extra_scores, random_key" + " fitnesses, _, extra_scores, key = scoring_fn(genotypes, key)\n", + " return fitnesses.reshape(-1, 1), extra_scores, key" ] }, { @@ -384,8 +384,8 @@ " metrics_function=default_ga_metrics,\n", " )\n", "\n", - " repertoire, emitter_state, random_key = algo_instance.init(\n", - " init_variables, population_size, random_key\n", + " repertoire, emitter_state, key = algo_instance.init(\n", + " init_variables, population_size, key\n", " )\n", "\n", "else:\n", @@ -410,7 +410,7 @@ " )\n", "\n", " # Compute initial repertoire and emitter state\n", - " repertoire, emitter_state, random_key = algo_instance.init(init_variables, centroids, random_key)" + " repertoire, emitter_state, key = algo_instance.init(init_variables, centroids, key)" ] }, { @@ -431,9 +431,9 @@ "%%time\n", "\n", "# Run the algorithm\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " algo_instance.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index 3f4f6bd8..cd9167ec 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -146,7 +146,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -157,14 +157,14 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -189,7 +189,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -211,7 +211,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -297,17 +297,17 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "# Compute initial repertoire and emitter state\n", - "repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)" + "repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key)" ] }, { @@ -337,9 +337,9 @@ "for i in range(num_loops):\n", " start_time = time.time()\n", " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", @@ -416,7 +416,7 @@ "outputs": [], "source": [ "# Init population of policies\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", "fake_params = policy_network.init(subkey, fake_batch)\n", "\n", @@ -507,7 +507,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", @@ -546,7 +546,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 0de5f29d..c667f5ae 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -162,7 +162,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" @@ -247,7 +247,7 @@ "eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)\n", "\n", "\n", - "def scoring_function(genotypes, random_key):\n", + "def scoring_function(genotypes, key):\n", " population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0]\n", " first_states = jax.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n", @@ -256,7 +256,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_descriptors, {}, random_key" + " return population_returns, population_descriptors, {}, key" ] }, { @@ -295,7 +295,7 @@ " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=key,\n", + " key=key,\n", ")" ] }, @@ -343,7 +343,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, key=keys)" ] }, { @@ -516,7 +516,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_map(\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 9b06b402..80593a97 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -167,7 +167,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, subkey = jax.random.split(key)\n", "env_states = jax.jit(env.reset)(rng=subkey)\n", "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" @@ -250,7 +250,7 @@ "eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)\n", "\n", "\n", - "def scoring_function(genotypes, random_key):\n", + "def scoring_function(genotypes, key):\n", " population_size = jax.tree_leaves(genotypes)[0].shape[0]\n", " first_states = jax.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n", @@ -259,7 +259,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_descriptors, {}, random_key" + " return population_returns, population_descriptors, {}, key" ] }, { @@ -298,7 +298,7 @@ " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=key,\n", + " key=key,\n", ")" ] }, @@ -346,7 +346,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, key=keys)" ] }, { diff --git a/examples/mees.ipynb b/examples/mees.ipynb index b2677c0d..1cfeff32 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -163,7 +163,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -174,13 +174,13 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "fake_batch = jnp.zeros(shape=(1, env.observation_size))\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0)\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "# Create the initial environment state\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "init_state = env.reset(subkey)" ] }, @@ -205,7 +205,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -227,7 +227,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -254,7 +254,7 @@ "scoring_fn = functools.partial(\n", " reset_based_scoring_function_brax_envs,\n", " episode_length=episode_length,\n", - " play_reset_fn=lambda random_key: init_state,\n", + " play_reset_fn=lambda key: init_state,\n", " play_step_fn=play_step_fn,\n", " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", @@ -361,18 +361,18 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_variables, centroids, random_key\n", + "repertoire, emitter_state, key = map_elites.init(\n", + " init_variables, centroids, key\n", ")" ] }, @@ -402,9 +402,9 @@ "for i in range(num_loops):\n", " start_time = time.time()\n", " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index a02c9444..56a2c0ed 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -153,7 +153,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -165,7 +165,7 @@ "\n", "# Init population of controllers. There are batch_size controllers, and each\n", "# controller will be evaluated num_samples times.\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" @@ -192,7 +192,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"Play an environment step and return the updated state and the\n", " transition.\"\"\"\n", @@ -213,7 +213,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -306,17 +306,17 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "# Compute initial repertoire and emitter state\n", - "repertoire, emitter_state, random_key = mels.init(init_variables, centroids, random_key)" + "repertoire, emitter_state, key = mels.init(init_variables, centroids, key)" ] }, { @@ -348,9 +348,9 @@ "for i in range(num_loops):\n", " start_time = time.time()\n", " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " mels_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", @@ -427,7 +427,7 @@ "outputs": [], "source": [ "# Init population of policies\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", "fake_params = policy_network.init(subkey, fake_batch)\n", "\n", @@ -522,7 +522,7 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "state = jit_env_reset(rng=rng)\n", "while not state.done:\n", " rollout.append(state)\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 7e28b608..d282c097 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -171,9 +171,9 @@ "source": [ "scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)\n", "\n", - "def scoring_fn(genotypes: jnp.ndarray, random_key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:\n", + "def scoring_fn(genotypes: jnp.ndarray, key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:\n", " fitnesses, descriptors = scoring_function(genotypes)\n", - " return fitnesses, descriptors, {}, random_key" + " return fitnesses, descriptors, {}, key" ] }, { @@ -216,10 +216,10 @@ "outputs": [], "source": [ "# initial population\n", - "random_key = jax.random.PRNGKey(42)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key = jax.random.key(42)\n", + "key, subkey = jax.random.split(key)\n", "genotypes = jax.random.uniform(\n", - " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", + " key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", "# crossover function\n", @@ -261,13 +261,13 @@ "metadata": {}, "outputs": [], "source": [ - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=2,\n", " num_init_cvt_samples=20000,\n", " num_centroids=num_centroids,\n", " minval=minval,\n", " maxval=maxval,\n", - " random_key=random_key,\n", + " key=key,\n", ")" ] }, @@ -308,11 +308,11 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = mome.init(\n", + "repertoire, emitter_state, key = mome.init(\n", " genotypes,\n", " centroids,\n", " pareto_front_max_length,\n", - " random_key\n", + " key\n", ")" ] }, @@ -334,9 +334,9 @@ "%%time\n", "\n", "# Run the algorithm\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " mome.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 4e9ab3b0..6b13e4f6 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -175,8 +175,8 @@ " base_lag=base_lag\n", ")\n", "\n", - "def scoring_fn(x, random_key):\n", - " return scoring_function(x)[0], {}, random_key" + "def scoring_fn(x, key):\n", + " return scoring_function(x)[0], {}, key" ] }, { @@ -193,8 +193,8 @@ "outputs": [], "source": [ "# Initial population\n", - "random_key = jax.random.PRNGKey(0)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key = jax.random.key(0)\n", + "key, subkey = jax.random.split(key)\n", "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", @@ -243,10 +243,10 @@ ")\n", "\n", "# init nsga2\n", - "repertoire, emitter_state, random_key = nsga2.init(\n", + "repertoire, emitter_state, key = nsga2.init(\n", " genotypes,\n", " population_size,\n", - " random_key\n", + " key\n", ")" ] }, @@ -266,8 +266,8 @@ "%%time\n", "\n", "# run optimization loop\n", - "(repertoire, emitter_state, random_key), _ = jax.lax.scan(\n", - " nsga2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations\n", + "(repertoire, emitter_state, key), _ = jax.lax.scan(\n", + " nsga2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations\n", ")" ] }, @@ -308,11 +308,11 @@ ")\n", "\n", "# init spea2\n", - "repertoire, emitter_state, random_key = spea2.init(\n", + "repertoire, emitter_state, key = spea2.init(\n", " genotypes,\n", " population_size,\n", " num_neighbours,\n", - " random_key\n", + " key\n", ")" ] }, @@ -325,8 +325,8 @@ "%%time\n", "\n", "# run optimization loop\n", - "(repertoire, emitter_state, random_key), _ = jax.lax.scan(\n", - " spea2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations\n", + "(repertoire, emitter_state, key), _ = jax.lax.scan(\n", + " spea2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations\n", ")" ] }, diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 1a294bf7..dc6a1414 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -146,9 +146,9 @@ " gradients = jnp.nan_to_num(gradients)\n", " return fitnesses, descriptors, {\"gradients\": gradients}\n", "\n", - "def scoring_fn(x, random_key):\n", + "def scoring_fn(x, key):\n", " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", - " return fitnesses, descriptors, extra_scores, random_key" + " return fitnesses, descriptors, extra_scores, key" ] }, { @@ -196,10 +196,10 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "\n", "# defines the population\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "initial_population = jax.random.normal(subkey, shape=(init_population_size, num_dimensions))\n", "\n", "sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid\n", @@ -239,7 +239,7 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + "repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key)" ] }, { @@ -257,9 +257,9 @@ "source": [ "%%time\n", "\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index cc6fb12f..9d2861dd 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -177,7 +177,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -188,14 +188,14 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=env_batch_size)\n", "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -220,7 +220,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -242,7 +242,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -261,9 +261,9 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n", + "env, policy_network, scoring_fn, key = create_default_brax_task_components(\n", " env_name=env_name,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "def observation_extractor_fn(\n", @@ -386,20 +386,20 @@ " (\n", " repertoire,\n", " emitter_state,\n", - " random_key,\n", + " key,\n", " aurora_extra_info\n", " ) = carry\n", "\n", " # update\n", - " (repertoire, emitter_state, metrics, random_key,) = aurora.update(\n", + " (repertoire, emitter_state, metrics, key,) = aurora.update(\n", " repertoire,\n", " emitter_state,\n", - " random_key,\n", + " key,\n", " aurora_extra_info=aurora_extra_info,\n", " )\n", "\n", " return (\n", - " (repertoire, emitter_state, random_key, aurora_extra_info),\n", + " (repertoire, emitter_state, key, aurora_extra_info),\n", " metrics,\n", " )\n", "\n", @@ -427,7 +427,7 @@ ")\n", "\n", "# Init the model params\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", ")\n", @@ -459,7 +459,7 @@ ")\n", "\n", "# init the model params\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", ")\n", @@ -478,18 +478,18 @@ ")\n", "\n", "# init step of the aurora algorithm\n", - "repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n", + "repertoire, emitter_state, aurora_extra_info, key = aurora.init(\n", " init_variables,\n", " aurora_extra_info,\n", " jnp.asarray(l_value_init),\n", " max_observation_size,\n", - " random_key,\n", + " key,\n", ")\n", "\n", "# initializing means and stds and AURORA\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "repertoire, aurora_extra_info = aurora.train(\n", - " repertoire, model_params, iteration=0, random_key=subkey\n", + " repertoire, model_params, iteration=0, key=subkey\n", ")\n", "\n", "# design aurora's schedule\n", @@ -523,11 +523,11 @@ "while iteration < max_iterations:\n", "\n", " (\n", - " (repertoire, emitter_state, random_key, aurora_extra_info),\n", + " (repertoire, emitter_state, key, aurora_extra_info),\n", " metrics,\n", " ) = jax.lax.scan(\n", " update_scan_fn,\n", - " (repertoire, emitter_state, random_key, aurora_extra_info),\n", + " (repertoire, emitter_state, key, aurora_extra_info),\n", " (),\n", " length=log_freq,\n", " )\n", @@ -540,7 +540,7 @@ " # autoencoder steps and CVC\n", " if (iteration + 1) in schedules:\n", " # train the autoencoder\n", - " random_key, subkey = jax.random.split(random_key)\n", + " key, subkey = jax.random.split(key)\n", " repertoire, aurora_extra_info = aurora.train(\n", " repertoire, model_params, iteration, subkey\n", " )\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 78d4ade7..c8d52a43 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -156,7 +156,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -167,13 +167,13 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=env_batch_size)\n", "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -196,7 +196,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -218,7 +218,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -332,18 +332,18 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_variables, centroids, random_key\n", + "repertoire, emitter_state, key = map_elites.init(\n", + " init_variables, centroids, key\n", ")" ] }, @@ -367,9 +367,9 @@ "for i in range(num_loops):\n", " start_time = time.time()\n", " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 9fccd4a9..6ec0be80 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -169,7 +169,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -180,13 +180,13 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=env_batch_size)\n", "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -209,7 +209,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -231,7 +231,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -379,18 +379,18 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", + "centroids, key = compute_cvt_centroids(\n", " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", " minval=min_descriptor,\n", " maxval=max_descriptor,\n", - " random_key=random_key,\n", + " key=key,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_variables, centroids, random_key\n", + "repertoire, emitter_state, key = map_elites.init(\n", + " init_variables, centroids, key\n", ")" ] }, @@ -414,9 +414,9 @@ "for i in range(num_loops):\n", " start_time = time.time()\n", " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index a484b035..f16e0378 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -182,10 +182,10 @@ "outputs": [], "source": [ "@jax.jit\n", - "def init_environments(random_key):\n", + "def init_environments(key):\n", "\n", - " env_states = jax.jit(env.reset)(rng=random_key)\n", - " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", + " env_states = jax.jit(env.reset)(rng=key)\n", + " eval_env_first_states = jax.jit(eval_env.reset)(rng=key)\n", "\n", " reshape_fn = jax.jit(\n", " lambda tree: jax.tree_util.tree_map(\n", @@ -221,7 +221,7 @@ "outputs": [], "source": [ "# %%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", @@ -529,7 +529,7 @@ "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.PRNGKey(seed=1)\n", + "rng = jax.random.key(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", "training_state, env_state = jax.tree_util.tree_map(\n", diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index dd972858..7867befa 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -26,10 +26,10 @@ def run_me() -> None: max_descriptor = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions), @@ -73,9 +73,7 @@ def run_me() -> None: ) # Initializes repertoire and emitter state - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run MAP-Elites loop for _ in range(num_iterations): @@ -83,11 +81,11 @@ def run_me() -> None: repertoire, emitter_state, metrics, - random_key, + key, ) = map_elites.update( repertoire, emitter_state, - random_key, + key, ) # plot archive diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 72bfa3e2..f014e83c 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -181,7 +181,7 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "env_state = jax.jit(env.reset)(rng=key)\n", "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", "\n", @@ -497,10 +497,10 @@ "jit_env_step = jax.jit(visual_env.step)\n", "\n", "@jax.jit\n", - "def jit_inference_fn(params, observation, random_key):\n", + "def jit_inference_fn(params, observation, key):\n", " obs = jnp.concatenate([observation, skill], axis=0)\n", - " action, random_key = diayn_smerl.select_action(obs, params, random_key, deterministic=True)\n", - " return action, random_key" + " action, key = diayn_smerl.select_action(obs, params, key, deterministic=True)\n", + " return action, key" ] }, { @@ -517,11 +517,11 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.PRNGKey(seed=1)\n", - "state = jit_env_reset(rng=random_key)\n", + "key = jax.random.key(seed=1)\n", + "state = jit_env_reset(rng=key)\n", "while not state.done:\n", " rollout.append(state)\n", - " action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n", + " action, key = jit_inference_fn(my_params, state.obs, key)\n", " state = jit_env_step(state, action)\n", "\n", "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 484f6d12..d5042307 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -161,10 +161,10 @@ "outputs": [], "source": [ "@jax.jit\n", - "def init_environments(random_key):\n", + "def init_environments(key):\n", "\n", - " env_states = jax.jit(env.reset)(rng=random_key)\n", - " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", + " env_states = jax.jit(env.reset)(rng=key)\n", + " eval_env_first_states = jax.jit(eval_env.reset)(rng=key)\n", "\n", " reshape_fn = jax.jit(\n", " lambda tree: jax.tree_util.tree_map(\n", @@ -192,7 +192,7 @@ "outputs": [], "source": [ "%%time\n", - "key = jax.random.PRNGKey(seed)\n", + "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", "env_states, eval_env_first_states = jax.pmap(\n", diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index bd4f4534..cc309cf7 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -40,7 +40,7 @@ class DadsTrainingState(TrainingState): target_critic_params: Params dynamics_optimizer_state: optax.OptState dynamics_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray normalization_running_stats: RunningMeanStdState @@ -120,7 +120,7 @@ def __init__(self, config: DadsConfig, action_size: int, descriptor_size: int): def init( # type: ignore self, - random_key: RNGKey, + key: RNGKey, action_size: int, observation_size: int, descriptor_size: int, @@ -128,7 +128,7 @@ def init( # type: ignore """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space descriptor_size: the size of the environment's descriptor space (i.e. the @@ -143,17 +143,17 @@ def init( # type: ignore dummy_dyn_obs = jnp.zeros((1, descriptor_size)) dummy_skill = jnp.zeros((1, self._config.num_skills)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) policy_params = self._policy.init(subkey, dummy_obs) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) target_critic_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), critic_params ) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) dynamics_params = self._dynamics.init( subkey, obs=dummy_dyn_obs, @@ -178,7 +178,7 @@ def init( # type: ignore target_critic_params=target_critic_params, dynamics_optimizer_state=dynamics_optimizer_state, dynamics_params=dynamics_params, - random_key=random_key, + key=key, normalization_running_stats=RunningMeanStdState( mean=jnp.zeros( descriptor_size, @@ -274,7 +274,7 @@ def play_step_fn( the played transition """ - random_key = training_state.random_key + key = training_state.key policy_params = training_state.policy_params obs = jnp.concatenate([env_state.obs, skills], axis=1) @@ -284,10 +284,10 @@ def play_step_fn( else: state_desc = jnp.zeros((env_state.obs.shape[0], 2)) - actions, random_key = self.select_action( + actions, key = self.select_action( obs=obs, policy_params=policy_params, - random_key=random_key, + key=key, deterministic=deterministic, ) @@ -326,24 +326,17 @@ def play_step_fn( ) if not evaluation: training_state = training_state.replace( - random_key=random_key, + key=key, normalization_running_stats=normalization_running_stats, ) else: training_state = training_state.replace( - random_key=random_key, + key=key, ) return next_env_state, training_state, transition - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - "env_batch_size", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size")) def eval_policy_fn( self, training_state: DadsTrainingState, @@ -479,13 +472,13 @@ def _update_networks( Args: training_state: the current training state of the algorithm. transitions: transitions sampled from a replay buffer. - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. Returns: The updated training state and training metrics. """ - random_key = training_state.random_key + key = training_state.key # Update skill-dynamics ( @@ -504,12 +497,12 @@ def _update_networks( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, + key, ) = self._update_alpha( alpha_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update critic @@ -518,14 +511,14 @@ def _update_networks( target_critic_params, critic_optimizer_state, critic_loss, - random_key, + key, ) = self._update_critic( critic_lr=self._config.learning_rate, reward_scaling=self._config.reward_scaling, discount=self._config.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update actor @@ -533,12 +526,12 @@ def _update_networks( policy_params, policy_optimizer_state, policy_loss, - random_key, + key, ) = self._update_actor( policy_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # Create new training state @@ -552,7 +545,7 @@ def _update_networks( target_critic_params=target_critic_params, dynamics_optimizer_state=dynamics_optimizer_state, dynamics_params=dynamics_params, - random_key=random_key, + key=key, normalization_running_stats=training_state.normalization_running_stats, steps=training_state.steps + 1, ) @@ -589,9 +582,9 @@ def update( """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + transitions, key = replay_buffer.sample( + key, sample_size=self._config.batch_size, ) diff --git a/qdax/baselines/dads_smerl.py b/qdax/baselines/dads_smerl.py index 5bd8274d..7bca44e8 100644 --- a/qdax/baselines/dads_smerl.py +++ b/qdax/baselines/dads_smerl.py @@ -92,14 +92,17 @@ def update( the replay buffer the training metrics """ + key = training_state.key # Sample a batch of transitions in the buffer - random_key = training_state.random_key - samples, returns, random_key = replay_buffer.sample_with_returns( - random_key, + key, subkey = jax.random.split(key) + samples, returns = replay_buffer.sample_with_returns( + subkey, sample_size=self._config.batch_size, ) + training_state = training_state.update(key=key) + # Optionally replace the state descriptor by the observation if self._config.descriptor_full_state: _state_desc = samples.obs[:, : -self._config.num_skills] diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index 0ebdfc32..84c18aad 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -35,7 +35,7 @@ class DiaynTrainingState(TrainingState): target_critic_params: Params discriminator_optimizer_state: optax.OptState discriminator_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -116,7 +116,7 @@ def __init__(self, config: DiaynConfig, action_size: int): def init( # type: ignore self, - random_key: RNGKey, + key: RNGKey, action_size: int, observation_size: int, descriptor_size: int, @@ -124,7 +124,7 @@ def init( # type: ignore """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space descriptor_size: the size of the environment's descriptor space (i.e. the @@ -139,17 +139,17 @@ def init( # type: ignore dummy_action = jnp.zeros((1, action_size)) dummy_discriminator_obs = jnp.zeros((1, descriptor_size)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) policy_params = self._policy.init(subkey, dummy_obs) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) target_critic_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), critic_params ) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) discriminator_params = self._discriminator.init( subkey, obs=dummy_discriminator_obs ) @@ -173,7 +173,7 @@ def init( # type: ignore target_critic_params=target_critic_params, discriminator_optimizer_state=discriminator_optimizer_state, discriminator_params=discriminator_params, - random_key=random_key, + key=key, steps=jnp.array(0), ) @@ -239,7 +239,7 @@ def play_step_fn( the played transition """ - random_key = training_state.random_key + key = training_state.key policy_params = training_state.policy_params obs = jnp.concatenate([env_state.obs, skills], axis=1) @@ -249,10 +249,10 @@ def play_step_fn( else: state_desc = jnp.zeros((env_state.obs.shape[0], 2)) - actions, random_key = self.select_action( + actions, key = self.select_action( obs=obs, policy_params=policy_params, - random_key=random_key, + key=key, deterministic=deterministic, ) @@ -273,18 +273,11 @@ def play_step_fn( actions=actions, truncations=truncations, ) - training_state = training_state.replace(random_key=random_key) + training_state = training_state.replace(key=key) return next_env_state, training_state, transition - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - "env_batch_size", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size")) def eval_policy_fn( self, training_state: DiaynTrainingState, @@ -382,12 +375,12 @@ def _update_networks( Args: training_state: the current training state. transitions: transitions sampled from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The update training state, metrics and a new random key. """ - random_key = training_state.random_key + key = training_state.key # Compute discriminator loss and gradients discriminator_loss, discriminator_gradient = jax.value_and_grad( @@ -413,12 +406,12 @@ def _update_networks( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, + key, ) = self._update_alpha( alpha_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update critic @@ -427,14 +420,14 @@ def _update_networks( target_critic_params, critic_optimizer_state, critic_loss, - random_key, + key, ) = self._update_critic( critic_lr=self._config.learning_rate, reward_scaling=self._config.reward_scaling, discount=self._config.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update actor @@ -442,12 +435,12 @@ def _update_networks( policy_params, policy_optimizer_state, policy_loss, - random_key, + key, ) = self._update_actor( policy_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # Create new training state @@ -461,7 +454,7 @@ def _update_networks( target_critic_params=target_critic_params, discriminator_optimizer_state=discriminator_optimizer_state, discriminator_params=discriminator_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, ) metrics = { @@ -492,9 +485,9 @@ def update( the training metrics """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + transitions, key = replay_buffer.sample( + key, sample_size=self._config.batch_size, ) diff --git a/qdax/baselines/diayn_smerl.py b/qdax/baselines/diayn_smerl.py index daacaa74..fb5651af 100644 --- a/qdax/baselines/diayn_smerl.py +++ b/qdax/baselines/diayn_smerl.py @@ -99,14 +99,17 @@ def update( the replay buffer the training metrics """ - # Sample a batch of transitions in the buffer - random_key = training_state.random_key + key = training_state.key - samples, returns, random_key = replay_buffer.sample_with_returns( - random_key, + # Sample a batch of transitions in the buffer + key, subkey = jax.random.split(key) + samples, returns = replay_buffer.sample_with_returns( + subkey, sample_size=self._config.batch_size, ) + training_state = training_state.update(key=key) + # Optionally replace the state descriptor by the observation if self._config.descriptor_full_state: state_desc = samples.obs[:, : -self._config.num_skills] diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index b4c6a32f..ef23483e 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -28,9 +28,7 @@ class GeneticAlgorithm: def __init__( self, - scoring_function: Callable[ - [Genotype, RNGKey], Tuple[Fitness, ExtraScores, RNGKey] - ], + scoring_function: Callable[[Genotype, RNGKey], Tuple[Fitness, ExtraScores]], emitter: Emitter, metrics_function: Callable[[GARepertoire], Metrics], ) -> None: @@ -40,23 +38,22 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, genotypes: Genotype, population_size: int, random_key: RNGKey - ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: + self, genotypes: Genotype, population_size: int, key: RNGKey + ) -> Tuple[GARepertoire, Optional[EmitterState]]: """Initialize a GARepertoire with an initial population of genotypes. Args: genotypes: the initial population of genotypes population_size: the maximal size of the repertoire - random_key: a random key to handle stochastic operations + key: a random key to handle stochastic operations Returns: The initial repertoire, an initial emitter state and a new random key. """ # score initial genotypes - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = GARepertoire.init( @@ -66,8 +63,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -75,15 +73,15 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: GARepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics, RNGKey]: + key: RNGKey, + ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics]: """ Performs one iteration of a Genetic algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -94,7 +92,7 @@ def update( Args: repertoire: a repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the updated MAP-Elites repertoire @@ -104,14 +102,12 @@ def update( """ # generate offsprings - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) # score the offsprings - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # update the repertoire repertoire = repertoire.add(genotypes, fitnesses) @@ -129,7 +125,7 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics @partial(jax.jit, static_argnames=("self",)) def scan_update( @@ -149,9 +145,10 @@ def scan_update( The updated repertoire and emitter state, with a new random key and metrics. """ # iterate over grid - repertoire, emitter_state, random_key = carry - repertoire, emitter_state, metrics, random_key = self.update( - repertoire, emitter_state, random_key + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = self.update( + repertoire, emitter_state, subkey ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state, key), metrics diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index afd587af..b802fa96 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -28,13 +28,12 @@ class NSGA2(GeneticAlgorithm): @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, genotypes: Genotype, population_size: int, random_key: RNGKey - ) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]: + self, genotypes: Genotype, population_size: int, key: RNGKey + ) -> Tuple[NSGA2Repertoire, Optional[EmitterState]]: # score initial genotypes - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = NSGA2Repertoire.init( @@ -44,8 +43,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -62,4 +62,4 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/baselines/pbt.py b/qdax/baselines/pbt.py index 6555c537..1e8b5a89 100644 --- a/qdax/baselines/pbt.py +++ b/qdax/baselines/pbt.py @@ -88,29 +88,28 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def update_states_and_buffer( self, - random_key: RNGKey, + key: RNGKey, population_returns: jnp.ndarray, training_state: PBTTrainingState, replay_buffer: ReplayBuffer, - ) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]: + ) -> Tuple[PBTTrainingState, ReplayBuffer]: """ Updates the agents of the population states as well as their shared replay buffer. Args: - random_key: Random RNG key. + key: Random key. population_returns: Returns of the agents in the populations. training_state: The training state of the PBT scheme. replay_buffer: Shared replay buffer by the agents. Returns: - Updated random key, updated PBT training state and updated replay buffer. + Updated PBT training state and updated replay buffer. """ indices_sorted = jax.numpy.argsort(-population_returns) best_indices = indices_sorted[: self._num_best_to_replace_from] indices_to_replace = indices_sorted[-self._num_worse_to_replace :] - random_key, key = jax.random.split(random_key) indices_used_to_replace = jax.random.choice( key, best_indices, shape=(self._num_worse_to_replace,), replace=True ) @@ -127,7 +126,7 @@ def update_states_and_buffer( replay_buffer, ) - return random_key, training_state, replay_buffer + return training_state, replay_buffer @partial(jax.jit, static_argnames=("self",)) def update_states_and_buffer_pmap( @@ -136,7 +135,7 @@ def update_states_and_buffer_pmap( population_returns: jnp.ndarray, training_state: PBTTrainingState, replay_buffer: ReplayBuffer, - ) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]: + ) -> Tuple[PBTTrainingState, ReplayBuffer]: """ Updates the agents of the population states as well as their shared replay buffer. This is the version of the function to be @@ -190,4 +189,4 @@ def update_states_and_buffer_pmap( gathered_best_buffers, ) - return random_key, training_state, replay_buffer + return training_state, replay_buffer diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index 193defea..a80d5b7c 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -54,7 +54,7 @@ class SacTrainingState(TrainingState): alpha_optimizer_state: optax.OptState alpha_params: Params target_critic_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray normalization_running_stats: RunningMeanStdState @@ -95,12 +95,12 @@ def __init__(self, config: SacConfig, action_size: int) -> None: self._sample_action_fn = self._parametric_action_distribution.sample def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> SacTrainingState: """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space @@ -112,10 +112,10 @@ def init( dummy_obs = jnp.zeros((1, observation_size)) dummy_action = jnp.zeros((1, action_size)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) policy_params = self._policy.init(subkey, dummy_obs) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) target_critic_params = jax.tree_util.tree_map( @@ -148,7 +148,7 @@ def init( ), count=jnp.zeros(()), ), - random_key=random_key, + key=key, steps=jnp.array(0), ) @@ -159,15 +159,15 @@ def select_action( self, obs: Observation, policy_params: Params, - random_key: RNGKey, + key: RNGKey, deterministic: bool = False, - ) -> Tuple[Action, RNGKey]: + ) -> Action: """Selects an action acording to SAC policy. Args: obs: agent observation(s) policy_params: parameters of the agent's policy - random_key: jax random key + key: jax random key deterministic: whether to select action in a deterministic way. Defaults to False. @@ -177,14 +177,13 @@ def select_action( dist_params = self._policy.apply(policy_params, obs) if not deterministic: - random_key, key_sample = jax.random.split(random_key) - actions = self._sample_action_fn(dist_params, key_sample) + actions = self._sample_action_fn(dist_params, key) else: # The first half of parameters is for mean and the second half for variance actions = jax.nn.tanh(dist_params[..., : dist_params.shape[-1] // 2]) - return actions, random_key + return actions @partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation")) def play_step_fn( @@ -212,7 +211,7 @@ def play_step_fn( the new SAC training state the played transition """ - random_key = training_state.random_key + key = training_state.key policy_params = training_state.policy_params obs = env_state.obs @@ -228,21 +227,21 @@ def play_step_fn( normalized_obs = obs normalization_running_stats = training_state.normalization_running_stats - actions, random_key = self.select_action( + actions, key = self.select_action( obs=normalized_obs, policy_params=policy_params, - random_key=random_key, + key=key, deterministic=deterministic, ) if not evaluation: training_state = training_state.replace( - random_key=random_key, + key=key, normalization_running_stats=normalization_running_stats, ) else: training_state = training_state.replace( - random_key=random_key, + key=key, ) next_env_state = env.step(env_state, actions) @@ -317,13 +316,7 @@ def play_qd_step_fn( transition, ) - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn")) def eval_policy_fn( self, training_state: SacTrainingState, @@ -422,8 +415,8 @@ def _update_alpha( alpha_lr: float, training_state: SacTrainingState, transitions: Transition, - random_key: RNGKey, - ) -> Tuple[Params, optax.OptState, jnp.ndarray, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, optax.OptState, jnp.ndarray]: """Updates the alpha parameter if necessary. Else, it keeps the current value. @@ -431,14 +424,13 @@ def _update_alpha( alpha_lr: alpha learning rate training_state: the current training state. transitions: a sample of transitions from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: - New alpha params, optimizer state, loss and a new random key. + New alpha params, optimizer state, and loss. """ if not self._config.fix_alpha: # update alpha - random_key, subkey = jax.random.split(random_key) alpha_loss, alpha_gradient = jax.value_and_grad(sac_alpha_loss_fn)( training_state.alpha_params, policy_fn=self._policy.apply, @@ -446,7 +438,7 @@ def _update_alpha( action_size=self._action_size, policy_params=training_state.policy_params, transitions=transitions, - random_key=subkey, + key=key, ) alpha_optimizer = optax.adam(learning_rate=alpha_lr) ( @@ -463,7 +455,7 @@ def _update_alpha( alpha_optimizer_state = training_state.alpha_optimizer_state alpha_loss = jnp.array(0.0) - return alpha_params, alpha_optimizer_state, alpha_loss, random_key + return alpha_params, alpha_optimizer_state, alpha_loss @partial(jax.jit, static_argnames=("self",)) def _update_critic( @@ -473,7 +465,7 @@ def _update_critic( discount: float, training_state: SacTrainingState, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[Params, Params, optax.OptState, jnp.ndarray, RNGKey]: """Updates the critic following the method described in the Soft Actor Critic paper. @@ -484,14 +476,14 @@ def _update_critic( discount: discount factor training_state: the current training state. transitions: a batch of transitions sampled from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: New parameters of the critic and its target. New optimizer state, loss and a new random key. """ # update critic - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_loss, critic_gradient = jax.value_and_grad(sac_critic_loss_fn)( training_state.critic_params, policy_fn=self._policy.apply, @@ -503,7 +495,7 @@ def _update_critic( target_critic_params=training_state.target_critic_params, alpha=jnp.exp(training_state.alpha_params), transitions=transitions, - random_key=subkey, + key=subkey, ) critic_optimizer = optax.adam(learning_rate=critic_lr) ( @@ -526,7 +518,7 @@ def _update_critic( target_critic_params, critic_optimizer_state, critic_loss, - random_key, + key, ) @partial(jax.jit, static_argnames=("self",)) @@ -535,8 +527,8 @@ def _update_actor( policy_lr: float, training_state: SacTrainingState, transitions: Transition, - random_key: RNGKey, - ) -> Tuple[Params, optax.OptState, jnp.ndarray, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, optax.OptState, jnp.ndarray]: """Updates the actor parameters following the stochastic policy gradient theorem with the method introduced in SAC. @@ -545,12 +537,11 @@ def _update_actor( training_state: the current training state. transitions: a batch of transitions sampled from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: - New params and optimizer state. Current loss. New random key. + New params and optimizer state. Current loss. """ - random_key, subkey = jax.random.split(random_key) policy_loss, policy_gradient = jax.value_and_grad(sac_policy_loss_fn)( training_state.policy_params, policy_fn=self._policy.apply, @@ -559,7 +550,7 @@ def _update_actor( critic_params=training_state.critic_params, alpha=jnp.exp(training_state.alpha_params), transitions=transitions, - random_key=subkey, + key=key, ) policy_optimizer = optax.adam(learning_rate=policy_lr) ( @@ -572,7 +563,7 @@ def _update_actor( training_state.policy_params, policy_updates ) - return policy_params, policy_optimizer_state, policy_loss, random_key + return policy_params, policy_optimizer_state, policy_loss @partial(jax.jit, static_argnames=("self",)) def update( @@ -593,9 +584,9 @@ def update( """ # sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + transitions, key = replay_buffer.sample( + key, sample_size=self._config.batch_size, ) @@ -617,12 +608,12 @@ def update( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, + key, ) = self._update_alpha( alpha_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update critic @@ -631,14 +622,14 @@ def update( target_critic_params, critic_optimizer_state, critic_loss, - random_key, + key, ) = self._update_critic( critic_lr=self._config.learning_rate, reward_scaling=self._config.reward_scaling, discount=self._config.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update actor @@ -646,12 +637,12 @@ def update( policy_params, policy_optimizer_state, policy_loss, - random_key, + key, ) = self._update_actor( policy_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # create new training state @@ -664,7 +655,7 @@ def update( alpha_params=alpha_params, normalization_running_stats=training_state.normalization_running_stats, target_critic_params=target_critic_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, ) metrics = { diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index 03e0bd82..09a33fe9 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -75,20 +75,20 @@ def resample_hyperparams( training_state: "PBTSacTrainingState", ) -> "PBTSacTrainingState": - random_key = training_state.random_key - random_key, sub_key = jax.random.split(random_key) + key = training_state.key + key, sub_key = jax.random.split(key) discount = jax.random.uniform(sub_key, shape=(), minval=0.9, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) policy_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) critic_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) alpha_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) reward_scaling = jax.random.uniform(sub_key, shape=(), minval=0.1, maxval=10.0) return training_state.replace( # type: ignore @@ -97,7 +97,7 @@ def resample_hyperparams( critic_lr=critic_lr, alpha_lr=alpha_lr, reward_scaling=reward_scaling, - random_key=random_key, + key=key, ) @@ -135,12 +135,12 @@ def __init__(self, config: PBTSacConfig, action_size: int) -> None: SAC.__init__(self, config=sac_config, action_size=action_size) def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> PBTSacTrainingState: """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space @@ -148,7 +148,7 @@ def init( the initial training state of PBT-SAC """ - sac_training_state = SAC.init(self, random_key, action_size, observation_size) + sac_training_state = SAC.init(self, key, action_size, observation_size) training_state = PBTSacTrainingState( policy_optimizer_state=sac_training_state.policy_optimizer_state, @@ -159,7 +159,7 @@ def init( alpha_params=sac_training_state.alpha_params, target_critic_params=sac_training_state.target_critic_params, normalization_running_stats=sac_training_state.normalization_running_stats, - random_key=sac_training_state.random_key, + key=sac_training_state.key, steps=sac_training_state.steps, discount=None, policy_lr=None, @@ -192,9 +192,9 @@ def update( """ # sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + transitions, key = replay_buffer.sample( + key, sample_size=self._config.batch_size, ) @@ -216,12 +216,12 @@ def update( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, + key, ) = self._update_alpha( alpha_lr=training_state.alpha_lr, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update critic @@ -230,14 +230,14 @@ def update( target_critic_params, critic_optimizer_state, critic_loss, - random_key, + key, ) = self._update_critic( critic_lr=training_state.critic_lr, reward_scaling=training_state.reward_scaling, discount=training_state.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # update actor @@ -245,12 +245,12 @@ def update( policy_params, policy_optimizer_state, policy_loss, - random_key, + key, ) = self._update_actor( policy_lr=training_state.policy_lr, training_state=training_state, transitions=transitions, - random_key=random_key, + key=key, ) # create new training state @@ -263,7 +263,7 @@ def update( alpha_params=alpha_params, normalization_running_stats=training_state.normalization_running_stats, target_critic_params=target_critic_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, discount=training_state.discount, policy_lr=training_state.policy_lr, @@ -302,10 +302,10 @@ def get_init_fn( """ def _init_fn( - random_key: RNGKey, - ) -> Tuple[RNGKey, PBTSacTrainingState, ReplayBuffer]: + key: RNGKey, + ) -> Tuple[PBTSacTrainingState, ReplayBuffer]: - random_key, *keys = jax.random.split(random_key, num=1 + population_size) + key, *keys = jax.random.split(key, num=population_size + 1) keys = jnp.stack(keys) init_dummy_transition = partial( @@ -328,7 +328,7 @@ def _init_fn( self.init, action_size=action_size, observation_size=observation_size ) training_states = jax.vmap(agent_init)(keys) - return random_key, training_states, replay_buffers + return training_states, replay_buffers return _init_fn @@ -367,18 +367,19 @@ def get_eval_qd_fn( descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor], ) -> Callable: """ - Returns the function the evaluation the PBT population. + Returns the evaluation function of the PBT population. Args: eval_env: evaluation environment. Might be different from training env if needed. - descriptor_extraction_fn: function to extract the descriptor from an episode. + descriptor_extraction_fn: function to extract the descriptor from an + episode. Returns: The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the - population agents mean returns and mean descriptors over episodes as well as all - returns and descriptors from all agents over all episodes. + population agents mean returns and mean descriptors over episodes, + as well as allreturns and descriptors from all agents over all episodes. """ play_eval_step = partial( self.play_qd_step_fn, diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index 10d195ad..49552d9d 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -30,26 +30,18 @@ class SPEA2(GeneticAlgorithm): b13724cb54ae4171916f3f969d304b9e9752a57f" """ - @partial( - jax.jit, - static_argnames=( - "self", - "population_size", - "num_neighbours", - ), - ) + @partial(jax.jit, static_argnames=("self", "population_size", "num_neighbours")) def init( self, genotypes: Genotype, population_size: int, num_neighbours: int, - random_key: RNGKey, - ) -> Tuple[SPEA2Repertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[SPEA2Repertoire, Optional[EmitterState]]: # score initial genotypes - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = SPEA2Repertoire.init( @@ -60,8 +52,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -78,4 +71,4 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index 32470ef6..ed109fbb 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -44,7 +44,7 @@ class TD3TrainingState(TrainingState): critic_params: Params target_critic_params: Params target_policy_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -86,13 +86,13 @@ def __init__(self, config: TD3Config, action_size: int): ) def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> TD3TrainingState: """Initialise the training state of the TD3 algorithm, through creation of optimizer states and params. Args: - random_key: a random key used for random operations. + key: a random key used for random operations. action_size: the size of the action array needed to interact with the environment. observation_size: the size of the observation array retrieved from the @@ -105,7 +105,7 @@ def init( # Initialize critics and policy params fake_obs = jnp.zeros(shape=(observation_size,)) fake_action = jnp.zeros(shape=(action_size,)) - random_key, subkey_1, subkey_2 = jax.random.split(random_key, num=3) + key, subkey_1, subkey_2 = jax.random.split(key, num=3) critic_params = self._critic.init(subkey_1, obs=fake_obs, actions=fake_action) policy_params = self._policy.init(subkey_2, fake_obs) @@ -129,7 +129,7 @@ def init( critic_params=critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, - random_key=random_key, + key=key, steps=jnp.array(0), ) @@ -140,17 +140,17 @@ def select_action( self, obs: Observation, policy_params: Params, - random_key: RNGKey, + key: RNGKey, expl_noise: float, deterministic: bool = False, - ) -> Tuple[Action, RNGKey]: + ) -> Action: """Selects an action according to TD3 policy. The action can be deterministic or stochastic by adding exploration noise. Args: obs: agent observation(s) policy_params: parameters of the agent's policy - random_key: jax random key + key: jax random key expl_noise: exploration noise deterministic: whether to select action in a deterministic way. Defaults to False. @@ -161,11 +161,10 @@ def select_action( actions = self._policy.apply(policy_params, obs) if not deterministic: - random_key, subkey = jax.random.split(random_key) - noise = jax.random.normal(subkey, actions.shape) * expl_noise + noise = jax.random.normal(key, actions.shape) * expl_noise actions = actions + noise actions = jnp.clip(actions, -1.0, 1.0) - return actions, random_key + return actions @partial(jax.jit, static_argnames=("self", "env", "deterministic")) def play_step_fn( @@ -191,15 +190,15 @@ def play_step_fn( the played transition """ - actions, random_key = self.select_action( + actions, key = self.select_action( obs=env_state.obs, policy_params=training_state.policy_params, - random_key=training_state.random_key, + key=training_state.key, expl_noise=self._config.expl_noise, deterministic=deterministic, ) training_state = training_state.replace( - random_key=random_key, + key=key, ) next_env_state = env.step(env_state, actions) transition = Transition( @@ -258,13 +257,7 @@ def play_qd_step_fn( transition, ) - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn")) def eval_policy_fn( self, training_state: TD3TrainingState, @@ -376,13 +369,11 @@ def update( """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - samples, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size - ) + key = training_state.key + samples, key = replay_buffer.sample(key, sample_size=self._config.batch_size) # Update Critic - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)( training_state.critic_params, target_policy_params=training_state.target_policy_params, @@ -394,7 +385,7 @@ def update( reward_scaling=self._config.reward_scaling, discount=self._config.discount, transitions=samples, - random_key=subkey, + key=subkey, ) critic_optimizer = optax.adam(learning_rate=self._config.critic_learning_rate) critic_updates, critic_optimizer_state = critic_optimizer.update( @@ -463,7 +454,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer_state=policy_optimizer_state, target_critic_params=target_critic_params, target_policy_params=target_policy_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, ) diff --git a/qdax/baselines/td3_pbt.py b/qdax/baselines/td3_pbt.py index 7165d3fc..721f7f43 100644 --- a/qdax/baselines/td3_pbt.py +++ b/qdax/baselines/td3_pbt.py @@ -80,23 +80,23 @@ def resample_hyperparams( cls, training_state: "PBTTD3TrainingState" ) -> "PBTTD3TrainingState": - random_key = training_state.random_key - random_key, sub_key = jax.random.split(random_key) + key = training_state.key + key, sub_key = jax.random.split(key) discount = jax.random.uniform(sub_key, shape=(), minval=0.9, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) policy_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) critic_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) noise_clip = jax.random.uniform(sub_key, shape=(), minval=0.0, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) policy_noise = jax.random.uniform(sub_key, shape=(), minval=0.0, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) expl_noise = jax.random.uniform(sub_key, shape=(), minval=0.0, maxval=0.2) return training_state.replace( # type: ignore @@ -105,7 +105,7 @@ def resample_hyperparams( critic_lr=critic_lr, noise_clip=noise_clip, policy_noise=policy_noise, - random_key=random_key, + key=key, expl_noise=expl_noise, ) @@ -138,13 +138,13 @@ def __init__(self, config: PBTTD3Config, action_size: int): TD3.__init__(self, td3_config, action_size) def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> PBTTD3TrainingState: """Initialise the training state of the PBT-TD3 algorithm, through creation of optimizer states and params. Args: - random_key: a random key used for random operations. + key: a random key used for random operations. action_size: the size of the action array needed to interact with the environment. observation_size: the size of the observation array retrieved from the @@ -154,7 +154,7 @@ def init( the initial training state. """ - training_state = TD3.init(self, random_key, action_size, observation_size) + training_state = TD3.init(self, key, action_size, observation_size) # Initial training state training_state = PBTTD3TrainingState( @@ -164,7 +164,7 @@ def init( critic_params=training_state.critic_params, target_policy_params=training_state.target_policy_params, target_critic_params=training_state.target_critic_params, - random_key=training_state.random_key, + key=training_state.key, steps=training_state.steps, discount=None, policy_lr=None, @@ -203,15 +203,15 @@ def play_step_fn( the played transition """ - actions, random_key = self.select_action( + actions, key = self.select_action( obs=env_state.obs, policy_params=training_state.policy_params, - random_key=training_state.random_key, + key=training_state.key, expl_noise=training_state.expl_noise, deterministic=deterministic, ) training_state = training_state.replace( - random_key=random_key, + key=key, ) next_env_state = env.step(env_state, actions) transition = Transition( @@ -245,13 +245,11 @@ def update( """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - samples, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size - ) + key = training_state.key + samples, key = replay_buffer.sample(key, sample_size=self._config.batch_size) # Update Critic - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)( training_state.critic_params, target_policy_params=training_state.target_policy_params, @@ -263,7 +261,7 @@ def update( reward_scaling=self._config.reward_scaling, discount=self._config.discount, transitions=samples, - random_key=subkey, + key=subkey, ) critic_optimizer = optax.adam(learning_rate=training_state.critic_lr) critic_updates, critic_optimizer_state = critic_optimizer.update( @@ -330,7 +328,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer_state=policy_optimizer_state, target_critic_params=target_critic_params, target_policy_params=target_policy_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, ) @@ -363,9 +361,9 @@ def get_init_fn( """ def _init_fn( - random_key: RNGKey, - ) -> Tuple[RNGKey, PBTTD3TrainingState, ReplayBuffer]: - random_key, *keys = jax.random.split(random_key, num=1 + population_size) + key: RNGKey, + ) -> Tuple[PBTTD3TrainingState, ReplayBuffer]: + key, *keys = jax.random.split(key, num=population_size + 1) keys = jnp.stack(keys) init_dummy_transition = partial( @@ -388,7 +386,7 @@ def _init_fn( self.init, action_size=action_size, observation_size=observation_size ) training_states = jax.vmap(agent_init)(keys) - return random_key, training_states, replay_buffers + return training_states, replay_buffers return _init_fn @@ -432,13 +430,14 @@ def get_eval_qd_fn( Args: eval_env: evaluation environment. Might be different from training env if needed. - descriptor_extraction_fn: function to extract the descriptor from an episode. + descriptor_extraction_fn: function to extract the descriptor from an + episode. Returns: The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the - population agents mean returns and mean descriptors over episodes as well as all - returns and descriptors from all agents over all episodes. + population agents mean returns and mean descriptors over episodes, + as well as all returns and descriptors from all agents over all episodes. """ play_eval_step = partial( self.play_qd_step_fn, diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 5d532a5b..4a9ea5fc 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -42,7 +42,7 @@ def __init__( self, scoring_function: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ArrayTree, RNGKey], + Tuple[Fitness, Descriptor, ArrayTree], ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], @@ -62,11 +62,11 @@ def train( repertoire: UnstructuredRepertoire, model_params: Params, iteration: int, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[UnstructuredRepertoire, AuroraExtraInfo]: - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) aurora_extra_info = self._train_fn( - random_key, + key, repertoire, model_params, iteration, @@ -122,8 +122,8 @@ def init( aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, - random_key: RNGKey, - ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], AuroraExtraInfo, RNGKey]: + key: RNGKey, + ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], AuroraExtraInfo]: """Initialize an unstructured repertoire with an initial population of genotypes. Also performs the first training of the AURORA encoder. @@ -134,15 +134,16 @@ def init( such as the encoder parameters l_value: threshold distance for the unstructured repertoire max_size: maximum size of the repertoire - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: an initialized unstructured repertoire, with the initial state of the emitter, and the updated information to perform AURORA encodings """ - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function( genotypes, - random_key, + subkey, ) observations = extra_scores["last_valid_observations"] @@ -159,8 +160,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -168,21 +170,20 @@ def init( extra_scores=extra_scores, ) - random_key, subkey = jax.random.split(random_key) repertoire, updated_aurora_extra_info = self.train( - repertoire, aurora_extra_info.model_params, iteration=0, random_key=subkey + repertoire, aurora_extra_info.model_params, iteration=0, key=key ) - return repertoire, emitter_state, updated_aurora_extra_info, random_key + return repertoire, emitter_state, updated_aurora_extra_info @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, + key: RNGKey, aurora_extra_info: AuroraExtraInfo, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """Main step of the AURORA algorithm. @@ -194,7 +195,7 @@ def update( Args: repertoire: unstructured repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key aurora_extra_info: extra info for computing encodings Results: @@ -204,14 +205,14 @@ def update( a new key """ # generate offsprings with the emitter - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) # scores the offsprings - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function( genotypes, - random_key, + subkey, ) observations = extra_scores["last_valid_observations"] @@ -239,4 +240,4 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index 0e9b4084..ef3c12c7 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -166,28 +166,25 @@ def init(self) -> CMAESState: ) @partial(jax.jit, static_argnames=("self",)) - def sample( - self, cmaes_state: CMAESState, random_key: RNGKey - ) -> Tuple[Genotype, RNGKey]: + def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype: """ Sample a population. Args: cmaes_state: current state of the algorithm - random_key: jax random key + key: jax random key Returns: A tuple that contains a batch of population size genotypes and a new random key. """ - random_key, subkey = jax.random.split(random_key) samples = jax.random.multivariate_normal( - subkey, + key, shape=(self._population_size,), mean=cmaes_state.mean, cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix, ) - return samples, random_key + return samples @partial(jax.jit, static_argnames=("self",)) def update_state( diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_repertoire.py index 403331ff..b1a6cc73 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_repertoire.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import partial -from typing import Callable, Tuple +from typing import Callable import jax import jax.numpy as jnp @@ -78,11 +78,11 @@ def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire: ) @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample genotypes from the repertoire. Args: - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. num_samples: the number of genotypes to sample. Returns: @@ -94,15 +94,14 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1)) # sample - random_key, subkey = jax.random.split(random_key) samples = jax.tree_util.tree_map( lambda x: jax.random.choice( - subkey, x, shape=(num_samples,), p=p, replace=False + key, x, shape=(num_samples,), p=p, replace=False ), self.genotypes, ) - return samples, random_key + return samples @jax.jit def add( diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index 74845746..848b07e0 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -31,8 +31,8 @@ def compute_cvt_centroids( num_centroids: int, minval: Union[float, List[float]], maxval: Union[float, List[float]], - random_key: RNGKey, -) -> Tuple[jnp.ndarray, RNGKey]: + key: RNGKey, +) -> Centroid: """Compute centroids for CVT tessellation. Args: @@ -44,31 +44,30 @@ def compute_cvt_centroids( num_centroids: number of centroids minval: minimum descriptors value maxval: maximum descriptors value - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the centroids with shape (num_centroids, num_descriptors) - random_key: an updated jax PRNG random key + key: an updated jax PRNG random key """ minval = jnp.array(minval) maxval = jnp.array(maxval) # assume here all values are in [0, 1] and rescale later - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) x = jax.random.uniform(key=subkey, shape=(num_init_cvt_samples, num_descriptors)) # compute k means - random_key, subkey = jax.random.split(random_key) k_means = KMeans( init="k-means++", n_clusters=num_centroids, n_init=1, - random_state=RandomState(subkey), + random_state=RandomState(jax.random.key_data(key)), ) k_means.fit(x) centroids = k_means.cluster_centers_ # rescale now - return jnp.asarray(centroids) * (maxval - minval) + minval, random_key + return jnp.asarray(centroids) * (maxval - minval) + minval def compute_euclidean_centroids( @@ -212,60 +211,57 @@ def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesReperto ) @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample elements in the repertoire. Args: - random_key: a jax PRNG random key + key: a jax PRNG random key num_samples: the number of elements to be sampled Returns: samples: a batch of genotypes sampled in the repertoire - random_key: an updated jax PRNG random key + key: an updated jax PRNG random key """ repertoire_empty = self.fitnesses == -jnp.inf p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) - random_key, subkey = jax.random.split(random_key) samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.genotypes, ) - return samples, random_key + return samples @partial(jax.jit, static_argnames=("num_samples",)) def sample_with_descs( self, - random_key: RNGKey, + key: RNGKey, num_samples: int, - ) -> Tuple[Genotype, Descriptor, RNGKey]: + ) -> Tuple[Genotype, Descriptor]: """Sample elements in the repertoire. Args: - random_key: a jax PRNG random key + key: a jax PRNG random key num_samples: the number of elements to be sampled Returns: samples: a batch of genotypes sampled in the repertoire - random_key: an updated jax PRNG random key """ repertoire_empty = self.fitnesses == -jnp.inf p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) - random_key, subkey = jax.random.split(random_key) samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.genotypes, ) descs = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.descriptors, ) - return samples, descs, random_key + return samples, descs @jax.jit def add( diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 43be3835..3ba09024 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -64,18 +64,14 @@ def _sample_in_masked_pareto_front( self, pareto_front_genotypes: ParetoFront[Genotype], mask: Mask, - random_key: RNGKey, + key: RNGKey, ) -> Genotype: """Sample one single genotype in masked pareto front. - Note: do not retrieve a random key because this function - is to be vmapped. The public method that uses this function - will return a random key - Args: pareto_front_genotypes: the genotypes of a pareto front mask: a mask associated to the front - random_key: a random key to handle stochastic operations + key: a random key to handle stochastic operations Returns: A single genotype among the pareto front. @@ -83,25 +79,25 @@ def _sample_in_masked_pareto_front( p = (1.0 - mask) / jnp.sum(1.0 - mask) genotype_sample = jax.tree_util.tree_map( - lambda x: jax.random.choice(random_key, x, shape=(1,), p=p), + lambda x: jax.random.choice(key, x, shape=(1,), p=p), pareto_front_genotypes, ) return genotype_sample @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample elements in the repertoire. This method sample a non-empty pareto front, and then sample genotypes from this pareto front. Args: - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. num_samples: number of samples to retrieve from the repertoire. Returns: - A sample of genotypes and a new random key. + A sample of genotypes. """ # create sampling probability for the cells @@ -114,7 +110,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey indices = jnp.arange(start=0, stop=repertoire_empty.shape[0]) # choose idx - among indices of cells that are not empty - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p) # get genotypes (front) from the chosen indices @@ -126,12 +122,11 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front) # sample genotypes from the pareto front - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, num=num_samples) + subkeys = jax.random.split(key, num=num_samples) sampled_genotypes = sample_in_fronts( # type: ignore pareto_front_genotypes=pareto_front_genotypes, mask=repertoire_empty[cells_idx], - random_key=subkeys, + key=subkeys, ) # remove the dim coming from pareto front @@ -139,7 +134,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey lambda x: x.squeeze(axis=1), sampled_genotypes ) - return sampled_genotypes, random_key + return sampled_genotypes @jax.jit def _update_masked_pareto_front( diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py index 77c91683..2ed39784 100644 --- a/qdax/core/containers/repertoire.py +++ b/qdax/core/containers/repertoire.py @@ -28,13 +28,13 @@ def init(cls) -> Repertoire: # noqa: N805 @abstractmethod def sample( self, - random_key: RNGKey, + key: RNGKey, num_samples: int, ) -> Genotype: """Sample genotypes from the repertoire. Args: - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. num_samples: the number of genotypes to sample. Returns: diff --git a/qdax/core/containers/uniform_replacement_archive.py b/qdax/core/containers/uniform_replacement_archive.py index 830878cf..1eacb598 100644 --- a/qdax/core/containers/uniform_replacement_archive.py +++ b/qdax/core/containers/uniform_replacement_archive.py @@ -18,7 +18,7 @@ class UniformReplacementArchive(Archive): Most methods are inherited from Archive. """ - random_key: RNGKey + key: RNGKey @classmethod def create( # type: ignore @@ -26,7 +26,7 @@ def create( # type: ignore acceptance_threshold: float, state_descriptor_size: int, max_size: int, - random_key: RNGKey, + key: RNGKey, ) -> Archive: """Create an Archive instance. @@ -39,7 +39,7 @@ def create( # type: ignore state_descriptor_size: the number of elements in a state descriptor. max_size: the maximal size of the archive. In case of overflow, previous elements are replaced by new ones. Defaults to 80000. - random_key: a key to handle random operations. Defaults to key with + key: a key to handle random operations. Defaults to key with seed = 0. Returns: @@ -52,7 +52,7 @@ def create( # type: ignore max_size, ) - return archive.replace(random_key=random_key) # type: ignore + return archive.replace(key=key) # type: ignore @jax.jit def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: @@ -69,7 +69,7 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: new_current_position = self.current_position + 1 is_full = new_current_position >= self.max_size - random_key, subkey = jax.random.split(self.random_key) + key, subkey = jax.random.split(self.key) random_index = jax.random.randint( subkey, shape=(1,), minval=0, maxval=self.max_size ) @@ -79,5 +79,5 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: new_data = self.data.at[index].set(state_descriptor) return self.replace( # type: ignore - current_position=new_current_position, data=new_data, random_key=random_key + current_position=new_current_position, data=new_data, key=key ) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 82e6a7a3..7fb8860a 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -124,8 +124,8 @@ def intra_batch_comp( fitness, ).any() - # Discard Individuals with Nans as their descriptor (mainly for the readdition where we - # have NaN descriptors) + # Discard individuals with nan as their descriptor (mainly for the readdition + # where we have nan descriptors) discard_indiv = jnp.logical_or(discard_indiv, not_existent) # Negate to know if we keep the individual @@ -298,7 +298,9 @@ def add( -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] )[1] batch_of_indices = jnp.where( - jnp.squeeze(batch_of_distances.at[sorted_descriptors].get() <= self.l_value), + jnp.squeeze( + batch_of_distances.at[sorted_descriptors].get() <= self.l_value + ), batch_of_indices.at[sorted_descriptors].get(), empty_indexes, ) @@ -308,7 +310,7 @@ def add( # ReIndexing of all the inputs to the correct sorted way batch_of_descriptors = batch_of_descriptors.at[sorted_descriptors].get() batch_of_genotypes = jax.tree_util.tree_map( - lambda x: x.at[sorted_bds].get(), batch_of_genotypes + lambda x: x.at[sorted_descriptors].get(), batch_of_genotypes ) batch_of_fitnesses = batch_of_fitnesses.at[sorted_descriptors].get() batch_of_observations = batch_of_observations.at[sorted_descriptors].get() @@ -389,28 +391,25 @@ def add( ) @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample elements in the repertoire. Args: - random_key: a jax PRNG random key + key: a jax PRNG random key num_samples: the number of elements to be sampled Returns: samples: a batch of genotypes sampled in the repertoire - random_key: an updated jax PRNG random key """ - - random_key, sub_key = jax.random.split(random_key) grid_empty = self.fitnesses == -jnp.inf p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.genotypes, ) - return samples, random_key + return samples @classmethod def init( diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index dbc6522b..d76087b7 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -20,7 +20,7 @@ def init( self, genotypes: Genotype, centroids: Centroid, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: """ Initialize a Map-Elites repertoire with an initial population of genotypes. @@ -34,16 +34,14 @@ def init( genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: An initialized MAP-Elite repertoire with the initial state of the emitter, and a random key. """ # score initial genotypes - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, key) # gather across all devices ( @@ -64,8 +62,8 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + emitter_state = self._emitter.init( + key=key, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -83,15 +81,15 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state, key @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -105,7 +103,7 @@ def update( Args: repertoire: the MAP-Elites repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the updated MAP-Elites repertoire @@ -114,13 +112,12 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) + # scores the offsprings - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # gather across all devices ( @@ -150,7 +147,7 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics def get_distributed_init_fn( self, centroids: Centroid, devices: List[Any] @@ -198,38 +195,38 @@ def _scan_update( """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.""" # unwrap the input - repertoire, emitter_state, random_key = carry + repertoire, emitter_state, key = carry # apply one step of update ( repertoire, emitter_state, metrics, - random_key, + key, ) = self.update( repertoire, emitter_state, - random_key, + key, ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state, key), metrics def update_fn( repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, + key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]: """Apply num_iterations of update.""" ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( _scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) - return repertoire, emitter_state, random_key, metrics + return repertoire, emitter_state, key, metrics return jax.pmap(update_fn, devices=devices, axis_name="p") # type: ignore diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index 46d68193..6727cea0 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -28,7 +28,7 @@ class CMAEmitterState(EmitterState): Emitter state for the CMA-ME emitter. Args: - random_key: a random key to handle stochastic operations. Used for + key: a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm @@ -37,7 +37,7 @@ class CMAEmitterState(EmitterState): emit_count: count the number of emission events. """ - random_key: RNGKey + key: RNGKey cmaes_state: CMAESState previous_fitnesses: Fitness emit_count: int @@ -107,7 +107,7 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -120,7 +120,7 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. @@ -131,15 +131,15 @@ def init( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # return the initial state - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) return ( CMAEmitterState( - random_key=subkey, + key=subkey, cmaes_state=self._cma_initial_state, previous_fitnesses=default_fitnesses, emit_count=0, ), - random_key, + key, ) @partial(jax.jit, static_argnames=("self",)) @@ -147,8 +147,8 @@ def emit( self, repertoire: Optional[MapElitesRepertoire], emitter_state: CMAEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -157,22 +157,17 @@ def emit( Args: repertoire: a repertoire of genotypes (unused). emitter_state: the state of the CMA-MEGA emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: New genotypes and a new random key. """ # emit from CMA-ES - offsprings, random_key = self._cmaes.sample( - cmaes_state=emitter_state.cmaes_state, random_key=random_key - ) + offsprings = self._cmaes.sample(cmaes_state=emitter_state.cmaes_state, key=key) - return offsprings, {}, random_key + return offsprings, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: CMAEmitterState, @@ -181,7 +176,7 @@ def state_update( fitnesses: Fitness, descriptors: Descriptor, extra_scores: Optional[ExtraScores] = None, - ) -> Optional[EmitterState]: + ) -> CMAEmitterState: """ Updates the CMA-ME emitter state. @@ -250,14 +245,14 @@ def update_and_reinit( operand: Tuple[ CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey ], - ) -> Tuple[CMAEmitterState, RNGKey]: + ) -> CMAEmitterState: return self._update_and_init_emitter_state(*operand) def update_wo_reinit( operand: Tuple[ CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey ], - ) -> Tuple[CMAEmitterState, RNGKey]: + ) -> CMAEmitterState: """Update the emitter when no reinit event happened. Here lies a divergence compared to the original implementation. We @@ -284,7 +279,7 @@ def update_wo_reinit( instability. """ - (cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand + cmaes_state, emitter_state, _, emit_count, _ = operand # Update CMA Parameters mask = jnp.ones_like(sorted_improvements) @@ -298,10 +293,12 @@ def update_wo_reinit( emit_count=emit_count, ) - return emitter_state, random_key + return emitter_state # type: ignore # Update CMA Parameters - emitter_state, random_key = jax.lax.cond( + key = emitter_state.key + key, subkey = jax.random.split(key) + emitter_state = jax.lax.cond( reinitialize, update_and_reinit, update_wo_reinit, @@ -310,13 +307,14 @@ def update_wo_reinit( emitter_state, repertoire, emit_count, - emitter_state.random_key, + subkey, ), ) # update the emitter state emitter_state = emitter_state.replace( - random_key=random_key, previous_fitnesses=repertoire.fitnesses + previous_fitnesses=repertoire.fitnesses, + key=key, ) return emitter_state @@ -327,8 +325,8 @@ def _update_and_init_emitter_state( emitter_state: CMAEmitterState, repertoire: MapElitesRepertoire, emit_count: int, - random_key: RNGKey, - ) -> Tuple[CMAEmitterState, RNGKey]: + key: RNGKey, + ) -> CMAEmitterState: """Update the emitter state in the case of a reinit event. Reinit the cmaes state and use an individual from the repertoire as the starting mean. @@ -338,14 +336,14 @@ def _update_and_init_emitter_state( emitter_state: current cmame state repertoire: most recent repertoire emit_count: counter of the emitter - random_key: key to handle stochastic events + key: key to handle stochastic events Returns: The updated emitter state. """ # re-sample - random_genotype, random_key = repertoire.sample(random_key, 1) + random_genotype = repertoire.sample(key, 1) # remove the batch dim new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) @@ -356,7 +354,7 @@ def _update_and_init_emitter_state( cmaes_state=cmaes_init_state, emit_count=0 ) - return emitter_state, random_key + return emitter_state # type: ignore @abstractmethod def _ranking_criteria( diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index c3f87fed..a87230ab 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -30,7 +30,7 @@ class CMAMEGAState(EmitterState): Args: theta: current genotype from where candidates will be drawn. theta_grads: normalized fitness and descriptors gradients of theta. - random_key: a random key to handle stochastic operations. Used for + key: a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm @@ -40,7 +40,7 @@ class CMAMEGAState(EmitterState): theta: Genotype theta_grads: Gradient - random_key: RNGKey + key: RNGKey cmaes_state: CMAESState previous_fitnesses: Fitness @@ -49,7 +49,7 @@ class CMAMEGAEmitter(Emitter): def __init__( self, scoring_function: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores] ], batch_size: int, learning_rate: float, @@ -101,7 +101,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -114,7 +114,7 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. @@ -127,7 +127,7 @@ def init( ) # score it - _, _, extra_score, random_key = self._scoring_function(theta, random_key) + _, _, extra_score = self._scoring_function(theta, key) theta_grads = extra_score["normalized_grads"] # Initialize repertoire with default values @@ -135,16 +135,16 @@ def init( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # return the initial state - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) return ( CMAMEGAState( theta=theta, theta_grads=theta_grads, - random_key=subkey, + key=subkey, cmaes_state=self._cma_initial_state, previous_fitnesses=default_fitnesses, ), - random_key, + key, ) @partial(jax.jit, static_argnames=("self",)) @@ -152,8 +152,8 @@ def emit( self, repertoire: Optional[MapElitesRepertoire], emitter_state: CMAMEGAState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -162,10 +162,10 @@ def emit( Args: repertoire: a repertoire of genotypes (unused). emitter_state: the state of the CMA-MEGA emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: - New genotypes and a new random key. + New genotypes. """ # retrieve elements from the emitter state @@ -176,9 +176,7 @@ def emit( grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0)) # Draw random coefficients - use the emitter state key - coeffs, random_key = self._cmaes.sample( - cmaes_state=cmaes_state, random_key=emitter_state.random_key - ) + coeffs, key = self._cmaes.sample(cmaes_state=cmaes_state, key=key) # make sure the fitness coefficient is positive coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0])) @@ -187,12 +185,9 @@ def emit( # Compute new candidates new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) - return new_thetas, {}, random_key + return new_thetas, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: CMAMEGAState, @@ -251,9 +246,7 @@ def state_update( sorted_indices = jnp.flip(jnp.argsort(ranking_criteria)) # Draw the coeffs - reuse the emitter state key to get same coeffs - coeffs, random_key = self._cmaes.sample( - cmaes_state=cmaes_state, random_key=emitter_state.random_key - ) + coeffs, key = self._cmaes.sample(cmaes_state=cmaes_state, key=emitter_state.key) # make sure the fitness coeff is positive coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0])) @@ -278,7 +271,7 @@ def state_update( ) # re-sample - random_theta, random_key = repertoire.sample(random_key, 1) + random_theta, key = repertoire.sample(key, 1) # update theta in case of reinit theta = jax.tree_util.tree_map( @@ -293,13 +286,13 @@ def state_update( ) # score theta - _, _, extra_score, random_key = self._scoring_function(theta, random_key) + _, _, extra_score = self._scoring_function(theta, key) # create new emitter state emitter_state = CMAMEGAState( theta=theta, theta_grads=extra_score["normalized_grads"], - random_key=random_key, + key=key, cmaes_state=cmaes_state, previous_fitnesses=repertoire.fitnesses, ) diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index 55ccaa4f..d105048d 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -50,7 +50,7 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -63,7 +63,7 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. @@ -72,20 +72,21 @@ def init( def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: - random_key = carry - emitter_state, random_key = self._emitter.init( - random_key, + key = carry + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + subkey, repertoire, genotypes, fitnesses, descriptors, extra_scores, ) - return random_key, emitter_state + return key, emitter_state # init all the emitter states - random_key, emitter_states = jax.lax.scan( - scan_emitter_init, random_key, (), length=self._num_states + key, emitter_states = jax.lax.scan( + scan_emitter_init, key, (), length=self._num_states ) # define the emitter state of the pool @@ -95,7 +96,7 @@ def scan_emitter_init( return ( emitter_state, - random_key, + key, ) @partial(jax.jit, static_argnames=("self",)) @@ -103,18 +104,18 @@ def emit( self, repertoire: Optional[MapElitesRepertoire], emitter_state: CMAPoolEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emits new individuals. Args: repertoire: a repertoire of genotypes (unused). emitter_state: the state of the CMA-MEGA emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: - New genotypes and a new random key. + New genotypes and extra infos. """ # retrieve the relevant emitter state @@ -124,16 +125,11 @@ def emit( ) # use it to emit offsprings - offsprings, extra_info, random_key = self._emitter.emit( - repertoire, used_emitter_state, random_key - ) + offsprings, extra_info = self._emitter.emit(repertoire, used_emitter_state, key) - return offsprings, extra_info, random_key + return offsprings, extra_info - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: CMAPoolEmitterState, diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index c70e4459..8eed255b 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -18,7 +18,7 @@ class CMARndEmitterState(CMAEmitterState): Args: - random_key: a random key to handle stochastic operations. Used for + key: a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm @@ -36,7 +36,7 @@ class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -49,7 +49,7 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. @@ -60,24 +60,24 @@ def init( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # take a random direction - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) random_direction = jax.random.uniform( subkey, shape=(self._centroids.shape[-1],), ) # return the initial state - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) return ( CMARndEmitterState( - random_key=subkey, + key=subkey, cmaes_state=self._cma_initial_state, previous_fitnesses=default_fitnesses, emit_count=0, random_direction=random_direction, ), - random_key, + key, ) def _update_and_init_emitter_state( @@ -86,8 +86,8 @@ def _update_and_init_emitter_state( emitter_state: CMAEmitterState, repertoire: MapElitesRepertoire, emit_count: int, - random_key: RNGKey, - ) -> Tuple[CMAEmitterState, RNGKey]: + key: RNGKey, + ) -> CMAEmitterState: """Update the emitter state in the case of a reinit event. Reinit the cmaes state and use an individual from the repertoire as the starting mean. @@ -97,14 +97,15 @@ def _update_and_init_emitter_state( emitter_state: current cmame state repertoire: most recent repertoire emit_count: counter of the emitter - random_key: key to handle stochastic events + key: key to handle stochastic events Returns: The updated emitter state. """ # re-sample - random_genotype, random_key = repertoire.sample(random_key, 1) + key, subkey = jax.random.split(key) + random_genotype = repertoire.sample(subkey, 1) # get new mean - remove the batch dim new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) @@ -113,9 +114,8 @@ def _update_and_init_emitter_state( cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0) # take a new random direction - random_key, subkey = jax.random.split(random_key) random_direction = jax.random.uniform( - subkey, + key, shape=(self._centroids.shape[-1],), ) @@ -125,7 +125,7 @@ def _update_and_init_emitter_state( random_direction=random_direction, ) - return emitter_state, random_key + return emitter_state # type: ignore def _ranking_criteria( self, diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index ea921237..85ef83da 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -79,26 +79,30 @@ def __init__( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[DiversityPGEmitterState, RNGKey]: + ) -> DiversityPGEmitterState: """Initializes the emitter state. Args: + key: A random key. + repertoire: The initial repertoire. genotypes: The initial population. - random_key: A random key. + fitnesses: The initial fitnesses of the population. + descriptors: The initial descriptors of the population. + extra_scores: Extra scores coming from the scoring function. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init( - random_key, + diversity_emitter_state = super().init( + key, repertoire, genotypes, fitnesses, @@ -130,7 +134,7 @@ def init( archive=archive, ) - return emitter_state, random_key + return emitter_state @partial(jax.jit, static_argnames=("self",)) def state_update( @@ -183,9 +187,9 @@ def scan_train_critics( # sample transitions ( transitions, - random_key, + key, ) = emitter_state.replay_buffer.sample( - random_key=emitter_state.random_key, + key=emitter_state.key, sample_size=self._config.num_critic_training_steps * self._config.batch_size, ) @@ -242,14 +246,14 @@ def _train_critics( critic_optimizer_state, critic_params, target_critic_params, - random_key, + key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_optimizer_state=emitter_state.critic_optimizer_state, transitions=transitions, - random_key=emitter_state.random_key, + key=emitter_state.key, ) # Update greedy policy @@ -282,7 +286,7 @@ def _train_critics( actor_opt_state=policy_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, replay_buffer=emitter_state.replay_buffer, ) @@ -332,8 +336,8 @@ def scan_train_policy( ), () # sample transitions - transitions, _random_key = emitter_state.replay_buffer.sample( - random_key=emitter_state.random_key, + transitions, _ = emitter_state.replay_buffer.sample( + key=emitter_state.key, sample_size=self._config.num_pg_training_steps * self._config.batch_size, ) diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index 21139356..7822e78f 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -31,33 +31,33 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[Optional[EmitterState], RNGKey]: + ) -> Optional[EmitterState]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted. Args: genotypes: The genotypes of the initial population. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial emitter state and a random key. """ - return None, random_key + return None @abstractmethod def emit( self, repertoire: Optional[Repertoire], emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire. @@ -66,17 +66,14 @@ def emit( Args: repertoire: a repertoire of genotypes. emitter_state: the state of the emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: A batch of offspring, a new random key. """ pass - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: Optional[EmitterState], diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index 4d51326a..d90194a5 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -160,7 +160,7 @@ class MEESEmitterState(EmitterState): last_updated_genotypes: used to choose parents from repertoire last_updated_fitnesses: used to choose parents from repertoire last_updated_position: used to choose parents from repertoire - random_key: key to handle stochastic operations + key: key to handle stochastic operations """ initial_optimizer_state: optax.OptState @@ -171,7 +171,7 @@ class MEESEmitterState(EmitterState): last_updated_genotypes: Genotype last_updated_fitnesses: Fitness last_updated_position: jnp.ndarray - random_key: RNGKey + key: RNGKey class MEESEmitter(Emitter): @@ -198,7 +198,7 @@ def __init__( config: MEESConfig, total_generations: int, scoring_fn: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores] ], num_descriptors: int, ) -> None: @@ -232,13 +232,10 @@ def batch_size(self) -> int: """ return 1 - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -249,7 +246,7 @@ def init( Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: The initial state of the MEESEmitter, a new random key. @@ -293,45 +290,38 @@ def init( last_updated_genotypes=last_updated_genotypes, last_updated_fitnesses=last_updated_fitnesses, last_updated_position=0, - random_key=random_key, + key=key, ), - random_key, + key, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: MapElitesRepertoire, emitter_state: MEESEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Return the offspring generated through gradient update. Params: repertoire: the MAP-Elites repertoire to sample from emitter_state - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: a new gradient offspring - a new jax PRNG key """ - return emitter_state.offspring, {}, random_key + return emitter_state.offspring, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _sample_exploit( self, emitter_state: MEESEmitterState, repertoire: MapElitesRepertoire, - random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + key: RNGKey, + ) -> Genotype: """Sample half of the time uniformly from the exploit_num_cell_sample highest-performing cells of the repertoire and half of the time uniformly from the exploit_num_cell_sample highest-performing cells among the @@ -340,18 +330,17 @@ def _sample_exploit( Args: emitter_state: current emitter_state repertoire: the current repertoire - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: samples: a genotype sampled in the repertoire - random_key: an updated jax PRNG random key """ def _sample( - random_key: RNGKey, + key: RNGKey, genotypes: Genotype, fitnesses: Fitness, - ) -> Tuple[Genotype, RNGKey]: + ) -> Genotype: """Sample uniformly from the 2 highest fitness cells.""" max_fitnesses, _ = jax.lax.top_k( @@ -362,14 +351,13 @@ def _sample( ) genotypes_empty = fitnesses < min_fitness p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty) - random_key, subkey = jax.random.split(random_key) samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), + lambda x: jax.random.choice(key, x, shape=(1,), p=p), genotypes, ) - return samples, random_key + return samples - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) # Sample p uniformly p = jax.random.uniform(subkey) @@ -383,35 +371,31 @@ def _sample( genotypes=emitter_state.last_updated_genotypes, fitnesses=emitter_state.last_updated_fitnesses, ) - samples, random_key = jax.lax.cond( + samples = jax.lax.cond( p < 0.5, repertoire_sample, last_updated_sample, - random_key, + key, ) - return samples, random_key + return samples - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _sample_explore( self, emitter_state: MEESEmitterState, repertoire: MapElitesRepertoire, - random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + key: RNGKey, + ) -> Genotype: """Sample uniformly from the explore_num_cell_sample most-novel genotypes. Args: emitter_state: current emitter state repertoire: the current genotypes repertoire - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: samples: a genotype sampled in the repertoire - random_key: an updated jax PRNG random key """ # Compute the novelty of all indivs in the archive @@ -429,25 +413,21 @@ def _sample_explore( ) repertoire_empty = novelties < min_novelty p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) - random_key, subkey = jax.random.split(random_key) samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), + lambda x: jax.random.choice(key, x, shape=(1,), p=p), repertoire.genotypes, ) - return samples, random_key + return samples - @partial( - jax.jit, - static_argnames=("self", "scores_fn"), - ) + @partial(jax.jit, static_argnames=("self", "scores_fn")) def _es_emitter( self, parent: Genotype, optimizer_state: optax.OptState, - random_key: RNGKey, + key: RNGKey, scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray], - ) -> Tuple[Genotype, optax.OptState, RNGKey]: + ) -> Tuple[Genotype, optax.OptState]: """Main es component, given a parent and a way to infer the score from the fitnesses and descriptors fo its es-samples, return its approximated-gradient-generated offspring. @@ -456,13 +436,13 @@ def _es_emitter( parent: the considered parent. scores_fn: a function to infer the score of its es-samples from their fitness and descriptors. - random_key + key Returns: - The approximated-gradients-generated offspring and a new random_key. + The approximated-gradients-generated offspring and a new key. """ - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) # Sampling mirror noise total_sample_number = self._config.sample_number @@ -508,9 +488,7 @@ def _es_emitter( ) # Evaluating samples - fitnesses, descriptors, extra_scores, random_key = self._scoring_fn( - samples, random_key - ) + fitnesses, descriptors, extra_scores = self._scoring_fn(samples, key) # Computing rank, with or without normalisation scores = scores_fn(fitnesses, descriptors) @@ -566,12 +544,9 @@ def _es_emitter( ) offspring = optax.apply_updates(parent, offspring_update) - return offspring, optimizer_state, random_key + return offspring, optimizer_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _buffers_update( self, emitter_state: MEESEmitterState, @@ -646,10 +621,7 @@ def _buffers_update( last_updated_position=last_updated_position, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: MEESEmitterState, @@ -698,23 +670,23 @@ def state_update( ) # Select parent and optimizer_state - parent, random_key = jax.lax.cond( + parent, key = jax.lax.cond( sample_new_parent, - lambda emitter_state, repertoire, random_key: jax.lax.cond( + lambda emitter_state, repertoire, key: jax.lax.cond( use_exploration, self._sample_explore, self._sample_exploit, emitter_state, repertoire, - random_key, + key, ), - lambda emitter_state, repertoire, random_key: ( + lambda emitter_state, repertoire, key: ( emitter_state.offspring, - random_key, + key, ), emitter_state, repertoire, - emitter_state.random_key, + emitter_state.key, ) optimizer_state = jax.lax.cond( sample_new_parent, @@ -739,10 +711,10 @@ def exploration_exploitation_scores( return scores # Run es process - offspring, optimizer_state, random_key = self._es_emitter( + offspring, optimizer_state, key = self._es_emitter( parent=parent, optimizer_state=optimizer_state, - random_key=random_key, + key=key, scores_fn=exploration_exploitation_scores, ) @@ -750,5 +722,5 @@ def exploration_exploitation_scores( optimizer_state=optimizer_state, offspring=offspring, generation_count=generation_count + 1, - random_key=random_key, + key=key, ) diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 17cb8ace..ce14c8c1 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -57,33 +57,32 @@ def get_indexes_separation_batches( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[Optional[EmitterState], RNGKey]: + ) -> Optional[EmitterState]: """ Initialize the state of the emitter. Args: genotypes: The genotypes of the initial population. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: - The initial emitter state and a random key. + The initial emitter state. """ # prepare keys for each emitter - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, len(self.emitters)) + keys = jax.random.split(key, len(self.emitters)) # init all emitter states - gather them emitter_states = [] - for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init( - subkey_emitter, + for emitter, key_emitter in zip(self.emitters, keys): + emitter_state = emitter.init( + key_emitter, repertoire, genotypes, fitnesses, @@ -92,43 +91,42 @@ def init( ) emitter_states.append(emitter_state) - return MultiEmitterState(tuple(emitter_states)), random_key + return MultiEmitterState(tuple(emitter_states)) @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Optional[Repertoire], emitter_state: Optional[MultiEmitterState], - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Emit new population. Use all the sub emitters to emit subpopulation and gather them. Args: repertoire: a repertoire of genotypes. emitter_state: the current state of the emitter. - random_key: key for random operations. + key: key for random operations. Returns: - Offsprings and a new random key. + Offsprings. """ assert emitter_state is not None assert len(emitter_state.emitter_states) == len(self.emitters) # prepare subkeys for each sub emitter - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, len(self.emitters)) + keys = jax.random.split(key, len(self.emitters)) # emit from all emitters and gather offsprings all_offsprings = [] all_extra_info: ExtraScores = {} - for emitter, sub_emitter_state, subkey_emitter in zip( + for emitter, sub_emitter_state, key_emitter in zip( self.emitters, emitter_state.emitter_states, - subkeys, + keys, ): - genotype, extra_info, _ = emitter.emit( - repertoire, sub_emitter_state, subkey_emitter + genotype, extra_info = emitter.emit( + repertoire, sub_emitter_state, key_emitter ) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size @@ -139,7 +137,7 @@ def emit( offsprings = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, all_extra_info, random_key + return offsprings, all_extra_info @partial(jax.jit, static_argnames=("self",)) def state_update( diff --git a/qdax/core/emitters/mutation_operators.py b/qdax/core/emitters/mutation_operators.py index bda2daca..90558a24 100644 --- a/qdax/core/emitters/mutation_operators.py +++ b/qdax/core/emitters/mutation_operators.py @@ -1,7 +1,7 @@ """File defining mutation and crossover functions.""" from functools import partial -from typing import Optional, Tuple +from typing import Optional import jax import jax.numpy as jnp @@ -11,7 +11,7 @@ def _polynomial_mutation( x: jnp.ndarray, - random_key: RNGKey, + key: RNGKey, proportion_to_mutate: float, eta: float, minval: float, @@ -24,7 +24,7 @@ def _polynomial_mutation( Args: x: parameters. - random_key: a random key + key: a random key proportion_to_mutate: the proportion of the given parameters that need to be mutated. eta: the inverse of the power of the mutation applied. @@ -39,7 +39,7 @@ def _polynomial_mutation( num_positions = x.shape[0] positions = jnp.arange(start=0, stop=num_positions) num_positions_to_mutate = int(proportion_to_mutate * num_positions) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) selected_positions = jax.random.choice( key=subkey, a=positions, shape=(num_positions_to_mutate,), replace=False ) @@ -51,7 +51,7 @@ def _polynomial_mutation( mutpow = 1.0 / (1.0 + eta) # Randomly select where to put delta_1 and delta_2 - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) rand = jax.random.uniform( key=subkey, shape=delta_1.shape, @@ -80,18 +80,18 @@ def _polynomial_mutation( def polynomial_mutation( x: Genotype, - random_key: RNGKey, + key: RNGKey, proportion_to_mutate: float, eta: float, minval: float, maxval: float, -) -> Tuple[Genotype, RNGKey]: +) -> Genotype: """ Polynomial mutation over several genotypes Parameters: x: array of genotypes to transform (real values only) - random_key: RNG key for reproducibility. + key: RNG key for reproducibility. Assumed to be of shape (batch_size, genotype_dim) proportion_to_mutate (float): proportion of variables to mutate in each genotype (must be in [0, 1]). @@ -101,11 +101,10 @@ def polynomial_mutation( maxval: maximum value to clip the genotypes. Returns: - New genotypes - same shape as input and a new RNG key + New genotypes - same shape as input """ - random_key, subkey = jax.random.split(random_key) batch_size = jax.tree_util.tree_leaves(x)[0].shape[0] - mutation_key = jax.random.split(subkey, num=batch_size) + mutation_keys = jax.random.split(key, num=batch_size) mutation_fn = partial( _polynomial_mutation, proportion_to_mutate=proportion_to_mutate, @@ -114,14 +113,14 @@ def polynomial_mutation( maxval=maxval, ) mutation_fn = jax.vmap(mutation_fn) - x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_key), x) - return x, random_key + x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_keys), x) + return x def _polynomial_crossover( x1: jnp.ndarray, x2: jnp.ndarray, - random_key: RNGKey, + key: RNGKey, proportion_var_to_change: float, ) -> jnp.ndarray: """ @@ -132,9 +131,7 @@ def _polynomial_crossover( """ num_var_to_change = int(proportion_var_to_change * x1.shape[0]) indices = jnp.arange(start=0, stop=x1.shape[0]) - selected_indices = jax.random.choice( - random_key, indices, shape=(num_var_to_change,) - ) + selected_indices = jax.random.choice(key, indices, shape=(num_var_to_change,)) x = x1.at[selected_indices].set(x2[selected_indices]) return x @@ -142,9 +139,9 @@ def _polynomial_crossover( def polynomial_crossover( x1: Genotype, x2: Genotype, - random_key: RNGKey, + key: RNGKey, proportion_var_to_change: float, -) -> Tuple[Genotype, RNGKey]: +) -> Genotype: """ Crossover over a set of pairs of genotypes. @@ -156,17 +153,15 @@ def polynomial_crossover( Parameters: x1: first batch of genotypes x2: second batch of genotypes - random_key: RNG key for reproducibility + key: RNG key for reproducibility proportion_var_to_change: proportion of variables to exchange between genotypes (must be [0, 1]) Returns: - New genotypes and a new RNG key + New genotypes """ - - random_key, subkey = jax.random.split(random_key) batch_size = jax.tree_util.tree_leaves(x2)[0].shape[0] - crossover_keys = jax.random.split(subkey, num=batch_size) + crossover_keys = jax.random.split(key, num=batch_size) crossover_fn = partial( _polynomial_crossover, proportion_var_to_change=proportion_var_to_change, @@ -176,25 +171,25 @@ def polynomial_crossover( x = jax.tree_util.tree_map( lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2 ) - return x, random_key + return x def isoline_variation( x1: Genotype, x2: Genotype, - random_key: RNGKey, + key: RNGKey, iso_sigma: float, line_sigma: float, minval: Optional[float] = None, maxval: Optional[float] = None, -) -> Tuple[Genotype, RNGKey]: +) -> Genotype: """ Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes Parameters: x1 (Genotypes): first batch of genotypes x2 (Genotypes): second batch of genotypes - random_key (RNGKey): RNG key for reproducibility + key (RNGKey): RNG key for reproducibility iso_sigma (float): spread parameter (noise) line_sigma (float): line parameter (direction of the new genotype) minval (float, Optional): minimum value to clip the genotypes @@ -202,7 +197,6 @@ def isoline_variation( Returns: x (Genotypes): new genotypes - random_key (RNGKey): new RNG key [1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and @@ -210,14 +204,12 @@ def isoline_variation( """ # Computing line_noise - random_key, key_line_noise = jax.random.split(random_key) + key, key_line_noise = jax.random.split(key) batch_size = jax.tree_util.tree_leaves(x1)[0].shape[0] line_noise = jax.random.normal(key_line_noise, shape=(batch_size,)) * line_sigma - def _variation_fn( - x1: jnp.ndarray, x2: jnp.ndarray, random_key: RNGKey - ) -> jnp.ndarray: - iso_noise = jax.random.normal(random_key, shape=x1.shape) * iso_sigma + def _variation_fn(x1: jnp.ndarray, x2: jnp.ndarray, key: RNGKey) -> jnp.ndarray: + iso_noise = jax.random.normal(key, shape=x1.shape) * iso_sigma x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1), line_noise) # Back in bounds if necessary (floating point issues) @@ -227,13 +219,12 @@ def _variation_fn( # create a tree with random keys nb_leaves = len(jax.tree_util.tree_leaves(x1)) - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, num=nb_leaves) - keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), subkeys) + keys = jax.random.split(key, num=nb_leaves) + keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), keys) # apply isolinedd to each branch of the tree x = jax.tree_util.tree_map( lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree ) - return x, random_key + return x diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 580bd151..418a7121 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -92,19 +92,19 @@ def __init__( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[OMGMEGAEmitterState, RNGKey]: + ) -> OMGMEGAEmitterState: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: genotypes: The genotypes of the initial population. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial emitter state. @@ -137,21 +137,15 @@ def init( extra_scores, ) - return ( - OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire), - random_key, - ) + return OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: MapElitesRepertoire, emitter_state: OMGMEGAEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates. @@ -159,22 +153,20 @@ def emit( Args: repertoire: current repertoire emitter_state: current emitter state, contains the gradients - random_key: random key + key: random key Returns: new_genotypes: new candidates to be added to the grid - random_key: updated random key """ # sample genotypes - ( - genotypes, - _, - ) = repertoire.sample(random_key, num_samples=self._batch_size) + key, subkey = jax.random.split(key) + genotypes = repertoire.sample(subkey, num_samples=self._batch_size) # sample gradients - use the same random key for sampling # See class docstrings for discussion about this choice - gradients, random_key = emitter_state.gradients_repertoire.sample( - random_key, num_samples=self._batch_size + key, subkey = jax.random.split(key) + gradients = emitter_state.gradients_repertoire.sample( + subkey, num_samples=self._batch_size ) fitness_gradients = jax.tree_util.tree_map( @@ -195,9 +187,8 @@ def emit( descriptors_gradients = descriptors_gradients / norm_descriptors_gradients # Draw random coefficients - random_key, subkey = jax.random.split(random_key) coeffs = jax.random.multivariate_normal( - subkey, + key, shape=(self._batch_size,), mean=self._mu, cov=self._sigma, @@ -215,12 +206,9 @@ def emit( lambda x, y: x + y, genotypes, update_grad ) - return new_genotypes, {}, random_key + return new_genotypes, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: OMGMEGAEmitterState, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 55bded4e..52e60c95 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -25,7 +25,7 @@ class PBTEmitterState(EmitterState): replay_buffers: ReplayBuffer env_states: EnvState training_states: PBTTrainingState - random_key: RNGKey + key: RNGKey class PBTEmitterConfig(PyTreeNode): @@ -92,21 +92,21 @@ def __init__( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[PBTEmitterState, RNGKey]: + ) -> PBTEmitterState: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ observation_size = self._env.observation_size @@ -131,8 +131,8 @@ def init( replay_buffers = replay_buffer_init(transition=dummy_transitions) # Initialise env states - (random_key, subkey1, subkey2) = jax.random.split(random_key, num=3) - env_states = jax.jit(self._env.reset)(rng=subkey1) + key, subkey = jax.random.split(key) + env_states = jax.jit(self._env.reset)(rng=subkey) reshape_fn = jax.jit( lambda tree: jax.tree_util.tree_map( @@ -158,21 +158,18 @@ def init( replay_buffers=replay_buffers, env_states=env_states, training_states=genotypes, - random_key=subkey2, + key=key, ) - return emitter_state, random_key + return emitter_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: PBTEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid. @@ -180,7 +177,7 @@ def emit( Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: A batch of offspring, the new emitter state and a new key. @@ -192,9 +189,10 @@ def emit( # Mutation evo if self._config.ga_population_size_per_device > 0: mutation_ga_batch_size = self._config.ga_population_size_per_device - x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size) - x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size) - x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key) + sample_key_1, sample_key_2, variation_key = jax.random.split(key, 3) + x1 = repertoire.sample(sample_key_1, mutation_ga_batch_size) + x2 = repertoire.sample(sample_key_2, mutation_ga_batch_size) + x_mutation_ga = self._variation_fn(x1, x2, variation_key) # Gather offspring genotypes = jax.tree_util.tree_map( @@ -205,7 +203,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, {}, random_key + return genotypes, {} @property def batch_size(self) -> int: @@ -322,8 +320,8 @@ def _loop_fn(i, val): # type: ignore ) all_fitnesses = jnp.ravel(all_fitnesses) all_fitnesses = -jnp.sort(-all_fitnesses) - random_key = emitter_state.random_key - random_key, sub_key = jax.random.split(random_key) + key = emitter_state.key + key, sub_key = jax.random.split(key) best_genotypes = jax.tree_util.tree_map( lambda x: jax.random.choice( sub_key, x, shape=(len(fitnesses),), replace=True @@ -366,8 +364,8 @@ def _loop_fn(i, val): # type: ignore # Replacing with samples from the ME repertoire if self._num_to_replace_from_samples > 0: - me_samples, random_key = repertoire.sample( - random_key, self._config.pg_population_size_per_device + me_samples, key = repertoire.sample( + key, self._config.pg_population_size_per_device ) # Resample hyper-params me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples) @@ -407,6 +405,6 @@ def _loop_fn(i, val): # type: ignore training_states=training_states, replay_buffers=replay_buffers, env_states=env_states, - random_key=random_key, + key=key, ) return emitter_state # type: ignore diff --git a/qdax/core/emitters/pbt_variation_operators.py b/qdax/core/emitters/pbt_variation_operators.py index c8537003..13972f0f 100644 --- a/qdax/core/emitters/pbt_variation_operators.py +++ b/qdax/core/emitters/pbt_variation_operators.py @@ -9,7 +9,7 @@ def sac_pbt_variation_fn( training_state1: PBTSacTrainingState, training_state2: PBTSacTrainingState, - random_key: RNGKey, + key: RNGKey, iso_sigma: float, line_sigma: float, ) -> Tuple[PBTSacTrainingState, RNGKey]: @@ -21,7 +21,7 @@ def sac_pbt_variation_fn( Args: training_state1: Training state of first SAC agent. training_state2: Training state of first SAC agent. - random_key: Random key. + key: Random key. iso_sigma: Spread parameter (noise). line_sigma: Line parameter (direction of the new genotype). @@ -42,10 +42,10 @@ def sac_pbt_variation_fn( training_state1.alpha_params, training_state2.alpha_params, ) - (policy_params, critic_params, alpha_params), random_key = isoline_variation( + (policy_params, critic_params, alpha_params), key = isoline_variation( x1=(policy_params1, critic_params1, alpha_params1), x2=(policy_params2, critic_params2, alpha_params2), - random_key=random_key, + key=key, iso_sigma=iso_sigma, line_sigma=line_sigma, ) @@ -58,14 +58,14 @@ def sac_pbt_variation_fn( return ( new_training_state, - random_key, + key, ) def td3_pbt_variation_fn( training_state1: PBTTD3TrainingState, training_state2: PBTTD3TrainingState, - random_key: RNGKey, + key: RNGKey, iso_sigma: float, line_sigma: float, ) -> Tuple[PBTTD3TrainingState, RNGKey]: @@ -77,7 +77,7 @@ def td3_pbt_variation_fn( Args: training_state1: Training state of first TD3 agent. training_state2: Training state of first TD3 agent. - random_key: Random key. + key: Random key. iso_sigma: Spread parameter (noise). line_sigma: Line parameter (direction of the new genotype). @@ -97,10 +97,10 @@ def td3_pbt_variation_fn( ( policy_params, critic_params, - ), random_key = isoline_variation( + ), key = isoline_variation( x1=(policy_params1, critic_params1), x2=(policy_params2, critic_params2), - random_key=random_key, + key=key, iso_sigma=iso_sigma, line_sigma=line_sigma, ) @@ -111,5 +111,5 @@ def td3_pbt_variation_fn( return ( new_training_state, - random_key, + key, ) diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index ab57a19e..e4941e66 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -54,7 +54,7 @@ class QualityDCGEmitterState(EmitterState): target_critic_params: Params target_actor_params: Params replay_buffer: ReplayBuffer - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -127,21 +127,21 @@ def use_all_data(self) -> bool: def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[QualityDCGEmitterState, RNGKey]: + ) -> QualityDCGEmitterState: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] @@ -149,7 +149,7 @@ def init( action_size = self._env.action_size # Initialise critic, greedy actor and population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) fake_obs = jnp.zeros(shape=(observation_size,)) fake_desc = jnp.zeros(shape=(descriptor_size,)) fake_action = jnp.zeros(shape=(action_size,)) @@ -159,7 +159,7 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) @@ -191,7 +191,6 @@ def init( replay_buffer = replay_buffer.insert(transitions) # Initial training state - random_key, subkey = jax.random.split(random_key) emitter_state = QualityDCGEmitterState( critic_params=critic_params, critic_opt_state=critic_opt_state, @@ -200,11 +199,11 @@ def init( target_critic_params=target_critic_params, target_actor_params=target_actor_params, replay_buffer=replay_buffer, - random_key=subkey, + key=key, steps=jnp.array(0), ) - return emitter_state, random_key + return emitter_state @partial(jax.jit, static_argnames=("self",)) def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: @@ -224,21 +223,16 @@ def _normalize_desc(self, desc: Descriptor) -> Descriptor: return ( 2 * (desc - self._env.descriptor_limits[0]) - / ( - self._env.descriptor_limits[1] - - self._env.descriptor_limits[0] - ) + / (self._env.descriptor_limits[1] - self._env.descriptor_limits[0]) - 1 ) @partial(jax.jit, static_argnames=("self",)) def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: return 0.5 * ( - self._env.descriptor_limits[1] - - self._env.descriptor_limits[0] + self._env.descriptor_limits[1] - self._env.descriptor_limits[0] ) * desc_normalized + 0.5 * ( - self._env.descriptor_limits[1] - + self._env.descriptor_limits[0] + self._env.descriptor_limits[1] + self._env.descriptor_limits[0] ) @partial(jax.jit, static_argnames=("self",)) @@ -274,39 +268,32 @@ def _compute_equivalent_params_with_desc( actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias return actor_dc_params - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: QualityDCGEmitterState, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: A batch of offspring, the new emitter state and a new key. """ # PG emitter - parents_pg, descs_pg, random_key = repertoire.sample_with_descs( - random_key, self._config.qpg_batch_size + parents_pg, descs_pg, key = repertoire.sample_with_descs( + key, self._config.qpg_batch_size ) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter - _, descs_ai, random_key = repertoire.sample_with_descs( - random_key, self._config.ai_batch_size - ) - descs_ai = descs_ai.reshape( - descs_ai.shape[0], self._env.descriptor_length - ) + _, descs_ai, key = repertoire.sample_with_descs(key, self._config.ai_batch_size) + descs_ai = descs_ai.reshape(descs_ai.shape[0], self._env.descriptor_length) genotypes_ai = self.emit_ai(emitter_state, descs_ai) # Concatenate PG and AI genotypes @@ -317,13 +304,10 @@ def emit( return ( genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, - random_key, + key, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_pg( self, emitter_state: QualityDCGEmitterState, @@ -350,10 +334,7 @@ def emit_pg( return offsprings - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_ai( self, emitter_state: QualityDCGEmitterState, descs: Descriptor ) -> Genotype: @@ -390,10 +371,7 @@ def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: """ return emitter_state.actor_params - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: QualityDCGEmitterState, @@ -449,8 +427,8 @@ def state_update( emitter_state = emitter_state.replace(replay_buffer=replay_buffer) # sample transitions from the replay buffer - random_key, subkey = jax.random.split(emitter_state.random_key) - transitions, random_key = replay_buffer.sample( + key, subkey = jax.random.split(emitter_state.key) + transitions, key = replay_buffer.sample( subkey, self._config.num_critic_training_steps * self._config.batch_size ) transitions = jax.tree_util.tree_map( @@ -468,7 +446,7 @@ def state_update( rewards=self._similarity(transitions.desc, transitions.desc_prime) * transitions.rewards ) - emitter_state = emitter_state.replace(random_key=random_key) + emitter_state = emitter_state.replace(key=key) def scan_train_critics( carry: QualityDCGEmitterState, @@ -510,14 +488,14 @@ def _train_critics( critic_opt_state, critic_params, target_critic_params, - random_key, + key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_opt_state=emitter_state.critic_opt_state, transitions=transitions, - random_key=emitter_state.random_key, + key=emitter_state.key, ) # Update greedy actor @@ -550,7 +528,7 @@ def _train_critics( actor_opt_state=actor_opt_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, ) @@ -564,17 +542,16 @@ def _update_critic( target_actor_params: Params, critic_opt_state: Params, transitions: DCGTransition, - random_key: RNGKey, - ) -> Tuple[Params, Params, Params, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, Params, Params]: # compute loss and gradients - random_key, subkey = jax.random.split(random_key) critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, transitions, - subkey, + key, ) critic_updates, critic_opt_state = self._critic_optimizer.update( critic_gradient, critic_opt_state @@ -591,7 +568,7 @@ def _update_critic( critic_params, ) - return critic_opt_state, critic_params, target_critic_params, random_key + return critic_opt_state, critic_params, target_critic_params @partial(jax.jit, static_argnames=("self",)) def _update_actor( @@ -629,10 +606,7 @@ def _update_actor( target_actor_params, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _mutation_function_pg( self, policy_params: Genotype, @@ -653,8 +627,8 @@ def _mutation_function_pg( The updated params of the neural network. """ # Get transitions - transitions, random_key = emitter_state.replay_buffer.sample( - emitter_state.random_key, + transitions, key = emitter_state.replay_buffer.sample( + emitter_state.key, sample_size=self._config.num_pg_training_steps * self._config.batch_size, ) descs_prime = jnp.tile( @@ -678,8 +652,8 @@ def _mutation_function_pg( transitions, ) - # Replace random_key - emitter_state = emitter_state.replace(random_key=random_key) + # Replace key + emitter_state = emitter_state.replace(key=key) # Define new policy optimizer state policy_opt_state = self._policies_optimizer.init(policy_params) diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c6e2df7e..c09b05eb 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -54,7 +54,7 @@ class QualityPGEmitterState(EmitterState): target_critic_params: Params target_actor_params: Params replay_buffer: ReplayBuffer - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -120,21 +120,21 @@ def use_all_data(self) -> bool: def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[QualityPGEmitterState, RNGKey]: + ) -> QualityPGEmitterState: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ observation_size = self._env.observation_size @@ -142,7 +142,7 @@ def init( descriptor_size = self._env.state_descriptor_length # Initialise critic, greedy actor and population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) fake_obs = jnp.zeros(shape=(observation_size,)) fake_action = jnp.zeros(shape=(action_size,)) critic_params = self._critic_network.init( @@ -176,7 +176,6 @@ def init( replay_buffer = replay_buffer.insert(transitions) # Initial training state - random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( critic_params=critic_params, critic_optimizer_state=critic_optimizer_state, @@ -185,28 +184,25 @@ def init( target_critic_params=target_critic_params, target_actor_params=target_actor_params, replay_buffer=replay_buffer, - random_key=subkey, + key=key, steps=jnp.array(0), ) - return emitter_state, random_key + return emitter_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: QualityPGEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Do a step of PG emission. Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: A batch of offspring, the new emitter state and a new key. @@ -216,7 +212,7 @@ def emit( # sample parents mutation_pg_batch_size = int(batch_size - 1) - parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size) + parents = repertoire.sample(key, mutation_pg_batch_size) # apply the pg mutation offsprings_pg = self.emit_pg(emitter_state, parents) @@ -236,12 +232,9 @@ def emit( offspring_actor, ) - return genotypes, {}, random_key + return genotypes, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_pg( self, emitter_state: QualityPGEmitterState, parents: Genotype ) -> Genotype: @@ -264,10 +257,7 @@ def emit_pg( return offsprings - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: """Emit the greedy actor. @@ -357,10 +347,10 @@ def _train_critics( """ # Sample a batch of transitions in the buffer - random_key = emitter_state.random_key + key = emitter_state.key replay_buffer = emitter_state.replay_buffer - transitions, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size + transitions, key = replay_buffer.sample( + key, sample_size=self._config.batch_size ) # Update Critic @@ -368,14 +358,14 @@ def _train_critics( critic_optimizer_state, critic_params, target_critic_params, - random_key, + key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_optimizer_state=emitter_state.critic_optimizer_state, transitions=transitions, - random_key=random_key, + key=key, ) # Update greedy actor @@ -408,7 +398,7 @@ def _train_critics( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, replay_buffer=replay_buffer, ) @@ -423,17 +413,16 @@ def _update_critic( target_actor_params: Params, critic_optimizer_state: Params, transitions: QDTransition, - random_key: RNGKey, - ) -> Tuple[Params, Params, Params, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, Params, Params]: # compute loss and gradients - random_key, subkey = jax.random.split(random_key) critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, transitions, - subkey, + key, ) critic_updates, critic_optimizer_state = self._critic_optimizer.update( critic_gradient, critic_optimizer_state @@ -450,7 +439,7 @@ def _update_critic( critic_params, ) - return critic_optimizer_state, critic_params, target_critic_params, random_key + return critic_optimizer_state, critic_params, target_critic_params @partial(jax.jit, static_argnames=("self",)) def _update_actor( @@ -563,10 +552,10 @@ def _train_policy( """ # Sample a batch of transitions in the buffer - random_key = emitter_state.random_key + key = emitter_state.key replay_buffer = emitter_state.replay_buffer - transitions, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size + transitions, key = replay_buffer.sample( + key, sample_size=self._config.batch_size ) # update policy @@ -579,7 +568,7 @@ def _train_policy( # Create new training state new_emitter_state = emitter_state.replace( - random_key=random_key, + key=key, replay_buffer=replay_buffer, ) diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 1d949b2d..6f532c17 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -22,16 +22,13 @@ def __init__( self._variation_percentage = variation_percentage self._batch_size = batch_size - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, @@ -45,24 +42,24 @@ def emit( Params: repertoire: the MAP-Elites repertoire to sample from emitter_state: void - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: a batch of offsprings - a new jax PRNG key """ n_variation = int(self._batch_size * self._variation_percentage) n_mutation = self._batch_size - n_variation if n_variation > 0: - x1, random_key = repertoire.sample(random_key, n_variation) - x2, random_key = repertoire.sample(random_key, n_variation) - - x_variation, random_key = self._variation_fn(x1, x2, random_key) + sample_key_1, sample_key_2, variation_key = jax.random.split(key, 3) + x1 = repertoire.sample(sample_key_1, n_variation) + x2 = repertoire.sample(sample_key_2, n_variation) + x_variation = self._variation_fn(x1, x2, variation_key) if n_mutation > 0: - x1, random_key = repertoire.sample(random_key, n_mutation) - x_mutation, random_key = self._mutation_fn(x1, random_key) + sample_key, mutation_key = jax.random.split(key) + x1 = repertoire.sample(sample_key, n_mutation) + x_mutation, key = self._mutation_fn(x1, mutation_key) if n_variation == 0: genotypes = x_mutation @@ -75,7 +72,7 @@ def emit( x_mutation, ) - return genotypes, {}, random_key + return genotypes, {} @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index d0b075a9..7489cb5e 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -41,7 +41,7 @@ class MAPElites: def __init__( self, scoring_function: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores] ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], @@ -55,8 +55,8 @@ def init( self, genotypes: Genotype, centroids: Centroid, - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState]]: """ Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method @@ -66,16 +66,15 @@ def init( genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tesselation centroids of shape (batch_size, num_descriptors) - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: An initialized MAP-Elite repertoire with the initial state of the emitter, and a random key. """ # score initial genotypes - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = MapElitesRepertoire.init( @@ -87,8 +86,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -96,15 +96,15 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """ Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -116,7 +116,7 @@ def update( Args: repertoire: the MAP-Elites repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the updated MAP-Elites repertoire @@ -125,14 +125,12 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) # scores the offsprings - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # add genotypes in the repertoire repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores) @@ -150,7 +148,7 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics @partial(jax.jit, static_argnames=("self",)) def scan_update( @@ -169,16 +167,16 @@ def scan_update( Returns: The updated repertoire and emitter state, with a new random key and metrics. """ - repertoire, emitter_state, random_key = carry + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) ( repertoire, emitter_state, metrics, - random_key, ) = self.update( repertoire, emitter_state, - random_key, + subkey, ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state, key), metrics diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 8b0e7511..36fa1be3 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -58,8 +58,8 @@ def init( self, genotypes: Genotype, centroids: Centroid, - random_key: RNGKey, - ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MELSRepertoire, Optional[EmitterState]]: """Initialize a MAP-Elites Low-Spread repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. @@ -68,16 +68,15 @@ def init( genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter state, JAX random key). """ # score initial genotypes - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = MELSRepertoire.init( @@ -89,8 +88,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -107,4 +107,4 @@ def init( descriptors=descriptors, extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/core/mome.py b/qdax/core/mome.py index c239bd1f..bd81f36f 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -26,8 +26,8 @@ def init( genotypes: jnp.ndarray, centroids: Centroid, pareto_front_max_length: int, - random_key: RNGKey, - ) -> Tuple[MOMERepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MOMERepertoire, Optional[EmitterState]]: """Initialize a MOME grid with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. @@ -37,16 +37,15 @@ def init( centroids: centroids of the repertoire. pareto_front_max_length: maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint. - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. Returns: The initial repertoire and emitter state, and a new random key. """ # first score - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = MOMERepertoire.init( @@ -59,8 +58,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -68,4 +68,4 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 5057e5e2..8fa4af84 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -466,22 +466,21 @@ def init( @partial(jax.jit, static_argnames=("sample_size",)) def sample( self, - random_key: RNGKey, + key: RNGKey, sample_size: int, ) -> Tuple[Transition, RNGKey]: """ Sample a batch of transitions in the replay buffer. """ - random_key, subkey = jax.random.split(random_key) idx = jax.random.randint( - subkey, + key, shape=(sample_size,), minval=0, maxval=self.current_size, ) samples = jnp.take(self.data, idx, axis=0, mode="clip") transitions = self.transition.__class__.from_flatten(samples, self.transition) - return transitions, random_key + return transitions @jax.jit def insert(self, transitions: Transition) -> ReplayBuffer: diff --git a/qdax/core/neuroevolution/buffers/trajectory_buffer.py b/qdax/core/neuroevolution/buffers/trajectory_buffer.py index 93e1b2f9..aca96b56 100644 --- a/qdax/core/neuroevolution/buffers/trajectory_buffer.py +++ b/qdax/core/neuroevolution/buffers/trajectory_buffer.py @@ -119,9 +119,9 @@ def init( # type: ignore @partial(jax.jit, static_argnames=("sample_size")) def sample( self, - random_key: RNGKey, + key: RNGKey, sample_size: int, - ) -> Tuple[Transition, RNGKey]: + ) -> Transition: """ Sample transitions from the buffer. If sample_traj=False, returns stacked transitions in the shape (sample_size,), if sample_traj=True, return transitions @@ -130,9 +130,8 @@ def sample( # Here we want to sample single transitions # We sample uniformly at random the indexes of valid transitions - random_key, subkey = jax.random.split(random_key) idx = jax.random.randint( - subkey, + key, shape=(sample_size,), minval=0, maxval=self.current_size, @@ -142,28 +141,27 @@ def sample( # (sample_size, concat_dim) transitions = self.transition.__class__.from_flatten(samples, self.transition) - return transitions, random_key + return transitions def sample_with_returns( self, - random_key: RNGKey, + key: RNGKey, sample_size: int, - ) -> Tuple[Transition, Reward, RNGKey]: + ) -> Tuple[Transition, Reward]: """Sample transitions and the return corresponding to their episode. The returns are compute by the method `compute_returns`. Args: - random_key: a random key + key: a random key sample_size: the number of transitions Returns: - The transitions, the associated returns and a new random key. + The transitions, the associated returns. """ # Here we want to sample single transitions # We sample uniformly at random the indexes of valid transitions - random_key, subkey = jax.random.split(random_key) idx = jax.random.randint( - subkey, + key, shape=(sample_size,), minval=0, maxval=self.current_size, @@ -173,7 +171,7 @@ def sample_with_returns( returns = jnp.take(self.returns, idx, mode="clip") # (sample_size, concat_dim) transitions = self.transition.__class__.from_flatten(samples, self.transition) - return transitions, returns, random_key + return transitions, returns @jax.jit def insert(self, transitions: Transition) -> TrajectoryBuffer: diff --git a/qdax/core/neuroevolution/losses/sac_loss.py b/qdax/core/neuroevolution/losses/sac_loss.py index d7289292..6662eb7f 100644 --- a/qdax/core/neuroevolution/losses/sac_loss.py +++ b/qdax/core/neuroevolution/losses/sac_loss.py @@ -71,7 +71,7 @@ def sac_policy_loss_fn( critic_params: Params, alpha: jnp.ndarray, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """ Creates the policy loss used in SAC. @@ -84,16 +84,14 @@ def sac_policy_loss_fn( critic_params: parameters of the critic alpha: entropy coefficient value transitions: transitions collected by the agent - random_key: random key + key: random key Returns: the loss of the policy """ dist_params = policy_fn(policy_params, transitions.obs) - action = parametric_action_distribution.sample_no_postprocessing( - dist_params, random_key - ) + action = parametric_action_distribution.sample_no_postprocessing(dist_params, key) log_prob = parametric_action_distribution.log_prob(dist_params, action) action = parametric_action_distribution.postprocess(action) q_action = critic_fn(critic_params, transitions.obs, action) @@ -114,7 +112,7 @@ def sac_critic_loss_fn( target_critic_params: Params, alpha: jnp.ndarray, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """ Creates the critic loss used in SAC. @@ -128,7 +126,7 @@ def sac_critic_loss_fn( target_critic_params: parameters of the target critic alpha: entropy coefficient value transitions: transitions collected by the agent - random_key: random key + key: random key reward_scaling: a multiplicative factor to the reward discount: the discount factor @@ -139,7 +137,7 @@ def sac_critic_loss_fn( q_old_action = critic_fn(critic_params, transitions.obs, transitions.actions) next_dist_params = policy_fn(policy_params, transitions.next_obs) next_action = parametric_action_distribution.sample_no_postprocessing( - next_dist_params, random_key + next_dist_params, key ) next_log_prob = parametric_action_distribution.log_prob( next_dist_params, next_action @@ -168,7 +166,7 @@ def sac_alpha_loss_fn( action_size: int, policy_params: Params, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """ Creates the alpha loss used in SAC. @@ -180,7 +178,7 @@ def sac_alpha_loss_fn( parametric_action_distribution: the distribution over actions policy_params: parameters of the policy transitions: transitions collected by the agent - random_key: random key + key: random key action_size: the size of the environment's action space Returns: @@ -190,9 +188,7 @@ def sac_alpha_loss_fn( target_entropy = -0.5 * action_size dist_params = policy_fn(policy_params, transitions.obs) - action = parametric_action_distribution.sample_no_postprocessing( - dist_params, random_key - ) + action = parametric_action_distribution.sample_no_postprocessing(dist_params, key) log_prob = parametric_action_distribution.log_prob(dist_params, action) alpha = jnp.exp(log_alpha) alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 964c2c4f..28fd9eb2 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -56,12 +56,11 @@ def _critic_loss_fn( target_policy_params: Params, target_critic_params: Params, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """Critics loss function for TD3 agent""" noise = ( - jax.random.normal(random_key, shape=transitions.actions.shape) - * policy_noise + jax.random.normal(key, shape=transitions.actions.shape) * policy_noise ).clip(-noise_clip, noise_clip) next_action = ( @@ -158,12 +157,11 @@ def _critic_loss_fn( target_actor_params: Params, target_critic_params: Params, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """Descriptor-conditioned critic loss function for TD3 agent""" noise = ( - jax.random.normal(random_key, shape=transitions.actions.shape) - * policy_noise + jax.random.normal(key, shape=transitions.actions.shape) * policy_noise ).clip(-noise_clip, noise_clip) next_action = ( @@ -236,7 +234,7 @@ def td3_critic_loss_fn( reward_scaling: float, discount: float, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """Critics loss function for TD3 agent. @@ -256,7 +254,7 @@ def td3_critic_loss_fn( Return the loss function used to train the critic in TD3. """ noise = ( - jax.random.normal(random_key, shape=transitions.actions.shape) * policy_noise + jax.random.normal(key, shape=transitions.actions.shape) * policy_noise ).clip(-noise_clip, noise_clip) next_action = (policy_fn(target_policy_params, transitions.next_obs) + noise).clip( diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index f269a22b..9de23216 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -26,7 +26,7 @@ class TrainingState(PyTreeNode): def generate_unroll( init_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_step_fn: Callable[ [EnvState, Params, RNGKey], @@ -44,7 +44,7 @@ def generate_unroll( Args: init_state: first state of the rollout. policy_params: params of the individual. - random_key: random key for stochasiticity handling. + key: random key for stochasiticity handling. episode_length: length of the rollout. play_step_fn: function describing how a step need to be taken. @@ -55,12 +55,12 @@ def generate_unroll( def _scan_play_step_fn( carry: Tuple[EnvState, Params, RNGKey], unused_arg: Any ) -> Tuple[Tuple[EnvState, Params, RNGKey], Transition]: - env_state, policy_params, random_key, transitions = play_step_fn(*carry) - return (env_state, policy_params, random_key), transitions + env_state, policy_params, key, transitions = play_step_fn(*carry) + return (env_state, policy_params, key), transitions (state, _, _), transitions = jax.lax.scan( _scan_play_step_fn, - (init_state, policy_params, random_key), + (init_state, policy_params, key), (), length=episode_length, ) @@ -72,7 +72,7 @@ def generate_unroll_actor_dc( init_state: EnvState, actor_dc_params: Params, desc: Descriptor, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_step_actor_dc_fn: Callable[ [EnvState, Descriptor, Params, RNGKey], @@ -92,7 +92,7 @@ def generate_unroll_actor_dc( init_state: first state of the rollout. policy_dc_params: descriptor-conditioned policy params. desc: descriptor the policy attempts to achieve. - random_key: random key for stochasiticity handling. + key: random key for stochasiticity handling. episode_length: length of the rollout. play_step_fn: function describing how a step need to be taken. @@ -107,14 +107,14 @@ def _scan_play_step_fn( env_state, actor_dc_params, desc, - random_key, + key, transitions, ) = play_step_actor_dc_fn(*carry) - return (env_state, actor_dc_params, desc, random_key), transitions + return (env_state, actor_dc_params, desc, key), transitions (state, _, _, _), transitions = jax.lax.scan( _scan_play_step_fn, - (init_state, actor_dc_params, desc, random_key), + (init_state, actor_dc_params, desc, key), (), length=episode_length, ) @@ -141,8 +141,8 @@ def init_population_controllers( policy_network: nn.Module, env: brax.envs.Env, batch_size: int, - random_key: RNGKey, -) -> Tuple[Genotype, RNGKey]: + key: RNGKey, +) -> Genotype: """ Initializes the population of controllers using a policy_network. @@ -151,15 +151,13 @@ def init_population_controllers( controllers. env: the BRAX environment. batch_size: the number of environments we play simultaneously. - random_key: a JAX random key. + key: a JAX random key. Returns: A tuple of the initial population and the new random key. """ - random_key, subkey = jax.random.split(random_key) - - keys = jax.random.split(subkey, num=batch_size) + keys = jax.random.split(key, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - return init_variables, random_key + return init_variables diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index 3cb52a3e..0fe121a1 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -52,7 +52,7 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: # Use a dummy key since the default state init fn is just zeros. return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore - jax.random.PRNGKey(0), (batch_size, hidden_size) + jax.random.key(0), (batch_size, hidden_size) ) diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md index 2d07cab7..826c3ad9 100644 --- a/qdax/tasks/README.md +++ b/qdax/tasks/README.md @@ -19,7 +19,7 @@ Notes: import jax from qdax.tasks.arm import arm_scoring_function -random_key = jax.random.PRNGKey(0) +key = jax.random.key(0) # Get scoring function scoring_fn = arm_scoring_function @@ -31,7 +31,7 @@ min_desc, max_desc = 0., 1. # Get initial batch of parameters num_param_dimensions = ... init_batch_size = ... -random_key, _subkey = jax.random.split(random_key) +key, _subkey = jax.random.split(key) initial_params = jax.random.uniform( _subkey, shape=(init_batch_size, num_param_dimensions), @@ -56,7 +56,7 @@ desc_size = 2 import jax from qdax.tasks.standard_functions import sphere_scoring_function -random_key = jax.random.PRNGKey(0) +key = jax.random.key(0) # Get scoring function scoring_fn = sphere_scoring_function @@ -68,7 +68,7 @@ min_desc, max_desc = 0., 1. # Get initial batch of parameters num_param_dimensions = ... init_batch_size = ... -random_key, _subkey = jax.random.split(random_key) +key, _subkey = jax.random.split(key) initial_params = jax.random.uniform( _subkey, shape=(init_batch_size, num_param_dimensions), @@ -98,7 +98,7 @@ desc_size = 2 import jax from qdax.tasks.hypervolume_functions import square_scoring_function -random_key = jax.random.PRNGKey(0) +key = jax.random.key(0) # Get scoring function scoring_fn = square_scoring_function @@ -110,7 +110,7 @@ min_desc, max_desc = 0., 1. # Get initial batch of parameters num_param_dimensions = ... init_batch_size = ... -random_key, _subkey = jax.random.split(random_key) +key, _subkey = jax.random.split(key) initial_params = jax.random.uniform( _subkey, shape=(init_batch_size, num_param_dimensions), diff --git a/qdax/tasks/arm.py b/qdax/tasks/arm.py index b6eb9e45..92614d74 100644 --- a/qdax/tasks/arm.py +++ b/qdax/tasks/arm.py @@ -19,8 +19,8 @@ def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: Returns: f: the fitness of the individual, given as the variance of the angles. - descriptor: the descriptor of the individual, given as the [x, y] position of the - end-effector of the arm. + descriptor: the descriptor of the individual, given as the [x, y] position + of the end-effector of the arm. Descriptor is normalized to [0, 1] regardless of the DoF. Arm is centered at 0.5, 0.5. """ @@ -40,33 +40,28 @@ def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: def arm_scoring_function( params: Genotype, - random_key: RNGKey, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate policies contained in params in parallel. """ fitnesses, descriptors = jax.vmap(arm)(params) - return ( - fitnesses, - descriptors, - {}, - random_key, - ) + return fitnesses, descriptors, {} def noisy_arm_scoring_function( params: Genotype, - random_key: RNGKey, + key: RNGKey, fit_variance: float, desc_variance: float, params_variance: float, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate policies contained in params in parallel. """ - random_key, f_subkey, d_subkey, p_subkey = jax.random.split(random_key, num=4) + key, f_subkey, d_subkey, p_subkey = jax.random.split(key, num=4) # Add noise to the parameters params = params + jax.random.normal(p_subkey, shape=params.shape) * params_variance @@ -83,9 +78,4 @@ def noisy_arm_scoring_function( + jax.random.normal(d_subkey, shape=descriptors.shape) * desc_variance ) - return ( - fitnesses, - descriptors, - {}, - random_key, - ) + return fitnesses, descriptors, {} diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 8abf6b00..84edbb46 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -34,8 +34,8 @@ def make_policy_network_play_step_fn_brax( Creates a function that when called, plays a step of the environment. Args: - env: The BRAX environment. - policy_network: The policy network structure used for creating and evaluating + env: The Brax environment. + policy_network: The policy network structure used for creating and evaluating policy controllers. Returns: @@ -46,19 +46,20 @@ def make_policy_network_play_step_fn_brax( def default_play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated EnvState and the transition. - Args: env_state: The state of the environment (containing for instance the - actor joint positions and velocities, the reward...). policy_params: The - parameters of policies/controllers. random_key: JAX random key. + Args: + env_state: The state of the environment (containing for instance the + actor joint positions and velocities, the reward...). + policy_params: The parameters of policies/controllers. key: JAX random key. Returns: - next_state: The updated environment state. + next_env_state: The updated environment state. policy_params: The parameters of policies/controllers (unchanged). - random_key: The updated random key. + key: The updated random key. transition: containing some information about the transition: observation, reward, next observation, policy action... """ @@ -66,20 +67,20 @@ def default_play_step_fn( actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] - next_state = env.step(env_state, actions) + next_env_state = env.step(env_state, actions) transition = QDTransition( obs=env_state.obs, - next_obs=next_state.obs, - rewards=next_state.reward, - dones=next_state.done, + next_obs=next_env_state.obs, + rewards=next_env_state.reward, + dones=next_env_state.done, actions=actions, - truncations=next_state.info["truncation"], + truncations=next_env_state.info["truncation"], state_desc=state_desc, - next_state_desc=next_state.info["state_descriptor"], + next_state_desc=next_env_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_env_state, policy_params, key, transition return default_play_step_fn @@ -103,14 +104,14 @@ def get_mask_from_transitions( ) def scoring_function_brax_envs( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, init_states: EnvState, episode_length: int, play_step_fn: Callable[ [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] ], descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Evaluates policies contained in policies_params in parallel in deterministic or pseudo-deterministic environments. @@ -121,7 +122,7 @@ def scoring_function_brax_envs( Args: policies_params: The parameters of closed-loop controllers/policies to evaluate. - random_key: A jax random key + key: A jax random key episode_length: The maximal rollout length. play_step_fn: The function to play a step of the environment. descriptor_extractor: The function to extract the descriptor. @@ -130,16 +131,16 @@ def scoring_function_brax_envs( fitness: Array of fitnesses of all evaluated policies descriptor: Behavioural descriptors of all evaluated policies extra_scores: Additional information resulting from evaluation - random_key: The updated random key. + key: The updated random key. """ # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) unroll_fn = partial( generate_unroll, episode_length=episode_length, play_step_fn=play_step_fn, - random_key=subkey, + key=subkey, ) _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) @@ -151,14 +152,7 @@ def scoring_function_brax_envs( fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) descriptors = descriptor_extractor(data, mask) - return ( - fitnesses, - descriptors, - { - "transitions": data, - }, - random_key, - ) + return fitnesses, descriptors, {"transitions": data} @partial( @@ -172,7 +166,7 @@ def scoring_function_brax_envs( def scoring_actor_dc_function_brax_envs( actors_dc_params: Genotype, descs: Descriptor, - random_key: RNGKey, + key: RNGKey, init_states: EnvState, episode_length: int, play_step_actor_dc_fn: Callable[ @@ -180,7 +174,7 @@ def scoring_actor_dc_function_brax_envs( Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], ], descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Evaluates policies contained in policy_dc_params in parallel in deterministic or pseudo-deterministic environments. @@ -194,7 +188,7 @@ def scoring_actor_dc_function_brax_envs( descriptor-conditioned policy to evaluate. descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. - random_key: A jax random key + key: A jax random key episode_length: The maximal rollout length. play_step_fn: The function to play a step of the environment. descriptor_extractor: The function to extract the descriptor. @@ -203,16 +197,16 @@ def scoring_actor_dc_function_brax_envs( fitness: Array of fitnesses of all evaluated policies descriptor: Behavioural descriptors of all evaluated policies extra_scores: Additional information resulting from evaluation - random_key: The updated random key. + key: The updated random key. """ # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) unroll_fn = partial( generate_unroll_actor_dc, episode_length=episode_length, play_step_actor_dc_fn=play_step_actor_dc_fn, - random_key=subkey, + key=subkey, ) _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) @@ -226,14 +220,7 @@ def scoring_actor_dc_function_brax_envs( fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) descriptors = descriptor_extractor(data, mask) - return ( - fitnesses, - descriptors, - { - "transitions": data, - }, - random_key, - ) + return fitnesses, descriptors, {"transitions": data} @partial( @@ -247,14 +234,14 @@ def scoring_actor_dc_function_brax_envs( ) def reset_based_scoring_function_brax_envs( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_reset_fn: Callable[[RNGKey], EnvState], play_step_fn: Callable[ [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] ], descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Evaluates policies contained in policies_params in parallel. The play_reset_fn function allows for a more general scoring_function that can be called with different batch-size and not only with a batch-size of the same @@ -264,12 +251,12 @@ def reset_based_scoring_function_brax_envs( environment, use "play_reset_fn = env.reset". To define purely deterministic environments, as in "scoring_function", generate - a single init_state using "init_state = env.reset(random_key)", then use - "play_reset_fn = lambda random_key: init_state". + a single init_state using "init_state = env.reset(key)", then use + "play_reset_fn = lambda key: init_state". Args: policies_params: The parameters of closed-loop controllers/policies to evaluate. - random_key: A jax random key + key: A jax random key episode_length: The maximal rollout length. play_reset_fn: The function to reset the environment and obtain initial states. play_step_fn: The function to play a step of the environment. @@ -279,26 +266,25 @@ def reset_based_scoring_function_brax_envs( fitness: Array of fitnesses of all evaluated policies descriptor: Behavioural descriptors of all evaluated policies extra_scores: Additional information resulting from the evaluation - random_key: The updated random key. """ - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split( subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0] ) reset_fn = jax.vmap(play_reset_fn) init_states = reset_fn(keys) - fitnesses, descriptors, extra_scores, random_key = scoring_function_brax_envs( + fitnesses, descriptors, extra_scores, key = scoring_function_brax_envs( policies_params=policies_params, - random_key=random_key, + key=key, init_states=init_states, episode_length=episode_length, play_step_fn=play_step_fn, descriptor_extractor=descriptor_extractor, ) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores @partial( @@ -313,7 +299,7 @@ def reset_based_scoring_function_brax_envs( def reset_based_scoring_actor_dc_function_brax_envs( actors_dc_params: Genotype, descs: Descriptor, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_reset_fn: Callable[[RNGKey], EnvState], play_step_actor_dc_fn: Callable[ @@ -321,7 +307,7 @@ def reset_based_scoring_actor_dc_function_brax_envs( Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], ], descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Evaluates policies contained in policy_dc_params in parallel. The play_reset_fn function allows for a more general scoring_function that can be called with different batch-size and not only with a batch-size of the same @@ -331,15 +317,15 @@ def reset_based_scoring_actor_dc_function_brax_envs( environment, use "play_reset_fn = env.reset". To define purely deterministic environments, as in "scoring_function", generate - a single init_state using "init_state = env.reset(random_key)", then use - "play_reset_fn = lambda random_key: init_state". + a single init_state using "init_state = env.reset(key)", then use + "play_reset_fn = lambda key: init_state". Args: policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. - random_key: A jax random key + key: A jax random key episode_length: The maximal rollout length. play_reset_fn: The function to reset the environment and obtain initial states. @@ -350,10 +336,10 @@ def reset_based_scoring_actor_dc_function_brax_envs( fitness: Array of fitnesses of all evaluated policies descriptor: Behavioural descriptors of all evaluated policies extra_scores: Additional information resulting from the evaluation - random_key: The updated random key. + key: The updated random key. """ - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split( subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0] ) @@ -364,25 +350,25 @@ def reset_based_scoring_actor_dc_function_brax_envs( fitnesses, descriptors, extra_scores, - random_key, + key, ) = scoring_actor_dc_function_brax_envs( actors_dc_params=actors_dc_params, descs=descs, - random_key=random_key, + key=key, init_states=init_states, episode_length=episode_length, play_step_actor_dc_fn=play_step_actor_dc_fn, descriptor_extractor=descriptor_extractor, ) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, descriptor_extraction_fn: Callable[[QDTransition, jnp.ndarray], Descriptor], - random_key: RNGKey, + key: RNGKey, play_step_fn: Optional[ Callable[ [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] @@ -391,18 +377,15 @@ def create_brax_scoring_fn( episode_length: int = 100, deterministic: bool = True, play_reset_fn: Optional[Callable[[RNGKey], EnvState]] = None, -) -> Tuple[ - Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], - RNGKey, -]: +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: """ Creates a scoring function to evaluate a policy in a BRAX task. Args: - env: The BRAX environment. + env: The Brax environment. policy_network: The policy network controller. descriptor_extraction_fn: The behaviour descriptor extraction function. - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. play_step_fn: the function used to perform environment rollouts and collect evaluation episodes. If None, we use make_policy_network_play_step_fn_brax to generate it. @@ -416,7 +399,6 @@ def create_brax_scoring_fn( Returns: The scoring function: a function that takes a batch of genotypes and compute their fitnesses and descriptors - The updated random key. """ if play_step_fn is None: play_step_fn = make_policy_network_play_step_fn_brax(env, policy_network) @@ -424,7 +406,7 @@ def create_brax_scoring_fn( # Deterministic case if deterministic: # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_state = env.reset(subkey) # Define the function to deterministically reset the environment @@ -445,27 +427,26 @@ def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: descriptor_extractor=descriptor_extraction_fn, ) - return scoring_fn, random_key + return scoring_fn def create_default_brax_task_components( env_name: str, - random_key: RNGKey, + key: RNGKey, episode_length: int = 100, mlp_policy_hidden_layer_sizes: Tuple[int, ...] = (64, 64), deterministic: bool = True, ) -> Tuple[ brax.envs.Env, MLP, - Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], - RNGKey, + Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]], ]: """ Creates default environment, policy network and scoring function for a BRAX task. Args: env_name: Name of the BRAX environment (e.g. "ant_omni", "walker2d_uni"...). - random_key: Jax random key + key: Jax random key episode_length: The maximal rollout length. mlp_policy_hidden_layer_sizes: Hidden layer sizes of the policy network. deterministic: Whether we reset the initial state of the robot to the same @@ -477,7 +458,6 @@ def create_default_brax_task_components( policy controllers. scoring_fn: a function that takes a batch of genotypes and compute their fitnesses and descriptors. - random_key: The updated random key. """ env = environments.create(env_name, episode_length=episode_length) @@ -491,26 +471,22 @@ def create_default_brax_task_components( descriptor_extraction_fn = qdax.environments.descriptor_extractor[env_name] - scoring_fn, random_key = create_brax_scoring_fn( + scoring_fn = create_brax_scoring_fn( env, policy_network, descriptor_extraction_fn, - random_key, + key, episode_length=episode_length, deterministic=deterministic, ) - return env, policy_network, scoring_fn, random_key + return env, policy_network, scoring_fn def get_aurora_scoring_fn( - scoring_fn: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] - ], + scoring_fn: Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]], observation_extractor_fn: Callable[[Transition], Observation], -) -> Callable[ - [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] -]: +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores]]: """Evaluates policies contained in flatten_variables in parallel This rollout is only deterministic when all the init states are the same. @@ -525,12 +501,12 @@ def get_aurora_scoring_fn( @functools.wraps(scoring_fn) def _wrapper( - params: Params, random_key: RNGKey # Perform rollouts with each policy - ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: - fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) + params: Params, key: RNGKey # Perform rollouts with each policy + ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores]: + fitnesses, _, extra_scores = scoring_fn(params, key) data = extra_scores["transitions"] observation = observation_extractor_fn(data) # type: ignore extra_scores["last_valid_observations"] = observation - return fitnesses, None, extra_scores, random_key + return fitnesses, None, extra_scores return _wrapper diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py index e56cfdb7..0e883893 100644 --- a/qdax/tasks/hypervolume_functions.py +++ b/qdax/tasks/hypervolume_functions.py @@ -73,17 +73,18 @@ def continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: def get_scoring_function( task_fn: Callable[[Genotype], Tuple[Fitness, Descriptor]] -) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]: +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: + def scoring_function( params: Genotype, - random_key: RNGKey, - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate params in parallel """ fitnesses, descriptors = jax.vmap(task_fn)(params) - return (fitnesses, descriptors, {}, random_key) + return fitnesses, descriptors, {} return scoring_function diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index f003962d..42248d68 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -47,7 +47,7 @@ def default_play_step_fn( env_state: JumanjiState, timestep: JumanjiTimeStep, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey, QDTransition]: """Play an environment step and return the updated state and the transition. Everything is deterministic in this simple example. @@ -75,7 +75,7 @@ def default_play_step_fn( next_state_desc=next_state_desc, ) - return next_state, next_timestep, policy_params, random_key, transition + return next_state, next_timestep, policy_params, key, transition return default_play_step_fn @@ -85,7 +85,7 @@ def generate_jumanji_unroll( init_state: JumanjiState, init_timestep: JumanjiTimeStep, policy_params: Params, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_step_fn: Callable[ [JumanjiState, JumanjiTimeStep, Params, RNGKey], @@ -104,7 +104,7 @@ def generate_jumanji_unroll( Args: init_state: first state of the rollout. policy_params: params of the individual. - random_key: random key for stochasiticity handling. + key: random key for stochasiticity handling. episode_length: length of the rollout. play_step_fn: function describing how a step need to be taken. @@ -115,14 +115,12 @@ def generate_jumanji_unroll( def _scan_play_step_fn( carry: Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey], unused_arg: Any ) -> Tuple[Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey], Transition]: - env_state, timestep, policy_params, random_key, transitions = play_step_fn( - *carry - ) - return (env_state, timestep, policy_params, random_key), transitions + env_state, timestep, policy_params, key, transitions = play_step_fn(*carry) + return (env_state, timestep, policy_params, key), transitions (state, timestep, _, _), transitions = jax.lax.scan( _scan_play_step_fn, - (init_state, init_timestep, policy_params, random_key), + (init_state, init_timestep, policy_params, key), (), length=episode_length, ) @@ -139,7 +137,7 @@ def _scan_play_step_fn( ) def jumanji_scoring_function( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, init_states: JumanjiState, init_timesteps: JumanjiTimeStep, episode_length: int, @@ -159,12 +157,12 @@ def jumanji_scoring_function( """ # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) unroll_fn = partial( generate_jumanji_unroll, episode_length=episode_length, play_step_fn=play_step_fn, - random_key=subkey, + key=subkey, ) _final_state, _final_timestep, data = jax.vmap(unroll_fn)( @@ -186,5 +184,5 @@ def jumanji_scoring_function( { "transitions": data, }, - random_key, + key, ) diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py index 04aec1e8..7f1fba99 100644 --- a/qdax/tasks/qd_suite/archimedean_spiral.py +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -34,8 +34,8 @@ def __init__( Args: parameterization: The parameterization of the genotype, can be either angle or arc length. - archimedean_descriptor: The Archimedean Descriptor, can be either euclidean or - geodesic. + archimedean_descriptor: The Archimedean descriptor, can be either euclidean + or geodesic. amplitude: The amplitude of the Archimedean spiral. precision: The precision of the approximation of the angle from the arc length. diff --git a/qdax/tasks/qd_suite/qd_suite_task.py b/qdax/tasks/qd_suite/qd_suite_task.py index fc5fcfed..fde6fc4a 100644 --- a/qdax/tasks/qd_suite/qd_suite_task.py +++ b/qdax/tasks/qd_suite/qd_suite_task.py @@ -25,14 +25,14 @@ def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: def scoring_function( self, params: Genotype, - random_key: RNGKey, - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate params in parallel """ fitnesses, descriptors = jax.vmap(self.evaluation)(params) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} @abc.abstractmethod def get_descriptor_size(self) -> int: diff --git a/qdax/tasks/standard_functions.py b/qdax/tasks/standard_functions.py index e089fe35..de850c1b 100644 --- a/qdax/tasks/standard_functions.py +++ b/qdax/tasks/standard_functions.py @@ -26,26 +26,26 @@ def sphere(params: Genotype) -> Tuple[Fitness, Descriptor]: def rastrigin_scoring_function( params: Genotype, - random_key: RNGKey, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Scoring function for the rastrigin function """ fitnesses, descriptors = jax.vmap(rastrigin)(params) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} def sphere_scoring_function( params: Genotype, - random_key: RNGKey, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Scoring function for the sphere function """ fitnesses, descriptors = jax.vmap(sphere)(params) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} def _rastrigin_proj_scoring( @@ -105,8 +105,8 @@ def rastrigin_descriptors(x: jnp.ndarray) -> jnp.ndarray: def rastrigin_proj_scoring_function( - params: Genotype, random_key: RNGKey, minval: float = -5.12, maxval: float = 5.12 -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + params: Genotype, key: RNGKey, minval: float = -5.12, maxval: float = 5.12 +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Scoring function for the rastrigin function with a folding of the behaviour space. @@ -117,4 +117,4 @@ def rastrigin_proj_scoring_function( _rastrigin_proj_scoring, in_axes=(0, None, None) )(params, minval, maxval) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index be1d336d..b2a3cd0a 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -115,22 +115,16 @@ def dummy_extra_scores_extractor( return extra_scores -@partial( - jax.jit, - static_argnames=( - "scoring_fn", - "num_samples", - ), -) +@partial(jax.jit, static_argnames=("scoring_fn", "num_samples")) def multi_sample_scoring_function( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_samples: int, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Wrap scoring_function to perform sampling. @@ -139,19 +133,16 @@ def multi_sample_scoring_function( Args: policies_params: policies to evaluate - random_key: JAX random key + key: JAX random key scoring_fn: scoring function used for evaluation num_samples: number of samples to generate for each individual Returns: (n, num_samples) array of fitnesses, (n, num_samples, num_descriptors) array of descriptors, - dict with num_samples extra_scores per individual, - JAX random key + dict with num_samples extra_scores per individual """ - - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, num=num_samples) + keys = jax.random.split(key, num=num_samples) # evaluate sample_scoring_fn = jax.vmap( @@ -166,7 +157,7 @@ def multi_sample_scoring_function( policies_params, keys ) - return all_fitnesses, all_descriptors, all_extra_scores, random_key + return all_fitnesses, all_descriptors, all_extra_scores @partial( @@ -181,10 +172,10 @@ def multi_sample_scoring_function( ) def sampling( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_samples: int, extra_scores_extractor: Callable[ @@ -192,7 +183,7 @@ def sampling( ] = dummy_extra_scores_extractor, fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Wrap scoring_function to perform sampling. This function return the expected fitnesses and descriptors for each @@ -201,7 +192,7 @@ def sampling( Args: policies_params: policies to evaluate - random_key: JAX random key + key: JAX random key scoring_fn: scoring function used for evaluation num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from @@ -221,17 +212,14 @@ def sampling( all_fitnesses, all_descriptors, all_extra_scores, - random_key, - ) = multi_sample_scoring_function( - policies_params, random_key, scoring_fn, num_samples - ) + ) = multi_sample_scoring_function(policies_params, key, scoring_fn, num_samples) # Extract final scores descriptors = descriptor_extractor(all_descriptors) fitnesses = fitness_extractor(all_fitnesses) extra_scores = extra_scores_extractor(all_extra_scores, num_samples) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores @partial( @@ -248,7 +236,7 @@ def sampling( ) def sampling_reproducibility( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey], @@ -271,7 +259,7 @@ def sampling_reproducibility( Args: policies_params: policies to evaluate - random_key: JAX random key + key: JAX random key scoring_fn: scoring function used for evaluation num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from @@ -296,10 +284,8 @@ def sampling_reproducibility( all_fitnesses, all_descriptors, all_extra_scores, - random_key, - ) = multi_sample_scoring_function( - policies_params, random_key, scoring_fn, num_samples - ) + key, + ) = multi_sample_scoring_function(policies_params, key, scoring_fn, num_samples) # Extract final scores descriptors = descriptor_extractor(all_descriptors) @@ -316,5 +302,5 @@ def sampling_reproducibility( extra_scores, fitnesses_reproducibility, descriptors_reproducibility, - random_key, + key, ) diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 4b7189b7..27d8be4f 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -40,17 +40,17 @@ def get_model( def get_initial_params( - model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...] + model: Seq2seq, key: PRNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: """ Returns the initial parameters of a seq2seq model. Args: model: the seq2seq model. - random_key: the random number generator. + key: the random number generator. encoder_input_shape: the shape of the encoder input. """ - random_key, rng1, rng2, rng3 = jax.random.split(random_key, 4) + key, rng1, rng2, rng3 = jax.random.split(key, 4) variables = model.init( {"params": rng1, "lstm": rng2, "dropout": rng3}, jnp.ones(encoder_input_shape, jnp.float32), @@ -63,7 +63,7 @@ def get_initial_params( def train_step( state: train_state.TrainState, batch: Array, - lstm_random_key: PRNGKey, + key: PRNGKey, ) -> Tuple[train_state.TrainState, Dict[str, float]]: """ Trains for one step. @@ -71,12 +71,12 @@ def train_step( Args: state: the training state. batch: the batch of data. - lstm_random_key: the random number key. + key: a random key. """ """Trains one step.""" - lstm_key = jax.random.fold_in(lstm_random_key, state.step) - dropout_key, lstm_key = jax.random.split(lstm_key, 2) + key = jax.random.fold_in(key, state.step) + dropout_key, lstm_key = jax.random.split(key) # Shift input by one to avoid leakage batch_decoder = jnp.roll(batch, shift=1, axis=1) @@ -112,7 +112,7 @@ def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def lstm_ae_train( - random_key: RNGKey, + key: RNGKey, repertoire: UnstructuredRepertoire, params: Params, epoch: int, @@ -153,7 +153,7 @@ def lstm_ae_train( # select repertoire_size indexes going from 0 to the total number of # valid individuals. Those indexes will be used to select the individuals # in the training dataset. - random_key, key_select_p1 = jax.random.split(random_key, 2) + key, key_select_p1 = jax.random.split(key, 2) idx_p1 = jax.random.randint( key_select_p1, shape=(repertoire_size,), minval=0, maxval=num_indivs ) @@ -177,7 +177,7 @@ def lstm_ae_train( loss_val = 0.0 for epoch in range(num_epochs): - random_key, shuffle_key = jax.random.split(random_key, 2) + key, shuffle_key = jax.random.split(key, 2) valid_indexes = jax.random.permutation(shuffle_key, valid_indexes, axis=0) # create dataset with the observation from the sample of valid indexes @@ -197,7 +197,7 @@ def lstm_ae_train( # print(batch.shape) continue - state, loss_val = train_step(state, batch, random_key) + state, loss_val = train_step(state, batch, key) # To see the actual value we cannot jit this function (i.e. the _one_es_epoch # function nor the train function) diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index 31176ed4..ecf7de5a 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -59,10 +59,10 @@ def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, Dict]: return scores, descriptors, {} def scoring_fn( - x: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + x: jnp.ndarray, key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12) best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) @@ -81,9 +81,9 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + key = jax.random.key(0) initial_population = ( - jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0 + jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.0 ) centroids = compute_euclidean_centroids( @@ -109,17 +109,15 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn ) - repertoire, emitter_state, random_key = map_elites.init( - initial_population, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key) ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index 5bfdfd58..0aee7a97 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -75,10 +75,10 @@ def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, ExtraScores]: return scores, descriptors, extra_scores def scoring_fn( - x: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + x: jnp.ndarray, key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12) best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) @@ -95,18 +95,16 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) - initial_population = jax.random.uniform( - random_key, shape=(batch_size, num_dimensions) - ) + key = jax.random.key(0) + initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=2, num_init_cvt_samples=10000, num_centroids=num_centroids, minval=minval, maxval=maxval, - random_key=random_key, + key=key, ) emitter = CMAMEGAEmitter( @@ -121,17 +119,15 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: map_elites = MAPElites( scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn ) - repertoire, emitter_state, random_key = map_elites.init( - initial_population, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key) ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 62c8fefa..be616a92 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -72,7 +72,7 @@ def test_dads_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 51cebded..403ab718 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -66,7 +66,7 @@ def test_dads() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index b5e2cb5a..5faccda0 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -69,7 +69,7 @@ def test_diayn_smerl() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index e7a3a8a2..b5660c5c 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -62,7 +62,7 @@ def test_diayn() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index a1eb1b51..e090fb98 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -64,15 +64,13 @@ def rastrigin_scorer( scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag) - def scoring_fn( - genotypes: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, ExtraScores, RNGKey]: + def scoring_fn(genotypes: jnp.ndarray, key: RNGKey) -> Tuple[Fitness, ExtraScores]: fitnesses, _ = scoring_function(genotypes) - return fitnesses, {}, random_key + return fitnesses, {} # initial population - random_key = jax.random.PRNGKey(42) - random_key, subkey = jax.random.split(random_key) + key = jax.random.key(42) + key, subkey = jax.random.split(key) genotypes = jax.random.uniform( subkey, (batch_size, genotype_dim), @@ -110,22 +108,22 @@ def scoring_fn( ) if isinstance(algo_instance, SPEA2): - repertoire, emitter_state, random_key = algo_instance.init( - genotypes, population_size, num_neighbours, random_key + repertoire, emitter_state, key = algo_instance.init( + genotypes, population_size, num_neighbours, key ) else: - repertoire, emitter_state, random_key = algo_instance.init( - genotypes, population_size, random_key + repertoire, emitter_state, key = algo_instance.init( + genotypes, population_size, key ) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( algo_instance.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 77c55385..04b622bb 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -69,7 +69,7 @@ def test_me_pbt_sac() -> None: ) min_descriptor, max_descriptor = env.descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) @@ -115,9 +115,11 @@ def test_me_pbt_sac() -> None: # get scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] - eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn) + eval_policy = agent.get_eval_qd_fn( + eval_env, descriptor_extraction_fn=descriptor_extraction_fn + ) - def scoring_function(genotypes, random_key): # type: ignore + def scoring_function(genotypes, key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states @@ -125,8 +127,10 @@ def scoring_function(genotypes, random_key): # type: ignore first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) - population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_descriptors, {}, random_key + population_returns, population_descriptors, _, _ = eval_policy( + genotypes, first_states + ) + return population_returns, population_descriptors, {} # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -150,7 +154,7 @@ def scoring_function(genotypes, random_key): # type: ignore num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=key, + key=key, ) key, *keys = jax.random.split(key, num=1 + num_devices) @@ -178,7 +182,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - genotypes=training_states, random_key=keys + genotypes=training_states, key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index e7906883..8d77e0c7 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -69,7 +69,7 @@ def test_me_pbt_td3() -> None: ) min_descriptor, max_descriptor = env.descriptor_limits - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) @@ -113,9 +113,11 @@ def test_me_pbt_td3() -> None: # get scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] - eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn) + eval_policy = agent.get_eval_qd_fn( + eval_env, descriptor_extraction_fn=descriptor_extraction_fn + ) - def scoring_function(genotypes, random_key): # type: ignore + def scoring_function(genotypes, key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states @@ -123,8 +125,10 @@ def scoring_function(genotypes, random_key): # type: ignore first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) - population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_descriptors, {}, random_key + population_returns, population_descriptors, _, _ = eval_policy( + genotypes, first_states + ) + return population_returns, population_descriptors, {} # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -148,7 +152,7 @@ def scoring_function(genotypes, random_key): # type: ignore num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=key, + key=key, ) key, *keys = jax.random.split(key, num=1 + num_devices) @@ -176,7 +180,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - genotypes=training_states, random_key=keys + genotypes=training_states, key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index c867541a..831dab49 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -47,7 +47,7 @@ def test_mees() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -58,7 +58,7 @@ def test_mees() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) fake_batch = jnp.zeros(shape=(1, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -67,7 +67,7 @@ def test_mees() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -89,14 +89,14 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Create the initial environment states for samples and final indivs reset_fn = jax.jit(jax.vmap(env.reset)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=sample_number, axis=0) init_states_samples = reset_fn(keys) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) init_states = reset_fn(keys) @@ -157,13 +157,13 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=random_key, + key=key, ) # Instantiate MAP Elites @@ -173,25 +173,26 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) @jax.jit def update_scan_fn(carry: Any, unused: Any) -> Any: # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) - - return (repertoire, emitter_state, random_key), metrics + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, subkey + ) + return (repertoire, emitter_state, key), metrics # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index ad51c7ae..ea92b2bb 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -61,11 +61,9 @@ def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, ExtraScores]: gradients = jnp.nan_to_num(gradients) return fitnesses, descriptors, {"gradients": gradients} - def scoring_fn( - x: Genotype, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + def scoring_fn(x: Genotype, key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * maxval) best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * maxval * 0.4) @@ -82,10 +80,10 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.PRNGKey(0) + key = jax.random.key(0) # defines the population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) initial_population = jax.random.uniform(subkey, shape=(100, num_dimensions)) sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid @@ -109,17 +107,15 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn ) - repertoire, emitter_state, random_key = map_elites.init( - initial_population, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key) ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index db7dc69e..a6f62a07 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -54,9 +54,9 @@ def test_pbt_sac() -> None: ) @jax.jit - def init_environments(random_key): # type: ignore - env_states = jax.jit(env.reset)(rng=random_key) - eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) + def init_environments(key): # type: ignore + env_states = jax.jit(env.reset)(rng=key) + eval_env_first_states = jax.jit(eval_env.reset)(rng=key) reshape_fn = jax.jit( lambda tree: jax.tree_util.tree_map( @@ -76,7 +76,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 0be68277..245b66ef 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -52,9 +52,9 @@ def test_pbt_td3() -> None: ) @jax.jit - def init_environments(random_key): # type: ignore - env_states = jax.jit(env.reset)(rng=random_key) - eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) + def init_environments(key): # type: ignore + env_states = jax.jit(env.reset)(rng=key) + eval_env_first_states = jax.jit(eval_env.reset)(rng=key) reshape_fn = jax.jit( lambda tree: jax.tree_util.tree_map( @@ -74,7 +74,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 052d57dc..fe80339a 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -54,7 +54,7 @@ def test_pgame() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -65,7 +65,7 @@ def test_pgame() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=env_batch_size) fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -74,7 +74,7 @@ def test_pgame() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -96,7 +96,7 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -145,7 +145,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: ) # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) init_states = reset_fn(keys) @@ -161,13 +161,13 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=random_key, + key=key, ) # Instantiate MAP Elites @@ -177,25 +177,26 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) @jax.jit def update_scan_fn(carry: Any, unused: Any) -> Any: # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) - - return (repertoire, emitter_state, random_key), metrics + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, subkey + ) + return (repertoire, emitter_state, key), metrics # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index a1d88de9..f5ca41b2 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -69,7 +69,7 @@ def test_qdpg() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -80,7 +80,7 @@ def test_qdpg() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=env_batch_size) fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -89,7 +89,7 @@ def test_qdpg() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -111,7 +111,7 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -195,7 +195,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: ) # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) init_states = reset_fn(keys) @@ -211,13 +211,13 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=random_key, + key=key, ) # Instantiate MAP Elites @@ -227,25 +227,26 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) @jax.jit def update_scan_fn(carry: Any, unused: Any) -> Any: # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) - - return (repertoire, emitter_state, random_key), metrics + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, subkey + ) + return (repertoire, emitter_state, key), metrics # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), _metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index 8c26b510..31510670 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -53,7 +53,7 @@ def test_sac() -> None: eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) @@ -82,7 +82,7 @@ def test_sac() -> None: sac = SAC(config=sac_config, action_size=env.action_size) key, subkey = jax.random.split(key) training_state = sac.init( - random_key=subkey, + key=subkey, action_size=env.action_size, observation_size=env.observation_size, ) diff --git a/tests/baselines_test/td3_test.py b/tests/baselines_test/td3_test.py index ff5d338c..55c09811 100644 --- a/tests/baselines_test/td3_test.py +++ b/tests/baselines_test/td3_test.py @@ -49,7 +49,7 @@ def test_td3() -> None: auto_reset=True, eval_metrics=True, ) - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key, subkey = jax.random.split(key) env_state = jax.jit(env.reset)(rng=key) eval_env_first_state = jax.jit(eval_env.reset)(rng=key) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index b2fc20fd..344516d7 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -75,16 +75,17 @@ def test_aurora(env_name: str, batch_size: int) -> None: log_freq = 5 # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init environment - env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + key, subkey = jax.random.split(key) + env, policy_network, scoring_fn = create_default_brax_task_components( env_name=env_name, - random_key=random_key, + key=subkey, ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -165,7 +166,7 @@ def observation_extractor_fn( ) # init the model params - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) model_params = train_seq2seq.get_initial_params( model, subkey, (1, *observations_dims) ) @@ -182,18 +183,19 @@ def observation_extractor_fn( ) # init step of the aurora algorithm - repertoire, emitter_state, aurora_extra_info, random_key = aurora.init( + key, subkey = jax.random.split(key) + repertoire, emitter_state, aurora_extra_info = aurora.init( init_variables, aurora_extra_info, jnp.asarray(l_value_init), max_size, - random_key, + subkey, ) # initializing means and stds and AURORA - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) repertoire, aurora_extra_info = aurora.train( - repertoire, model_params, iteration=0, random_key=subkey + repertoire, model_params, iteration=0, key=subkey ) # design aurora's schedule @@ -215,10 +217,10 @@ def observation_extractor_fn( while iteration < max_iterations: # standard MAP-Elites-like loop for _ in range(log_freq): - repertoire, emitter_state, _, random_key = aurora.update( + repertoire, emitter_state, _, key = aurora.update( repertoire, emitter_state, - random_key, + key, aurora_extra_info=aurora_extra_info, ) @@ -228,7 +230,7 @@ def observation_extractor_fn( # autoencoder steps and Container Size Control (CSC) if (iteration + 1) in schedules: # train the autoencoder (includes the CSC) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) repertoire, aurora_extra_info = aurora.train( repertoire, model_params, iteration, subkey ) diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index 16321fd4..cb63d17b 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -32,14 +32,14 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: ) state = cmaes.init() - random_key = jax.random.PRNGKey(0) + key = jax.random.key(0) iteration_count = 0 for _ in range(num_iterations): iteration_count += 1 # sample - samples, random_key = cmaes.sample(state, random_key) + samples, key = cmaes.sample(state, key) # udpate state = cmaes.update(state, samples) diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index 8bda398f..aab46e12 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -27,10 +27,10 @@ def test_multi_emitter() -> None: max_descriptor = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions) ) @@ -91,18 +91,16 @@ def test_multi_emitter() -> None: ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index a3e9f421..3edf195a 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -51,7 +51,7 @@ def test_map_elites(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -62,13 +62,13 @@ def test_map_elites(env_name: str, batch_size: int) -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) init_states = reset_fn(keys) @@ -77,7 +77,7 @@ def test_map_elites(env_name: str, batch_size: int) -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -99,7 +99,7 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Prepare the scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] @@ -128,28 +128,28 @@ def play_step_fn( ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=random_key, + key=subkey, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 89f14a48..bcae5be0 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -40,7 +40,7 @@ def test_mels(env_name: str, batch_size: int) -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -52,7 +52,7 @@ def test_mels(env_name: str, batch_size: int) -> None: # Init population of controllers. There are batch_size controllers, and each # controller will be evaluated num_samples times. - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -61,7 +61,7 @@ def test_mels(env_name: str, batch_size: int) -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """Play an environment step and return the updated state and the transition.""" @@ -82,7 +82,7 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Prepare the scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] @@ -127,28 +127,26 @@ def metrics_fn(repertoire: MELSRepertoire) -> Dict: ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=random_key, + key=key, ) # Compute initial repertoire - repertoire, emitter_state, random_key = mels.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = mels.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( mels.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index 746b94a0..c7198256 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -68,10 +68,10 @@ def rastrigin_scorer( scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag) def scoring_fn( - genotypes: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + genotypes: jnp.ndarray, key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors = scoring_function(genotypes) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} reference_point = jnp.array([-150, -150]) @@ -79,8 +79,8 @@ def scoring_fn( metrics_function = partial(default_moqd_metrics, reference_point=reference_point) # initial population - random_key = jax.random.PRNGKey(42) - random_key, subkey = jax.random.split(random_key) + key = jax.random.key(42) + key, subkey = jax.random.split(key) genotypes = jax.random.uniform( subkey, (batch_size, num_variables), @@ -111,13 +111,13 @@ def scoring_fn( batch_size=batch_size, ) - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=num_descriptors, num_init_cvt_samples=20000, num_centroids=num_centroids, minval=minval, maxval=maxval, - random_key=random_key, + key=key, ) mome = MOME( @@ -126,18 +126,18 @@ def scoring_fn( metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = mome.init( - genotypes, centroids, pareto_front_max_length, random_key + repertoire, emitter_state, key = mome.init( + genotypes, centroids, pareto_front_max_length, key ) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( mome.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index 06e25fcd..dbd0c495 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -91,9 +91,9 @@ def test_sample() -> None: simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) - random_key = jax.random.PRNGKey(0) + key = jax.random.key(0) - samples, random_key = replay_buffer.sample(random_key, 3) + samples, key = replay_buffer.sample(key, 3) samples_shapes = jax.tree_util.tree_map(lambda x: x.shape, samples) transition_shapes = jax.tree_util.tree_map(lambda x: x.shape, simple_transition) diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index 1e7d44bb..04f2a254 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -41,10 +41,10 @@ def test_arm(task_name: str, batch_size: int) -> None: max_descriptor = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions), @@ -91,18 +91,16 @@ def test_arm(task_name: str, batch_size: int) -> None: ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) @@ -114,7 +112,7 @@ def test_arm_scoring_function() -> None: # Init a random key seed = 42 - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # arm has xy descriptor centered at 0.5 0.5 and min max range is [0,1] # 0 params of first genotype is horizontal and points towards negative x axis @@ -131,27 +129,13 @@ def test_arm_scoring_function() -> None: genotypes_6 = jnp.array([[0.5, 0.5]]) genotypes_7 = jnp.array([[0.75, 0.5]]) - fitness_1, descriptors_1, _, random_key = arm_scoring_function( - genotypes_1, random_key - ) - fitness_2, descriptors_2, _, random_key = arm_scoring_function( - genotypes_2, random_key - ) - fitness_3, descriptors_3, _, random_key = arm_scoring_function( - genotypes_3, random_key - ) - fitness_4, descriptors_4, _, random_key = arm_scoring_function( - genotypes_4, random_key - ) - fitness_5, descriptors_5, _, random_key = arm_scoring_function( - genotypes_5, random_key - ) - fitness_6, descriptors_6, _, random_key = arm_scoring_function( - genotypes_6, random_key - ) - fitness_7, descriptors_7, _, random_key = arm_scoring_function( - genotypes_7, random_key - ) + fitness_1, descriptors_1, _ = arm_scoring_function(genotypes_1, key) + fitness_2, descriptors_2, _ = arm_scoring_function(genotypes_2, key) + fitness_3, descriptors_3, _ = arm_scoring_function(genotypes_3, key) + fitness_4, descriptors_4, _ = arm_scoring_function(genotypes_4, key) + fitness_5, descriptors_5, _ = arm_scoring_function(genotypes_5, key) + fitness_6, descriptors_6, _ = arm_scoring_function(genotypes_6, key) + fitness_7, descriptors_7, _ = arm_scoring_function(genotypes_7, key) # use rounding to avoid some numerical floating point errors pytest.assume( diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index 77b18e1c..4b3b0ad3 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -34,11 +34,12 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - max_descriptor = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) - env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + key, subkey = jax.random.split(key) + env, policy_network, scoring_fn = create_default_brax_task_components( env_name=env_name, - random_key=random_key, + key=subkey, ) # Define emitter @@ -64,33 +65,31 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( + centroids, key = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, - random_key=random_key, + key=key, ) # Init population of controllers - init_variables, random_key = init_population_controllers( - policy_network, env, batch_size, random_key + init_variables, key = init_population_controllers( + policy_network, env, batch_size, key ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index 9b96b4b0..bb423267 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -50,10 +50,10 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_descriptor = 1.0 # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions) ) @@ -97,18 +97,16 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index e3269c08..c4dd66c3 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -30,7 +30,7 @@ def test_jumanji_utils() -> None: env = jumanji.make("Snake-v1") # Reset your (jit-able) environment - key = jax.random.PRNGKey(0) + key = jax.random.key(0) state, _timestep = jax.jit(env.reset)(key) # Interact with the (jit-able) environment @@ -38,7 +38,7 @@ def test_jumanji_utils() -> None: state, _timestep = jax.jit(env.step)(state, action) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # get number of actions num_actions = env.action_spec().maximum + 1 @@ -69,7 +69,7 @@ def observation_processing( ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec @@ -84,7 +84,7 @@ def observation_processing( init_variables = jax.vmap(policy_network.init)(keys, fake_batch) # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) @@ -120,7 +120,7 @@ def descriptor_extraction( return descriptors # create a random projection to a two dim space - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) linear_projection = jax.random.uniform( subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32 ) @@ -139,9 +139,7 @@ def descriptor_extraction( descriptor_extractor=descriptor_extraction_fn, ) - fitnesses, descriptors, extra_scores, random_key = scoring_fn( - init_variables, random_key - ) + fitnesses, descriptors, extra_scores, key = scoring_fn(init_variables, key) pytest.assume(fitnesses.shape == (population_size,)) pytest.assume(jnp.sum(jnp.isnan(fitnesses)) == 0.0) diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index 74180d72..a5f82e22 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -68,7 +68,7 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: grid_shape = tuple([resolution_per_axis for _ in range(descriptor_size)]) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init population of parameters init_variables = task.get_initial_parameters(init_batch_size) @@ -112,18 +112,16 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index b29622b8..eca7ea4f 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -40,10 +40,10 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: max_descriptor = max_param # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions) ) @@ -87,18 +87,16 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/environments_test/wrapper_test.py b/tests/environments_test/wrapper_test.py index f5e035ea..14190153 100644 --- a/tests/environments_test/wrapper_test.py +++ b/tests/environments_test/wrapper_test.py @@ -110,8 +110,8 @@ def test_wrapper(env_name: str) -> None: print("Observation size: ", env.observation_size) print("Action size: ", env.action_size) - random_key = jax.random.PRNGKey(seed) - init_state = env.reset(random_key) + key = jax.random.key(seed) + init_state = env.reset(key) joint_angle = jp.concatenate( [joint.angle_vel(init_state.qp)[0] for joint in env.sys.joints] diff --git a/tests/utils_test/plotting_test.py b/tests/utils_test/plotting_test.py index 17b4a8ea..234de9e8 100644 --- a/tests/utils_test/plotting_test.py +++ b/tests/utils_test/plotting_test.py @@ -38,8 +38,8 @@ def test_onion_grid(num_descriptors: int, grid_shape: Tuple[int, ...]) -> None: minval = jnp.array([0] * num_descriptors) maxval = jnp.array([1] * num_descriptors) - random_key = jax.random.PRNGKey(seed=0) - random_key, key_desc, key_fit = jax.random.split(random_key, num=3) + key = jax.random.key(seed=0) + key, key_desc, key_fit = jax.random.split(key, num=3) number_samples_test = 300 descriptors = jax.random.uniform( diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 4a50641a..db1f100f 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -34,7 +34,7 @@ def test_sampling() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -45,7 +45,7 @@ def test_sampling() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) fake_batch = jnp.zeros(shape=(1, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -54,7 +54,7 @@ def test_sampling() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -76,11 +76,11 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Create the initial environment states for samples and final indivs reset_fn = jax.jit(jax.vmap(env.reset)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) init_states = reset_fn(keys) @@ -110,9 +110,9 @@ def sampling_test( ) # Evaluate individuals using the scoring functions - fitnesses, descriptors, _, _ = scoring_fn(init_variables, random_key) + fitnesses, descriptors, _, _ = scoring_fn(init_variables, key) sample_fitnesses, sample_descriptors, _, _ = scoring_1_sample_fn( - init_variables, random_key + init_variables, key ) # Compare @@ -132,7 +132,7 @@ def sampling_test( # Evaluate individuals using the scoring functions sample_fitnesses, sample_descriptors, _, _ = scoring_multi_sample_fn( - init_variables, random_key + init_variables, key ) # Compare @@ -170,7 +170,7 @@ def sampling_reproducibility_test( fitnesses_reproducibility, descriptors_reproducibility, _, - ) = scoring_1_sample_fn(init_variables, random_key) + ) = scoring_1_sample_fn(init_variables, key) # Compare - all reproducibility should be 0 pytest.assume( @@ -207,7 +207,7 @@ def sampling_reproducibility_test( fitnesses_reproducibility, descriptors_reproducibility, _, - ) = scoring_multi_sample_fn(init_variables, random_key) + ) = scoring_multi_sample_fn(init_variables, key) # Compare - all reproducibility should be 0 pytest.assume(