diff --git a/README.md b/README.md index 052eb74a..aa57b3ba 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ [![codecov](https://codecov.io/gh/adaptive-intelligent-robotics/QDax/branch/feat/add-codecov/graph/badge.svg)](https://codecov.io/gh/adaptive-intelligent-robotics/QDax) -QDax is a tool to accelerate Quality-Diversity (QD) and neuro-evolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛 +QDax is a tool to accelerate Quality-Diversity (QD) and neuroevolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛 QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) @@ -60,14 +60,14 @@ num_iterations = 50 grid_shape = (100, 100) min_param = 0.0 max_param = 1.0 -min_bd = 0.0 -max_bd = 1.0 +min_descriptor = 0.0 +max_descriptor = 1.0 # Init a random key -random_key = jax.random.key(seed) +key = jax.random.key(seed) # Init population of controllers -random_key, subkey = jax.random.split(random_key) +key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions), @@ -106,19 +106,19 @@ map_elites = MAPElites( # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) # Initializes repertoire and emitter state -repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key) +repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) # Run MAP-Elites loop for i in range(num_iterations): - (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + (repertoire, emitter_state, metrics, key,) = map_elites.update( repertoire, emitter_state, - random_key, + key, ) # Get contents of repertoire @@ -133,6 +133,7 @@ QDax currently supports the following algorithms: | Algorithm | Example | |-------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | +| [AURORA](https://arxiv.org/abs/2106.05648) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/aurora.ipynb) | | [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | | [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) | | [DCRL-ME](https://arxiv.org/abs/2401.08632) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dcrlme.ipynb) | diff --git a/docs/api_documentation/core/mels.md b/docs/api_documentation/core/mels.md index 3aa212b5..4fc3aa51 100644 --- a/docs/api_documentation/core/mels.md +++ b/docs/api_documentation/core/mels.md @@ -2,6 +2,6 @@ [ME-LS](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) is a variant of MAP-Elites that thrives the search process towards solutions that are consistent -in the behavior space for uncertain domains. +in the descriptor space for uncertain domains. ::: qdax.core.mels.MELS diff --git a/docs/overview.md b/docs/overview.md index 00de8b20..6af6576e 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -22,18 +22,18 @@ More importantly, QDax handles the archive management which is the key idea of Q ## Code Example ```python # Initializes repertoire and emitter state -repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key) +repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key) for i in range(num_iterations): # generate new population with the emitter - genotypes, random_key = map_elites._emitter.emit( - repertoire, emitter_state, random_key + genotypes, key = map_elites._emitter.emit( + repertoire, emitter_state, key ) # scores/evaluates the population - fitnesses, descriptors, extra_scores, random_key = map_elites._scoring_function( - genotypes, random_key + fitnesses, descriptors, extra_scores, key = map_elites._scoring_function( + genotypes, key ) # update repertoire diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index d9c01a1a..bb0f403a 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with AURORA in Jax\n", + "# Optimizing with AURORA in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [AURORA](https://arxiv.org/pdf/1905.11874.pdf).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -49,8 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -71,7 +70,7 @@ " create_default_brax_task_components,\n", " get_aurora_scoring_fn,\n", ")\n", - "from qdax.environments.bd_extractors import (\n", + "from qdax.environments.descriptor_extractors import (\n", " AuroraExtraInfoNormalization,\n", " get_aurora_encoding,\n", ")\n", @@ -85,8 +84,8 @@ "\n", "\n", "if \"COLAB_TPU_ADDR\" in os.environ:\n", - " from jax.tools import colab_tpu\n", - " colab_tpu.setup_tpu()\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", "\n", "\n", "clear_output()" @@ -110,8 +109,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "\n", "lstm_batch_size = 128 #@param {type:\"integer\"}\n", "\n", @@ -146,7 +145,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -157,14 +156,14 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -187,9 +186,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -211,7 +210,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -220,7 +219,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -230,9 +229,10 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n", + "key, subkey = jax.random.split(key)\n", + "env, policy_network, scoring_fn = create_default_brax_task_components(\n", " env_name=env_name,\n", - " random_key=random_key,\n", + " key=subkey,\n", ")\n", "\n", "def observation_extractor_fn(\n", @@ -322,26 +322,20 @@ "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", "\n", "@jax.jit\n", - "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + "def update_scan_fn(carry: Any, _: Any) -> Any:\n", " \"\"\"Scan the update function.\"\"\"\n", - " (\n", - " repertoire,\n", - " random_key,\n", - " aurora_extra_info\n", - " ) = carry\n", + " repertoire, key, aurora_extra_info = carry\n", "\n", " # update\n", - " (repertoire, _, metrics, random_key,) = aurora.update(\n", + " key, subkey = jax.random.split(key)\n", + " repertoire, _, metrics = aurora.update(\n", " repertoire,\n", " None,\n", - " random_key,\n", + " subkey,\n", " aurora_extra_info=aurora_extra_info,\n", " )\n", "\n", - " return (\n", - " (repertoire, random_key, aurora_extra_info),\n", - " metrics,\n", - " )\n", + " return (repertoire, key, aurora_extra_info), metrics\n", "\n", "# Init algorithm\n", "# AutoEncoder Params and INIT\n", @@ -367,7 +361,7 @@ ")\n", "\n", "# Init the model params\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", ")\n", @@ -410,18 +404,19 @@ ")\n", "\n", "# init step of the aurora algorithm\n", - "repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state, aurora_extra_info = aurora.init(\n", " init_variables,\n", " aurora_extra_info,\n", " jnp.asarray(l_value_init),\n", " max_observation_size,\n", - " random_key,\n", + " subkey,\n", ")\n", "\n", "# initializing means and stds and AURORA\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "repertoire, aurora_extra_info = aurora.train(\n", - " repertoire, model_params, iteration=0, random_key=subkey\n", + " repertoire, model_params, iteration=0, key=subkey\n", ")\n", "\n", "# design aurora's schedule\n", @@ -455,11 +450,11 @@ "while iteration < max_iterations:\n", "\n", " (\n", - " (repertoire, random_key, aurora_extra_info),\n", + " (repertoire, key, aurora_extra_info),\n", " metrics,\n", " ) = jax.lax.scan(\n", " update_scan_fn,\n", - " (repertoire, random_key, aurora_extra_info),\n", + " (repertoire, key, aurora_extra_info),\n", " (),\n", " length=log_freq,\n", " )\n", @@ -472,7 +467,7 @@ " # autoencoder steps and CVC\n", " if (iteration + 1) in schedules:\n", " # train the autoencoder\n", - " random_key, subkey = jax.random.split(random_key)\n", + " key, subkey = jax.random.split(key)\n", " repertoire, aurora_extra_info = aurora.train(\n", " repertoire, model_params, iteration, subkey\n", " )\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index 6e81b989..fd0dc57e 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -13,7 +13,7 @@ "id": "1", "metadata": {}, "source": [ - "# Optimizing with CMA-ES in Jax\n", + "# Optimizing with CMA-ES in JAX\n", "\n", "This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", @@ -178,7 +178,7 @@ "outputs": [], "source": [ "state = cmaes.init()\n", - "random_key = jax.random.key(0)" + "key = jax.random.key(0)" ] }, { @@ -204,8 +204,6 @@ }, "outputs": [], "source": [ - "%%time\n", - "\n", "means = [state.mean]\n", "covs = [(state.sigma**2) * state.cov_matrix]\n", "\n", @@ -214,7 +212,8 @@ " iteration_count += 1\n", "\n", " # sample\n", - " samples, random_key = cmaes.sample(state, random_key)\n", + " key, subkey = jax.random.split(key)\n", + " samples = cmaes.sample(state, subkey)\n", "\n", " # update\n", " state = cmaes.update(state, samples)\n", @@ -285,7 +284,7 @@ "fig, ax = plt.subplots(figsize=(12, 6))\n", "\n", "# sample points to show fitness landscape\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n", "f_x = fitness_fn(x)\n", "\n", diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 186c30ee..b00947d8 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with CMA-ME in Jax\n", + "# Optimizing with CMA-ME in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing parameters on Rastrigin or Sphere problem with [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf). It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", @@ -90,8 +90,8 @@ "sigma_g = .5 #@param {type:\"number\"}\n", "minval = -5.12 #@param {type:\"number\"}\n", "maxval = 5.12 #@param {type:\"number\"}\n", - "min_bd = -5.12 * 0.5 * num_dimensions #@param {type:\"number\"}\n", - "max_bd = 5.12 * 0.5 * num_dimensions #@param {type:\"number\"}\n", + "min_descriptor = -5.12 * 0.5 * num_dimensions #@param {type:\"number\"}\n", + "max_descriptor = 5.12 * 0.5 * num_dimensions #@param {type:\"number\"}\n", "emitter_type = \"imp\" #@param[\"opt\", \"imp\", \"rnd\"]\n", "pool_size = 15 #@param {type:\"integer\"}\n", "optim_problem = \"rastrigin\" #@param[\"rastrigin\", \"sphere\"]\n", @@ -134,14 +134,14 @@ " (maxval / x)\n", " )\n", "\n", - "def _behavior_descriptor_1(x: jnp.ndarray):\n", + "def _descriptor_1(x: jnp.ndarray):\n", " return jnp.sum(clip(x[:x.shape[-1]//2]))\n", "\n", - "def _behavior_descriptor_2(x: jnp.ndarray):\n", + "def _descriptor_2(x: jnp.ndarray):\n", " return jnp.sum(clip(x[x.shape[-1]//2:]))\n", "\n", - "def _behavior_descriptors(x: jnp.ndarray):\n", - " return jnp.array([_behavior_descriptor_1(x), _behavior_descriptor_2(x)])" + "def _descriptors(x: jnp.ndarray):\n", + " return jnp.array([_descriptor_1(x), _descriptor_2(x)])" ] }, { @@ -151,12 +151,12 @@ "outputs": [], "source": [ "def scoring_function(x):\n", - " scores, descriptors = fitness_scoring(x), _behavior_descriptors(x)\n", + " scores, descriptors = fitness_scoring(x), _descriptors(x)\n", " return scores, descriptors, {}\n", "\n", - "def scoring_fn(x, random_key):\n", + "def scoring_fn(x, key):\n", " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", - " return fitnesses, descriptors, extra_scores, random_key" + " return fitnesses, descriptors, extra_scores" ] }, { @@ -205,15 +205,15 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.key(0)\n", + "key = jax.random.key(0)\n", "# in CMA-ME settings (from the paper), there is no init population\n", "# we multiply by zero to reproduce this setting\n", - "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", + "initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", "centroids = compute_euclidean_centroids(\n", " grid_shape=grid_shape,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", ")\n", "\n", "emitter_kwargs = {\n", @@ -259,7 +259,8 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(initial_population, centroids, subkey)" ] }, { @@ -275,11 +276,9 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index 8674a0ef..a4af8543 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with CMA-MEGA in Jax\n", + "# Optimizing with CMA-MEGA in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with [CMA-MEGA](https://arxiv.org/pdf/2106.03894.pdf). It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", @@ -133,11 +133,11 @@ " gradients = jnp.nan_to_num(gradients)\n", "\n", " # Compute normalized gradients\n", - " norm_gradients = jax.tree_util.tree_map(\n", + " norm_gradients = jax.tree.map(\n", " lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),\n", " gradients,\n", " )\n", - " grads = jax.tree_util.tree_map(\n", + " grads = jax.tree.map(\n", " lambda x, y: x / y, gradients, norm_gradients\n", " )\n", " grads = jnp.nan_to_num(grads)\n", @@ -148,9 +148,9 @@ "\n", " return scores, descriptors, extra_scores\n", "\n", - "def scoring_fn(x, random_key):\n", + "def scoring_fn(x, key):\n", " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", - " return fitnesses, descriptors, extra_scores, random_key" + " return fitnesses, descriptors, extra_scores" ] }, { @@ -198,17 +198,18 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.key(0)\n", + "key = jax.random.key(0)\n", "# no initial population - give all the same value as emitter init value\n", - "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", + "initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", - "centroids, random_key = compute_cvt_centroids(\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", " num_descriptors=2,\n", " num_init_cvt_samples=10000,\n", " num_centroids=num_centroids,\n", " minval=minval,\n", " maxval=maxval,\n", - " random_key=random_key,\n", + " key=subkey,\n", ")\n", "\n", "emitter = CMAMEGAEmitter(\n", @@ -233,7 +234,8 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(initial_population, centroids, subkey)" ] }, { @@ -242,11 +244,9 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" @@ -282,7 +282,7 @@ "\n", "# create the plots and the grid\n", "fig, axes = plot_map_elites_results(\n", - " env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=minval, max_bd=maxval\n", + " env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=minval, max_descriptor=maxval\n", ")" ] } diff --git a/examples/dads.ipynb b/examples/dads.ipynb index f7b417ae..c8c8ef5e 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Training DADS with Jax\n", + "# Training DADS with JAX\n", "\n", "This notebook shows how to use QDax to train [DADS](https://arxiv.org/abs/1907.01657) on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "- how to define an environment\n", @@ -46,8 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -164,14 +163,16 @@ ")\n", "\n", "key = jax.random.key(seed)\n", - "env_state = jax.jit(env.reset)(rng=key)\n", - "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", + "\n", + "key, subkey_1, subkey_2 = jax.random.split(key, 3)\n", + "env_state = jax.jit(env.reset)(rng=subkey_1)\n", + "eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2)\n", "\n", "# Initialize buffer\n", "dummy_transition = QDTransition.init_dummy(\n", " observation_dim=env.observation_size + num_skills,\n", " action_dim=env.action_size,\n", - " descriptor_dim=env.behavior_descriptor_length,\n", + " descriptor_dim=env.descriptor_length,\n", ")\n", "replay_buffer = ReplayBuffer.init(\n", " buffer_size=buffer_size, transition=dummy_transition\n", @@ -207,7 +208,7 @@ " # DADS config\n", " num_skills=num_skills,\n", " descriptor_full_state=descriptor_full_state,\n", - " omit_input_dynamics_dim=env.behavior_descriptor_length,\n", + " omit_input_dynamics_dim=env.descriptor_length,\n", " dynamics_update_freq=dynamics_update_freq,\n", " normalize_target=normalize_target,\n", ")\n", @@ -215,7 +216,7 @@ "if descriptor_full_state:\n", " descriptor_size = env.observation_size\n", "else:\n", - " descriptor_size = env.behavior_descriptor_length\n", + " descriptor_size = env.descriptor_length\n", "\n", "# define an instance of DADS\n", "dads = DADS(\n", @@ -225,8 +226,9 @@ ")\n", "\n", "# get the initial training state\n", + "key, subkey = jax.random.split(key)\n", "training_state = dads.init(\n", - " key,\n", + " subkey,\n", " action_size=env.action_size,\n", " observation_size=env.observation_size,\n", " descriptor_size=descriptor_size,\n", @@ -361,8 +363,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", "# Main loop\n", "(training_state, env_state, replay_buffer), metrics = jax.lax.scan(\n", " _scan_do_iteration,\n", @@ -479,10 +479,10 @@ "jit_env_step = jax.jit(visual_env.step)\n", "\n", "@jax.jit\n", - "def jit_inference_fn(params, observation, random_key):\n", + "def jit_inference_fn(params, observation, key):\n", " obs = jnp.concatenate([observation, skill], axis=0)\n", - " action, random_key = dads.select_action(obs, params, random_key, deterministic=True)\n", - " return action, random_key" + " action = dads.select_action(obs, params, key, deterministic=True)\n", + " return action" ] }, { @@ -499,11 +499,12 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.key(seed=1)\n", - "state = jit_env_reset(rng=random_key)\n", + "key = jax.random.key(seed=1)\n", + "state = jit_env_reset(rng=key)\n", "while not state.done:\n", " rollout.append(state)\n", - " action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n", + " key, subkey = jax.random.split(key)\n", + " action = jit_inference_fn(my_params, state.obs, subkey)\n", " state = jit_env_step(state, action)\n", "\n", "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index ff09dc5f..172eef41 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with DCRL-ME in Jax\n", + "# Optimizing with DCRL-ME in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)](https://arxiv.org/abs/2401.08632).\n", "This algorithm extends and improves upon [Descriptor-Conditioned Gradients MAP-Elites (DCG-ME)](https://dl.acm.org/doi/abs/10.1145/3583131.3590503)\n", @@ -49,8 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -73,9 +72,9 @@ "from qdax.core.neuroevolution.buffers.buffer import DCRLTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n", "from qdax.custom_types import EnvState, Params, RNGKey\n", - "from qdax.environments import behavior_descriptor_extractor\n", + "from qdax.environments import descriptor_extractor\n", "from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper\n", - "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.tasks.brax_envs import scoring_function_brax_envs\n", "from qdax.utils.plotting import plot_map_elites_results\n", "\n", "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", @@ -99,8 +98,8 @@ "\n", "env_name = \"ant_omni\" #@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", "episode_length = 250 #@param {type:\"integer\"}\n", - "min_bd = -30.0 #@param {type:\"number\"}\n", - "max_bd = 30.0 #@param {type:\"number\"}\n", + "min_descriptor = -30.0 #@param {type:\"number\"}\n", + "max_descriptor = 30.0 #@param {type:\"number\"}\n", "\n", "num_iterations = 1000 #@param {type:\"integer\"}\n", "batch_size = 256 #@param {type:\"integer\"}\n", @@ -152,9 +151,8 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", @@ -164,7 +162,6 @@ "env = ClipRewardWrapper(\n", " env, clip_min=0.,\n", ") # apply reward clip as DCRL needs positive rewards\n", - "\n", "reset_fn = jax.jit(env.reset)\n", "\n", "# Init policy network\n", @@ -181,7 +178,7 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)" @@ -202,7 +199,7 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state: EnvState, policy_params: Params, random_key: RNGKey\n", + " env_state: EnvState, policy_params: Params, key: RNGKey\n", ") -> Tuple[EnvState, Params, RNGKey, DCRLTransition]:\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", " state_desc = env_state.info[\"state_descriptor\"]\n", @@ -218,16 +215,16 @@ " state_desc=state_desc,\n", " next_state_desc=next_state.info[\"state_descriptor\"],\n", " desc=jnp.zeros(\n", - " env.behavior_descriptor_length,\n", + " env.descriptor_length,\n", " )\n", " * jnp.nan,\n", " desc_prime=jnp.zeros(\n", - " env.behavior_descriptor_length,\n", + " env.descriptor_length,\n", " )\n", " * jnp.nan,\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -246,13 +243,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", - " reset_based_scoring_function_brax_envs,\n", + " scoring_function_brax_envs,\n", " episode_length=episode_length,\n", " play_reset_fn=reset_fn,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -343,35 +340,23 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_params, centroids, random_key\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(\n", + " init_params, centroids, subkey\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", - " # iterate over grid\n", - " repertoire, emitter_state, metrics, random_key = map_elites.update(*carry)\n", - "\n", - " return (repertoire, emitter_state, random_key), metrics" - ] - }, { "cell_type": "code", "execution_count": null, @@ -379,44 +364,37 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", + "metrics = dict.fromkeys([\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"], jnp.array([]))\n", "csv_logger = CSVLogger(\n", " \"dcrlme-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + " header=list(metrics.keys())\n", ")\n", - "all_metrics = {}\n", "\n", - "# main loop\n", + "# Main loop\n", "map_elites_scan_update = map_elites.scan_update\n", "for i in range(num_loops):\n", " start_time = time.time()\n", - " # main iterations\n", " (\n", " repertoire,\n", " emitter_state,\n", - " random_key,\n", - " ), metrics = jax.lax.scan(\n", - " update_scan_fn,\n", - " (repertoire, emitter_state, random_key),\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", + " map_elites_scan_update,\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", " timelapse = time.time() - start_time\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", - "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", "\n", - " csv_logger.log(logged_metrics)" + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" ] }, { @@ -427,12 +405,12 @@ "source": [ "#@title Visualization\n", "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(740) * episode_length * batch_size\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", "\n", "%matplotlib inline\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "# Create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" ] } ], diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index b58a0af0..db8074c0 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Training DIAYN with Jax\n", + "# Training DIAYN with JAX\n", "\n", "This notebook shows how to use QDax to train [DIAYN](https://arxiv.org/abs/1802.06070) on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "- how to define an environment\n", @@ -46,8 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -60,7 +59,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "\n", "from qdax import environments\n", "from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n", @@ -77,7 +75,6 @@ " from jax.tools import colab_tpu\n", " colab_tpu.setup_tpu()\n", "\n", - "\n", "clear_output()" ] }, @@ -163,14 +160,16 @@ ")\n", "\n", "key = jax.random.key(seed)\n", - "env_state = jax.jit(env.reset)(rng=key)\n", - "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", + "\n", + "key, subkey_1, subkey_2 = jax.random.split(key, 3)\n", + "env_state = jax.jit(env.reset)(rng=subkey_1)\n", + "eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2)\n", "\n", "# Initialize buffer\n", "dummy_transition = QDTransition.init_dummy(\n", " observation_dim=env.observation_size + num_skills,\n", " action_dim=env.action_size,\n", - " descriptor_dim=env.behavior_descriptor_length,\n", + " descriptor_dim=env.descriptor_length,\n", ")\n", "replay_buffer = ReplayBuffer.init(\n", " buffer_size=buffer_size, transition=dummy_transition\n", @@ -214,7 +213,7 @@ "if descriptor_full_state:\n", " descriptor_size = env.observation_size\n", "else:\n", - " descriptor_size = env.behavior_descriptor_length\n", + " descriptor_size = env.descriptor_length\n", "\n", "# get the initial training state\n", "training_state = diayn.init(\n", @@ -352,8 +351,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", "# Main loop\n", "(training_state, env_state, replay_buffer), metrics = jax.lax.scan(\n", " _scan_do_iteration,\n", @@ -470,10 +467,10 @@ "jit_env_step = jax.jit(visual_env.step)\n", "\n", "@jax.jit\n", - "def jit_inference_fn(params, observation, random_key):\n", + "def jit_inference_fn(params, observation, key):\n", " obs = jnp.concatenate([observation, skill], axis=0)\n", - " action, random_key = diayn.select_action(obs, params, random_key, deterministic=True)\n", - " return action, random_key" + " action = diayn.select_action(obs, params, key, deterministic=True)\n", + " return action" ] }, { @@ -490,11 +487,12 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.key(seed=1)\n", - "state = jit_env_reset(rng=random_key)\n", + "key, subkey = jax.random.split(key)\n", + "state = jit_env_reset(rng=subkey)\n", "while not state.done:\n", " rollout.append(state)\n", - " action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n", + " key, subkey = jax.random.split(key)\n", + " action = jit_inference_fn(my_params, state.obs, subkey)\n", " state = jit_env_step(state, action)\n", "\n", "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 517b55db..cb5d69f9 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with MAP-Elites in Jax (multi-devices example)\n", + "# Optimizing with MAP-Elites in JAX (multi-devices example)\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [MAP-Elites](https://arxiv.org/abs/1504.04909).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -47,8 +47,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -151,8 +150,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "#@markdown ---" ] }, @@ -171,12 +170,12 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", + "reset_fn = jax.jit(env.reset)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -187,17 +186,10 @@ ")\n", "\n", "# Init population of controllers (batch size controllers)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", - "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", - "\n", - "\n", - "# Create the initial environment states (batch_size_per_device environments)\n", - "random_key, subkey = jax.random.split(random_key)\n", - "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size_per_device, axis=0)\n", - "reset_fn = jax.jit(jax.vmap(env.reset))\n", - "init_states = reset_fn(keys)" + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" ] }, { @@ -217,9 +209,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -241,7 +233,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -250,7 +242,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -260,13 +252,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", " scoring_function,\n", - " init_states=init_states,\n", " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -319,7 +311,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# Instantiate MAP-Elites\n", "map_elites = DistributedMAPElites(\n", " scoring_function=scoring_fn,\n", @@ -328,30 +319,31 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# Compute initial repertoire and emitter state\n", - "random_key = jax.random.split(random_key, num=num_devices)\n", - "random_key = jnp.stack(random_key)\n", + "keys = jax.random.split(key, num=num_devices)\n", + "keys = jnp.stack(keys)\n", "\n", "# add a dimension for devices\n", - "init_variables = jax.tree_util.tree_map(\n", + "init_variables = jax.tree.map(\n", " lambda x: jnp.reshape(x, (num_devices, batch_size_per_device,) + x.shape[1:]),\n", " init_variables\n", ")\n", "\n", "# get initial elements\n", - "repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n", + "repertoire, emitter_state = map_elites.get_distributed_init_fn(\n", " centroids=centroids,\n", " devices=devices,\n", - ")(genotypes=init_variables, random_key=random_key)" + ")(genotypes=init_variables, key=keys)" ] }, { @@ -368,7 +360,7 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", "csv_logger = CSVLogger(\n", " \"mapelites-logs.csv\",\n", @@ -385,10 +377,10 @@ " start_time = time.time()\n", "\n", " # main iterations\n", - " repertoire, emitter_state, random_key, metrics = update_fn(repertoire, emitter_state, random_key)\n", + " repertoire, emitter_state, metrics = update_fn(repertoire, emitter_state, keys)\n", "\n", " # get metrics\n", - " metrics = jax.tree_util.tree_map(lambda x: x[0], metrics)\n", + " metrics = jax.tree.map(lambda x: x[0], metrics)\n", " timelapse = time.time() - start_time\n", "\n", " # log metrics\n", @@ -422,7 +414,7 @@ "outputs": [], "source": [ "# Get the repertoire from the first device\n", - "repertoire = jax.tree_util.tree_map(lambda x: x[0], repertoire)" + "repertoire = jax.tree.map(lambda x: x[0], repertoire)" ] } ], diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 078c1c65..2d7a4250 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -113,8 +113,10 @@ "env = jumanji.make('Snake-v1')\n", "\n", "# Reset your (jit-able) environment\n", - "key = jax.random.key(0)\n", - "state, timestep = jax.jit(env.reset)(key)\n", + "key = jax.random.key(seed)\n", + "\n", + "key, subkey = jax.random.split(key)\n", + "state, timestep = jax.jit(env.reset)(subkey)\n", "\n", "# Interact with the (jit-able) environment\n", "action = env.action_spec().generate_value() # Action selection (dummy value here)\n", @@ -136,10 +138,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Init a random key\n", - "random_key = jax.random.key(seed)\n", - "\n", - "# get number of actions\n", + "# Get number of actions\n", "num_actions = env.action_spec().maximum + 1\n", "\n", "policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)\n", @@ -176,7 +175,7 @@ " env_state,\n", " timestep,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"Play an environment step and return the updated state and the transition.\n", " Everything is deterministic in this simple example.\n", @@ -205,7 +204,7 @@ " next_state_desc=next_state_desc,\n", " )\n", "\n", - " return next_state, next_timestep, policy_params, random_key, transition" + " return next_state, next_timestep, policy_params, key, transition" ] }, { @@ -226,7 +225,7 @@ "outputs": [], "source": [ "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "\n", "# compute observation size from observation spec\n", @@ -237,7 +236,7 @@ "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "\n", @@ -249,7 +248,7 @@ "id": "13", "metadata": {}, "source": [ - "## Define a method to extract behavior descriptor when relevant" + "## Define a method to extract descriptor when relevant" ] }, { @@ -260,7 +259,7 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "def bd_extraction(data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.ndarray) -> Descriptor:\n", + "def descriptor_extraction(data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.ndarray) -> Descriptor:\n", " \"\"\"Compute feet contact time proportion.\n", "\n", " This function suppose that state descriptor is the feet contact, as it\n", @@ -279,13 +278,13 @@ " return descriptors\n", "\n", "# create a random projection to a two dim space\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "linear_projection = jax.random.uniform(\n", " subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32\n", ")\n", "\n", - "bd_extraction_fn = functools.partial(\n", - " bd_extraction,\n", + "descriptor_extraction_fn = functools.partial(\n", + " descriptor_extraction,\n", " linear_projection=linear_projection\n", ")\n", "\n", @@ -296,7 +295,7 @@ " init_timesteps=init_timesteps,\n", " episode_length=episode_length,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")" ] }, @@ -316,10 +315,10 @@ "outputs": [], "source": [ "def scoring_function(\n", - " genotypes: jnp.ndarray, random_key: RNGKey\n", + " genotypes: jnp.ndarray, key: RNGKey\n", ") -> Tuple[Fitness, ExtraScores, RNGKey]:\n", - " fitnesses, _, extra_scores, random_key = scoring_fn(genotypes, random_key)\n", - " return fitnesses.reshape(-1, 1), extra_scores, random_key" + " fitnesses, _, extra_scores = scoring_fn(genotypes, key)\n", + " return fitnesses.reshape(-1, 1), extra_scores" ] }, { @@ -375,8 +374,9 @@ " metrics_function=default_ga_metrics,\n", " )\n", "\n", - " repertoire, emitter_state, random_key = algo_instance.init(\n", - " init_variables, population_size, random_key\n", + " key, subkey = jax.random.split(key)\n", + " repertoire, emitter_state = algo_instance.init(\n", + " init_variables, population_size, subkey\n", " )\n", "\n", "else:\n", @@ -401,7 +401,8 @@ " )\n", "\n", " # Compute initial repertoire and emitter state\n", - " repertoire, emitter_state, random_key = algo_instance.init(init_variables, centroids, random_key)" + " key, subkey = jax.random.split(key)\n", + " repertoire, emitter_state = algo_instance.init(init_variables, centroids, subkey)" ] }, { @@ -419,12 +420,10 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", "# Run the algorithm\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " algo_instance.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" @@ -474,7 +473,7 @@ "\n", "# create the plots and the grid\n", "fig, axes = plot_map_elites_results(\n", - " env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=-1., max_bd=1.\n", + " env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=-1., max_descriptor=1.\n", ")" ] }, @@ -519,7 +518,7 @@ "metadata": {}, "outputs": [], "source": [ - "my_params = jax.tree_util.tree_map(\n", + "my_params = jax.tree.map(\n", " lambda x: x[best_idx],\n", " repertoire.genotypes\n", ")" @@ -532,7 +531,7 @@ "metadata": {}, "outputs": [], "source": [ - "init_state = jax.tree_util.tree_map(\n", + "init_state = jax.tree.map(\n", " lambda x: x[0],\n", " init_states\n", ")" @@ -545,7 +544,7 @@ "metadata": {}, "outputs": [], "source": [ - "init_timestep = jax.tree_util.tree_map(\n", + "init_timestep = jax.tree.map(\n", " lambda x: x[0],\n", " init_timesteps\n", ")" @@ -558,8 +557,8 @@ "metadata": {}, "outputs": [], "source": [ - "state = jax.tree_util.tree_map(lambda x: x.copy(), init_state)\n", - "timestep = jax.tree_util.tree_map(lambda x: x.copy(), init_timestep)\n", + "state = jax.tree.map(lambda x: x.copy(), init_state)\n", + "timestep = jax.tree.map(lambda x: x.copy(), init_timestep)\n", "\n", "for _ in range(100):\n", " # (Optional) Render the env state\n", @@ -592,7 +591,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b6b4652a..602b2801 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with MAP-Elites in Jax\n", + "# Optimizing with MAP-Elites in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [MAP-Elites](https://arxiv.org/abs/1504.04909).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -49,8 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -82,12 +81,10 @@ "from brax.v1.io import html\n", "\n", "\n", - "\n", "if \"COLAB_TPU_ADDR\" in os.environ:\n", " from jax.tools import colab_tpu\n", " colab_tpu.setup_tpu()\n", "\n", - "\n", "clear_output()" ] }, @@ -109,8 +106,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "#@markdown ---" ] }, @@ -131,9 +128,10 @@ "source": [ "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", + "reset_fn = jax.jit(env.reset)\n", "\n", "# Init a random key\n", - "random_key = jax.Key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -144,17 +142,10 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", - "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", - "\n", - "\n", - "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", - "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", - "reset_fn = jax.jit(jax.vmap(env.reset))\n", - "init_states = reset_fn(keys)" + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" ] }, { @@ -174,9 +165,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -198,7 +189,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -207,7 +198,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -217,13 +208,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", " scoring_function,\n", - " init_states=init_states,\n", " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -255,6 +246,7 @@ "variation_fn = functools.partial(\n", " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", ")\n", + "\n", "mixing_emitter = MixingEmitter(\n", " mutation_fn=None,\n", " variation_fn=variation_fn,\n", @@ -284,17 +276,19 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# Compute initial repertoire and emitter state\n", - "repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)" + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey)" ] }, { @@ -311,40 +305,37 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", + "metrics = dict.fromkeys([\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"], jnp.array([]))\n", "csv_logger = CSVLogger(\n", " \"mapelites-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + " header=list(metrics.keys())\n", ")\n", - "all_metrics = {}\n", "\n", - "# main loop\n", + "# Main loop\n", "map_elites_scan_update = map_elites.scan_update\n", "for i in range(num_loops):\n", " start_time = time.time()\n", - " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", " timelapse = time.time() - start_time\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", - "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", "\n", - " csv_logger.log(logged_metrics)" + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" ] }, { @@ -355,11 +346,12 @@ "source": [ "#@title Visualization\n", "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", "\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "%matplotlib inline\n", + "# Create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" ] }, { @@ -403,7 +395,7 @@ "outputs": [], "source": [ "# Init population of policies\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", "fake_params = policy_network.init(subkey, fake_batch)\n", "\n", @@ -441,7 +433,7 @@ "source": [ "best_idx = jnp.argmax(repertoire.fitnesses)\n", "best_fitness = jnp.max(repertoire.fitnesses)\n", - "best_bd = repertoire.descriptors[best_idx]" + "best_descriptor = repertoire.descriptors[best_idx]" ] }, { @@ -452,7 +444,7 @@ "source": [ "print(\n", " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\",\n", - " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\",\n", + " f\"Descriptor of the best individual in the repertoire: {best_descriptor}\\n\",\n", " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", ")" ] @@ -463,7 +455,7 @@ "metadata": {}, "outputs": [], "source": [ - "my_params = jax.tree_util.tree_map(\n", + "my_params = jax.tree.map(\n", " lambda x: x[best_idx],\n", " repertoire.genotypes\n", ")" @@ -494,8 +486,8 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.key(seed=1)\n", - "state = jit_env_reset(rng=rng)\n", + "key, subkey = jax.random.split(key)\n", + "state = jit_env_reset(rng=subkey)\n", "while not state.done:\n", " rollout.append(state)\n", " action = jit_inference_fn(my_params, state.obs)\n", @@ -533,7 +525,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 42c46188..00c7316a 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -117,7 +117,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# Initialize environments\n", "env = environments.create(\n", " env_name=env_name,\n", @@ -140,7 +139,7 @@ "metadata": {}, "outputs": [], "source": [ - "min_bd, max_bd = env.behavior_descriptor_limits" + "min_descriptor, max_descriptor = env.descriptor_limits" ] }, { @@ -149,11 +148,11 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "key = jax.random.key(seed)\n", - "key, subkey = jax.random.split(key)\n", - "env_states = jax.jit(env.reset)(rng=subkey)\n", - "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" + "\n", + "key, subkey_1, subkey_2 = jax.random.split(key, 3)\n", + "env_states = jax.jit(env.reset)(rng=subkey_1)\n", + "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)" ] }, { @@ -231,20 +230,20 @@ "outputs": [], "source": [ "# get scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", - "eval_policy = agent.get_eval_qd_fn(eval_env, bd_extraction_fn=bd_extraction_fn)\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", + "eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)\n", "\n", "\n", - "def scoring_function(genotypes, random_key):\n", - " population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0]\n", + "def scoring_function(genotypes, key):\n", + " population_size = jax.tree.leaves(genotypes)[0].shape[0]\n", " first_states = jax.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n", " )\n", " first_states = jax.tree_map(\n", " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", - " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, {}, random_key" + " population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n", + " return population_returns, population_descriptors, {}" ] }, { @@ -276,14 +275,14 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "centroids, key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")" ] }, @@ -303,7 +302,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# get the initial training states and replay buffers\n", "agent_init_fn = agent.get_init_fn(\n", " population_size=pg_population_size_per_device + ga_population_size_per_device,\n", @@ -315,7 +313,7 @@ "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", "keys = jax.random.key_data(keys)\n", "\n", - "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" + "training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, { @@ -324,7 +322,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# empty optimizers states to avoid too heavy repertories\n", "training_states = jax.pmap(\n", " jax.vmap(training_states.__class__.empty_optimizers_states),\n", @@ -333,9 +330,9 @@ ")(training_states)\n", "\n", "# initialize map-elites\n", - "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", + "repertoire, emitter_state = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, key=keys)" ] }, { @@ -369,13 +366,12 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "all_metrics = {}\n", "\n", "for i in tqdm(range(num_loops // log_period), total=num_loops // log_period):\n", " start_time = time.time()\n", "\n", - " repertoire, emitter_state, keys, metrics = update_fn(\n", + " repertoire, emitter_state, metrics = update_fn(\n", " repertoire, emitter_state, keys\n", " )\n", " metrics_cpu = jax.tree_map(\n", @@ -384,12 +380,12 @@ " timelapse = time.time() - start_time\n", "\n", " # log metrics\n", - " for key, value in metrics_cpu.items():\n", + " for k, v in metrics_cpu.items():\n", " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " if k in all_metrics.keys():\n", + " all_metrics[k] = jnp.concatenate([all_metrics[k], v])\n", " else:\n", - " all_metrics[key] = value" + " all_metrics[k] = v" ] }, { @@ -410,8 +406,8 @@ " env_steps=env_steps,\n", " metrics=all_metrics,\n", " repertoire=repertoire_cpu,\n", - " min_bd=min_bd,\n", - " max_bd=max_bd,\n", + " min_descriptor=min_descriptor,\n", + " max_descriptor=max_descriptor,\n", ")" ] }, @@ -431,7 +427,7 @@ "# Evaluate best individual of the repertoire\n", "best_idx = jnp.argmax(repertoire_cpu.fitnesses)\n", "best_fitness = jnp.max(repertoire_cpu.fitnesses)\n", - "best_bd = repertoire_cpu.descriptors[best_idx]" + "best_descriptor = repertoire_cpu.descriptors[best_idx]" ] }, { @@ -452,7 +448,7 @@ "# Evaluate agent that goes the further on the y-axis\n", "# best_idx = jnp.argmax(repertoire.descriptors[:, 0])\n", "# best_fitness = repertoire.fitnesses[best_idx]\n", - "# best_bd = repertoire.descriptors[best_idx]" + "# best_descriptor = repertoire.descriptors[best_idx]" ] }, { @@ -463,7 +459,7 @@ "source": [ "print(\n", " f\"Fitness of the selected agent: {best_fitness:.2f}\\n\",\n", - " f\"Behavior descriptor of the selected agent: {best_bd}\\n\",\n", + " f\"Descriptor of the selected agent: {best_descriptor}\\n\",\n", " f\"Index in the repertoire of this individual: {best_idx}\\n\",\n", ")" ] @@ -496,7 +492,7 @@ "metadata": {}, "outputs": [], "source": [ - "training_state = jax.tree_util.tree_map(lambda x: x[best_idx], repertoire_cpu.genotypes)" + "training_state = jax.tree.map(lambda x: x[best_idx], repertoire_cpu.genotypes)" ] }, { @@ -505,11 +501,10 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.key(seed=1)\n", - "env_state = jax.jit(env.reset)(rng=rng)\n", + "key, subkey = jax.random.split(key)\n", + "env_state = jax.jit(env.reset)(rng=subkey)\n", "\n", "training_state, env_state = jax.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n", @@ -543,13 +538,6 @@ "source": [ "HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -568,7 +556,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 8caca62f..0fa59cde 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -122,7 +122,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# Initialize environments\n", "env = environments.create(\n", " env_name=env_name,\n", @@ -145,7 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "min_bd, max_bd = env.behavior_descriptor_limits" + "min_descriptor, max_descriptor = env.descriptor_limits" ] }, { @@ -154,11 +153,10 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "key = jax.random.key(seed)\n", - "key, subkey = jax.random.split(key)\n", - "env_states = jax.jit(env.reset)(rng=subkey)\n", - "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)" + "key, subkey_1, subkey_2 = jax.random.split(key, 3)\n", + "env_states = jax.jit(env.reset)(rng=subkey_1)\n", + "eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)" ] }, { @@ -234,11 +232,11 @@ "outputs": [], "source": [ "# get scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", - "eval_policy = agent.get_eval_qd_fn(eval_env, bd_extraction_fn=bd_extraction_fn)\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", + "eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)\n", "\n", "\n", - "def scoring_function(genotypes, random_key):\n", + "def scoring_function(genotypes, key):\n", " population_size = jax.tree_leaves(genotypes)[0].shape[0]\n", " first_states = jax.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n", @@ -246,8 +244,8 @@ " first_states = jax.tree_map(\n", " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", - " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, {}, random_key" + " population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n", + " return population_returns, population_descriptors, {}" ] }, { @@ -279,14 +277,14 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "centroids, key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")" ] }, @@ -306,7 +304,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# get the initial training states and replay buffers\n", "agent_init_fn = agent.get_init_fn(\n", " population_size=pg_population_size_per_device + ga_population_size_per_device,\n", @@ -318,7 +315,7 @@ "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", "keys = jax.random.key_data(keys)\n", "\n", - "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" + "training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, { @@ -327,7 +324,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "# empty optimizers states to avoid too heavy repertories\n", "training_states = jax.pmap(\n", " jax.vmap(training_states.__class__.empty_optimizers_states),\n", @@ -336,9 +332,9 @@ ")(training_states)\n", "\n", "# initialize map-elites\n", - "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", + "repertoire, emitter_state = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, key=keys)" ] }, { @@ -372,26 +368,27 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "all_metrics = {}\n", "repertoires = []\n", "\n", "for i in tqdm(range(num_iterations // log_period), total=num_iterations // log_period):\n", " start_time = time.time()\n", "\n", - " repertoire, emitter_state, keys, metrics = update_fn(\n", + " key, *keys = jax.random.split(key, num=1 + num_devices)\n", + " keys = jnp.stack(keys)\n", + " repertoire, emitter_state, metrics = update_fn(\n", " repertoire, emitter_state, keys\n", " )\n", " metrics_cpu = jax.tree_map(lambda x: jax.device_get(x)[0], metrics)\n", " timelapse = time.time() - start_time\n", "\n", " # log metrics\n", - " for key, value in metrics_cpu.items():\n", + " for k, v in metrics_cpu.items():\n", " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " if k in all_metrics.keys():\n", + " all_metrics[k] = jnp.concatenate([all_metrics[k], v])\n", " else:\n", - " all_metrics[key] = value\n", + " all_metrics[k] = v\n", "\n", " if i % save_repertoire_freq == 0:\n", " repertoires.append(jax.tree_map(lambda x: jax.device_get(x)[0], repertoire))" @@ -410,8 +407,8 @@ " env_steps=env_steps,\n", " metrics=all_metrics,\n", " repertoire=repertoires[-1],\n", - " min_bd=min_bd,\n", - " max_bd=max_bd,\n", + " min_descriptor=min_descriptor,\n", + " max_descriptor=max_descriptor,\n", ")" ] }, @@ -421,7 +418,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", "import math\n", "\n", "import matplotlib.pyplot as plt\n", @@ -445,19 +441,12 @@ " repertoire.genotypes.expl_noise,\n", " -jnp.inf * jnp.ones_like(repertoire.fitnesses),\n", " ),\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", " ax=axes[row_i, col_i],\n", " )\n", " axes[row_i, col_i].set_title(f\"Grid after {env_step_multiplier * i} steps\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/examples/mees.ipynb b/examples/mees.ipynb index 765a5986..39f9b1a8 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -15,7 +15,7 @@ "id": "b4mIajPjoQvB" }, "source": [ - "# Optimizing with MEES in Jax\n", + "# Optimizing with MEES in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers with MAP-Elites-ES introduced in [Scaling MAP-Elites to Deep Neuroevolution](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -54,8 +54,7 @@ }, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -72,7 +71,7 @@ "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax import environments\n", - "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.tasks.brax_envs import scoring_function_brax_envs\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", @@ -107,8 +106,8 @@ "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "\n", "#@title MEES Emitter Definitions Fields\n", "sample_number = 1000 #@param {type:\"integer\"}\n", @@ -151,7 +150,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -162,13 +161,13 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "fake_batch = jnp.zeros(shape=(1, env.observation_size))\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0)\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "# Create the initial environment state\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "init_state = env.reset(subkey)" ] }, @@ -191,9 +190,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -215,7 +214,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -226,7 +225,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -238,13 +237,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", - " reset_based_scoring_function_brax_envs,\n", + " scoring_function_brax_envs,\n", " episode_length=episode_length,\n", - " play_reset_fn=lambda random_key: init_state,\n", + " play_reset_fn=lambda key: init_state,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Prepare the scoring functions for the offspring generated following\n", @@ -316,7 +315,7 @@ " config=mees_emitter_config,\n", " total_generations=num_iterations,\n", " scoring_fn=scoring_fn,\n", - " num_descriptors=env.behavior_descriptor_length,\n", + " num_descriptors=env.descriptor_length,\n", ")" ] }, @@ -349,18 +348,20 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_variables, centroids, random_key\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(\n", + " init_variables, centroids, subkey\n", ")" ] }, @@ -378,39 +379,37 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", + "metrics = dict.fromkeys([\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"], jnp.array([]))\n", "csv_logger = CSVLogger(\n", " \"mees-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + " header=list(metrics.keys())\n", ")\n", - "all_metrics = {}\n", "\n", - "# main loop\n", + "# Main loop\n", + "map_elites_scan_update = map_elites.scan_update\n", "for i in range(num_loops):\n", " start_time = time.time()\n", - " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", - " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", + " map_elites_scan_update,\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", " timelapse = time.time() - start_time\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", - "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", "\n", - " csv_logger.log(logged_metrics)" + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" ] }, { @@ -423,11 +422,12 @@ "source": [ "#@title Visualization\n", "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(num_iterations) * episode_length\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", "\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "%matplotlib inline\n", + "# Create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" ] } ], @@ -455,7 +455,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/mels.ipynb b/examples/mels.ipynb index ed5a7c7a..f9b93d57 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -50,8 +50,7 @@ }, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -69,7 +68,7 @@ "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax.core.containers.mels_repertoire import MELSRepertoire\n", "from qdax import environments\n", - "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.tasks.brax_envs import scoring_function_brax_envs\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", @@ -114,8 +113,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "#@markdown ---" ] }, @@ -140,7 +139,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -152,7 +151,7 @@ "\n", "# Init population of controllers. There are batch_size controllers, and each\n", "# controller will be evaluated num_samples times.\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=batch_size)\n", "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" @@ -179,7 +178,7 @@ "def play_step_fn(\n", " env_state,\n", " policy_params,\n", - " random_key,\n", + " key,\n", "):\n", " \"\"\"Play an environment step and return the updated state and the\n", " transition.\"\"\"\n", @@ -200,7 +199,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -209,7 +208,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. Note that while the MAP-Elites tutorial uses `scoring_function_brax_envs` as the basis for the scoring function, we use `reset_based_scoring_function_brax_envs`. The difference is that `reset_based_scoring_function_brax_envs` generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity in the behavior." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual. Note that while the MAP-Elites tutorial uses `scoring_function_brax_envs` as the basis for the scoring function, we use `scoring_function_brax_envs`. The difference is that `scoring_function_brax_envs` generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity." ] }, { @@ -221,13 +220,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", - " reset_based_scoring_function_brax_envs,\n", + " scoring_function_brax_envs,\n", " episode_length=episode_length,\n", " play_reset_fn=env.reset,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -293,17 +292,19 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# Compute initial repertoire and emitter state\n", - "repertoire, emitter_state, random_key = mels.init(init_variables, centroids, random_key)" + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = mels.init(init_variables, centroids, subkey)" ] }, { @@ -322,40 +323,37 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", + "metrics = dict.fromkeys([\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"], jnp.array([]))\n", "csv_logger = CSVLogger(\n", - " \"mapelites-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + " \"mels-logs.csv\",\n", + " header=list(metrics.keys())\n", ")\n", - "all_metrics = {}\n", "\n", - "# main loop\n", + "# Main loop\n", "mels_scan_update = mels.scan_update\n", "for i in range(num_loops):\n", " start_time = time.time()\n", - " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", " mels_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", " timelapse = time.time() - start_time\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", - "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", "\n", - " csv_logger.log(logged_metrics)" + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" ] }, { @@ -366,11 +364,12 @@ "source": [ "#@title Visualization\n", "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", "\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "%matplotlib inline\n", + "# Create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" ] }, { @@ -414,7 +413,7 @@ "outputs": [], "source": [ "# Init population of policies\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", "fake_params = policy_network.init(subkey, fake_batch)\n", "\n", @@ -443,7 +442,7 @@ "source": [ "## Get the best individual of the repertoire\n", "\n", - "Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its `num_samples` behavior descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell." + "Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its `num_samples` descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell." ] }, { @@ -454,7 +453,7 @@ "source": [ "best_idx = jnp.argmax(repertoire.fitnesses)\n", "best_fitness = jnp.max(repertoire.fitnesses)\n", - "best_bd = repertoire.descriptors[best_idx]\n", + "best_descriptor = repertoire.descriptors[best_idx]\n", "best_spread = repertoire.spreads[best_idx]" ] }, @@ -466,7 +465,7 @@ "source": [ "print(\n", " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\"\n", - " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\"\n", + " f\"Descriptor of the best individual in the repertoire: {best_descriptor}\\n\"\n", " f\"Spread of the best individual in the repertoire: {best_spread}\\n\"\n", " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", ")" @@ -478,7 +477,7 @@ "metadata": {}, "outputs": [], "source": [ - "my_params = jax.tree_util.tree_map(\n", + "my_params = jax.tree.map(\n", " lambda x: x[best_idx],\n", " repertoire.genotypes\n", ")" @@ -509,8 +508,8 @@ "outputs": [], "source": [ "rollout = []\n", - "rng = jax.random.key(seed=1)\n", - "state = jit_env_reset(rng=rng)\n", + "key, subkey = jax.random.split(key)\n", + "state = jit_env_reset(rng=subkey)\n", "while not state.done:\n", " rollout.append(state)\n", " action = jit_inference_fn(my_params, state.obs)\n", @@ -550,7 +549,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 8840b9e8..29606f1b 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -13,7 +13,7 @@ "id": "1", "metadata": {}, "source": [ - "# Optimizing multiple objectives with MOME in Jax\n", + "# Optimizing multiple objectives with MOME in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [Multi-Objective MAP-Elites](https://arxiv.org/pdf/2202.03057.pdf) (MOME) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", @@ -160,9 +160,9 @@ "source": [ "scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)\n", "\n", - "def scoring_fn(genotypes: jnp.ndarray, random_key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:\n", + "def scoring_fn(genotypes: jnp.ndarray, key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:\n", " fitnesses, descriptors = scoring_function(genotypes)\n", - " return fitnesses, descriptors, {}, random_key" + " return fitnesses, descriptors, {}" ] }, { @@ -205,10 +205,10 @@ "outputs": [], "source": [ "# initial population\n", - "random_key = jax.random.key(42)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key = jax.random.key(42)\n", + "key, subkey = jax.random.split(key)\n", "genotypes = jax.random.uniform(\n", - " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", + " subkey, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", "# crossover function\n", @@ -250,13 +250,14 @@ "metadata": {}, "outputs": [], "source": [ - "centroids, random_key = compute_cvt_centroids(\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", " num_descriptors=2,\n", " num_init_cvt_samples=20000,\n", " num_centroids=num_centroids,\n", " minval=minval,\n", " maxval=maxval,\n", - " random_key=random_key,\n", + " key=subkey,\n", ")" ] }, @@ -297,11 +298,12 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = mome.init(\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = mome.init(\n", " genotypes,\n", " centroids,\n", " pareto_front_max_length,\n", - " random_key\n", + " subkey\n", ")" ] }, @@ -320,12 +322,10 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", "# Run the algorithm\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " mome.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" @@ -425,7 +425,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index a7347130..6c90aca7 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing multiple objectives with NSGA2 & SPEA2 in Jax\n", + "# Optimizing multiple objectives with NSGA2 & SPEA2 in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [NSGA2](https://ieeexplore.ieee.org/document/996017) and [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) algorithms. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", @@ -163,8 +163,8 @@ " base_lag=base_lag\n", ")\n", "\n", - "def scoring_fn(x, random_key):\n", - " return scoring_function(x)[0], {}, random_key" + "def scoring_fn(x, key):\n", + " return scoring_function(x)[0], {}" ] }, { @@ -181,8 +181,8 @@ "outputs": [], "source": [ "# Initial population\n", - "random_key = jax.random.key(0)\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key = jax.random.key(0)\n", + "key, subkey = jax.random.split(key)\n", "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", @@ -231,10 +231,11 @@ ")\n", "\n", "# init nsga2\n", - "repertoire, emitter_state, random_key = nsga2.init(\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = nsga2.init(\n", " genotypes,\n", " population_size,\n", - " random_key\n", + " subkey\n", ")" ] }, @@ -251,11 +252,9 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", - "# run optimization loop\n", - "(repertoire, emitter_state, random_key), _ = jax.lax.scan(\n", - " nsga2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations\n", + "# Run optimization loop\n", + "(repertoire, emitter_state, key), _ = jax.lax.scan(\n", + " nsga2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations\n", ")" ] }, @@ -296,11 +295,12 @@ ")\n", "\n", "# init spea2\n", - "repertoire, emitter_state, random_key = spea2.init(\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = spea2.init(\n", " genotypes,\n", " population_size,\n", " num_neighbours,\n", - " random_key\n", + " subkey,\n", ")" ] }, @@ -310,11 +310,9 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", "# run optimization loop\n", - "(repertoire, emitter_state, random_key), _ = jax.lax.scan(\n", - " spea2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations\n", + "(repertoire, emitter_state, key), _ = jax.lax.scan(\n", + " spea2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations\n", ")" ] }, @@ -358,7 +356,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 5f3c69eb..e250fee1 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with OMG-MEGA in Jax\n", + "# Optimizing with OMG-MEGA in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with [OMG-MEGA](https://arxiv.org/pdf/2106.03894.pdf).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -134,9 +134,9 @@ " gradients = jnp.nan_to_num(gradients)\n", " return fitnesses, descriptors, {\"gradients\": gradients}\n", "\n", - "def scoring_fn(x, random_key):\n", + "def scoring_fn(x, key):\n", " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", - " return fitnesses, descriptors, extra_scores, random_key" + " return fitnesses, descriptors, extra_scores" ] }, { @@ -184,10 +184,10 @@ "metadata": {}, "outputs": [], "source": [ - "random_key = jax.random.key(0)\n", + "key = jax.random.key(0)\n", "\n", "# defines the population\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "initial_population = jax.random.normal(subkey, shape=(init_population_size, num_dimensions))\n", "\n", "sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid\n", @@ -227,7 +227,8 @@ "metadata": {}, "outputs": [], "source": [ - "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(initial_population, centroids, subkey)" ] }, { @@ -243,11 +244,9 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", - "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + "(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n", " map_elites.scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=num_iterations,\n", ")" @@ -273,7 +272,7 @@ "\n", "# create the plots and the grid\n", "fig, axes = plot_map_elites_results(\n", - " env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=minval, max_bd=maxval\n", + " env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=minval, max_descriptor=maxval\n", ")" ] } @@ -294,7 +293,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 330e82ed..02239498 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with PGA-AURORA in Jax\n", + "# Optimizing with PGA-AURORA in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [PGA-AURORA](https://arxiv.org/abs/2210.03516).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -49,8 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -71,7 +70,7 @@ " create_default_brax_task_components,\n", " get_aurora_scoring_fn,\n", ")\n", - "from qdax.environments.bd_extractors import (\n", + "from qdax.environments.descriptor_extractors import (\n", " AuroraExtraInfoNormalization,\n", " get_aurora_encoding,\n", ")\n", @@ -110,8 +109,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "\n", "lstm_batch_size = 128 #@param {type:\"integer\"}\n", "\n", @@ -164,7 +163,7 @@ "env = environments.create(env_name, episode_length=episode_length)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -175,14 +174,14 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=env_batch_size)\n", "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", "\n", "\n", "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", "reset_fn = jax.jit(jax.vmap(env.reset))\n", "init_states = reset_fn(keys)" @@ -205,9 +204,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -229,7 +228,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -238,7 +237,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -248,9 +247,10 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n", + "key, subkey = jax.random.split(key)\n", + "env, policy_network, scoring_fn = create_default_brax_task_components(\n", " env_name=env_name,\n", - " random_key=random_key,\n", + " key=subkey,\n", ")\n", "\n", "def observation_extractor_fn(\n", @@ -293,7 +293,7 @@ " coverage = 100 * jnp.mean(1.0 - grid_empty)\n", " max_fitness = jnp.max(repertoire.fitnesses)\n", "\n", - " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}\n" + " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}" ] }, { @@ -368,25 +368,26 @@ "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", "\n", "@jax.jit\n", - "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + "def update_scan_fn(carry: Any, _: Any) -> Any:\n", " \"\"\"Scan the update function.\"\"\"\n", " (\n", " repertoire,\n", " emitter_state,\n", - " random_key,\n", + " key,\n", " aurora_extra_info\n", " ) = carry\n", "\n", " # update\n", - " (repertoire, emitter_state, metrics, random_key,) = aurora.update(\n", + " key, subkey = jax.random.split(key)\n", + " repertoire, emitter_state, metrics = aurora.update(\n", " repertoire,\n", " emitter_state,\n", - " random_key,\n", + " subkey,\n", " aurora_extra_info=aurora_extra_info,\n", " )\n", "\n", " return (\n", - " (repertoire, emitter_state, random_key, aurora_extra_info),\n", + " (repertoire, emitter_state, key, aurora_extra_info),\n", " metrics,\n", " )\n", "\n", @@ -414,7 +415,7 @@ ")\n", "\n", "# Init the model params\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", ")\n", @@ -446,7 +447,7 @@ ")\n", "\n", "# init the model params\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", ")\n", @@ -465,18 +466,19 @@ ")\n", "\n", "# init step of the aurora algorithm\n", - "repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state, aurora_extra_info = aurora.init(\n", " init_variables,\n", " aurora_extra_info,\n", " jnp.asarray(l_value_init),\n", " max_observation_size,\n", - " random_key,\n", + " subkey,\n", ")\n", "\n", "# initializing means and stds and AURORA\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "repertoire, aurora_extra_info = aurora.train(\n", - " repertoire, model_params, iteration=0, random_key=subkey\n", + " repertoire, model_params, iteration=0, key=subkey\n", ")\n", "\n", "# design aurora's schedule\n", @@ -510,11 +512,11 @@ "while iteration < max_iterations:\n", "\n", " (\n", - " (repertoire, emitter_state, random_key, aurora_extra_info),\n", + " (repertoire, emitter_state, key, aurora_extra_info),\n", " metrics,\n", " ) = jax.lax.scan(\n", " update_scan_fn,\n", - " (repertoire, emitter_state, random_key, aurora_extra_info),\n", + " (repertoire, emitter_state, key, aurora_extra_info),\n", " (),\n", " length=log_freq,\n", " )\n", @@ -527,7 +529,7 @@ " # autoencoder steps and CVC\n", " if (iteration + 1) in schedules:\n", " # train the autoencoder\n", - " random_key, subkey = jax.random.split(random_key)\n", + " key, subkey = jax.random.split(key)\n", " repertoire, aurora_extra_info = aurora.train(\n", " repertoire, model_params, iteration, subkey\n", " )\n", @@ -573,7 +575,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index bd83b0e8..9bdbb2a2 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with PGAME in Jax\n", + "# Optimizing with PGAME in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Policy Gradient Assisted MAP-Elites](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -48,8 +48,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -100,8 +99,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "\n", "proportion_mutation_ga = 0.5 #@param {type:\"number\"}\n", "\n", @@ -141,9 +140,10 @@ "source": [ "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", + "reset_fn = jax.jit(env.reset)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -154,16 +154,10 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=env_batch_size)\n", "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", - "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", - "\n", - "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", - "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", - "reset_fn = jax.jit(jax.vmap(env.reset))\n", - "init_states = reset_fn(keys)" + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" ] }, { @@ -181,9 +175,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -205,7 +199,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -214,7 +208,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -224,13 +218,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", " scoring_function,\n", - " init_states=init_states,\n", " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -319,18 +313,20 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_variables, centroids, random_key\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(\n", + " init_variables, centroids, subkey\n", ")" ] }, @@ -341,40 +337,37 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", + "metrics = dict.fromkeys([\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"], jnp.array([]))\n", "csv_logger = CSVLogger(\n", " \"pgame-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + " header=list(metrics.keys())\n", ")\n", - "all_metrics = {}\n", "\n", - "# main loop\n", + "# Main loop\n", "map_elites_scan_update = map_elites.scan_update\n", "for i in range(num_loops):\n", " start_time = time.time()\n", - " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", " timelapse = time.time() - start_time\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", - "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", "\n", - " csv_logger.log(logged_metrics)" + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" ] }, { @@ -385,33 +378,13 @@ "source": [ "#@title Visualization\n", "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(num_iterations) * episode_length * env_batch_size\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", "\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "%matplotlib inline\n", + "# Create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -430,7 +403,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index b2e68e35..48d1baf8 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Optimizing with QDPG in Jax\n", + "# Optimizing with QDPG in JAX\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [QDPG - Quality Diversity Policy Gradient in MAP-Elites](https://arxiv.org/abs/2006.08505).\n", "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", @@ -48,8 +48,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -104,8 +103,8 @@ "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", "num_centroids = 1024 #@param {type:\"integer\"}\n", - "min_bd = 0. #@param {type:\"number\"}\n", - "max_bd = 1.0 #@param {type:\"number\"}\n", + "min_descriptor = 0. #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", "\n", "# mutations size\n", "quality_pg_batch_size = 34 #@param {type:\"integer\"}\n", @@ -155,9 +154,10 @@ "source": [ "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", + "reset_fn = jax.jit(env.reset)\n", "\n", "# Init a random key\n", - "random_key = jax.random.key(seed)\n", + "key = jax.random.key(seed)\n", "\n", "# Init policy network\n", "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", @@ -168,16 +168,10 @@ ")\n", "\n", "# Init population of controllers\n", - "random_key, subkey = jax.random.split(random_key)\n", + "key, subkey = jax.random.split(key)\n", "keys = jax.random.split(subkey, num=env_batch_size)\n", "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", - "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", - "\n", - "# Create the initial environment states\n", - "random_key, subkey = jax.random.split(random_key)\n", - "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", - "reset_fn = jax.jit(jax.vmap(env.reset))\n", - "init_states = reset_fn(keys)" + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" ] }, { @@ -195,9 +189,9 @@ "source": [ "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", - " env_state,\n", - " policy_params,\n", - " random_key,\n", + " env_state,\n", + " policy_params,\n", + " key,\n", "):\n", " \"\"\"\n", " Play an environment step and return the updated state and the transition.\n", @@ -219,7 +213,7 @@ " next_state_desc=next_state.info[\"state_descriptor\"],\n", " )\n", "\n", - " return next_state, policy_params, random_key, transition" + " return next_state, policy_params, key, transition" ] }, { @@ -228,7 +222,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." ] }, { @@ -238,13 +232,13 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "descriptor_extraction_fn = environments.descriptor_extractor[env_name]\n", "scoring_fn = functools.partial(\n", " scoring_function,\n", - " init_states=init_states,\n", " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + " descriptor_extractor=descriptor_extraction_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -367,18 +361,20 @@ ")\n", "\n", "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", + "key, subkey = jax.random.split(key)\n", + "centroids = compute_cvt_centroids(\n", + " num_descriptors=env.descriptor_length,\n", " num_init_cvt_samples=num_init_cvt_samples,\n", " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + " minval=min_descriptor,\n", + " maxval=max_descriptor,\n", + " key=subkey,\n", ")\n", "\n", "# compute initial repertoire\n", - "repertoire, emitter_state, random_key = map_elites.init(\n", - " init_variables, centroids, random_key\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state = map_elites.init(\n", + " init_variables, centroids, subkey\n", ")" ] }, @@ -389,40 +385,37 @@ "outputs": [], "source": [ "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "num_loops = num_iterations // log_period\n", "\n", + "metrics = dict.fromkeys([\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"], jnp.array([]))\n", "csv_logger = CSVLogger(\n", " \"qdpg-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + " header=list(metrics.keys())\n", ")\n", - "all_metrics = {}\n", "\n", - "# main loop\n", + "# Main loop\n", "map_elites_scan_update = map_elites.scan_update\n", "for i in range(num_loops):\n", " start_time = time.time()\n", - " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", + " (repertoire, emitter_state, key),\n", " (),\n", " length=log_period,\n", " )\n", " timelapse = time.time() - start_time\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", - "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", "\n", - " csv_logger.log(logged_metrics)" + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" ] }, { @@ -433,33 +426,13 @@ "source": [ "#@title Visualization\n", "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(num_iterations) * episode_length * env_batch_size\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", "\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "%matplotlib inline\n", + "# Create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -481,7 +454,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index b4847bd8..d3572974 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -139,7 +139,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# Initialize environments\n", "env = environments.create(\n", " env_name=env_name,\n", @@ -171,13 +170,13 @@ "outputs": [], "source": [ "@jax.jit\n", - "def init_environments(random_key):\n", + "def init_environments(key):\n", "\n", - " env_states = jax.jit(env.reset)(rng=random_key)\n", - " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", + " env_states = jax.jit(env.reset)(rng=key)\n", + " eval_env_first_states = jax.jit(eval_env.reset)(rng=key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_util.tree_map(\n", + " lambda tree: jax.tree.map(\n", " lambda x: jnp.reshape(\n", " x,\n", " (\n", @@ -209,7 +208,6 @@ }, "outputs": [], "source": [ - "# %%time\n", "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", @@ -261,7 +259,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# get the initial training states and replay buffers\n", "agent_init_fn = agent.get_init_fn(\n", " population_size=population_size_per_device,\n", @@ -273,7 +270,7 @@ "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", "keys = jax.random.key_data(keys)\n", "\n", - "keys, training_states, replay_buffers = jax.pmap(\n", + "training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)" ] @@ -310,7 +307,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# eval policy before training\n", "population_returns, _ = eval_policy(training_states, eval_env_first_states)\n", "population_returns = jnp.reshape(population_returns, (population_size,))\n", @@ -385,8 +381,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_util.tree_map(\n", + " tree = jax.tree.map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree.map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -406,7 +402,6 @@ }, "outputs": [], "source": [ - "%%time\n", "for i in tqdm(range(num_loops), total=num_loops):\n", "\n", " # Update for num_steps\n", @@ -427,21 +422,15 @@ "\n", " # PBT selection\n", " if i < (num_loops-1):\n", - " keys, training_states, replay_buffers = select_fn(\n", + " training_states, replay_buffers = select_fn(\n", " keys, population_returns, training_states, replay_buffers\n", " )" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "16", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], + "metadata": {}, "source": [ "### Visualize learnt behaviors" ] @@ -459,7 +448,7 @@ "source": [ "training_states = unshard_fn(training_states)\n", "best_idx = jnp.argmax(population_returns)\n", - "best_training_state = jax.tree_util.tree_map(lambda x: x[best_idx], training_states)" + "best_training_state = jax.tree.map(lambda x: x[best_idx], training_states)" ] }, { @@ -519,13 +508,12 @@ }, "outputs": [], "source": [ - "%%time\n", "rollout = []\n", "\n", - "rng = jax.random.key(seed=1)\n", - "env_state = jax.jit(env.reset)(rng=rng)\n", + "key, subkey = jax.random.split(key)\n", + "env_state = jax.jit(env.reset)(rng=subkey)\n", "\n", - "training_state, env_state = jax.tree_util.tree_map(\n", + "training_state, env_state = jax.tree.map(\n", " lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n", ")\n", "\n", @@ -549,7 +537,7 @@ "outputs": [], "source": [ "rollout = [\n", - " jax.tree_util.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", + " jax.tree.map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", " for env_state in rollout\n", "]" ] @@ -585,7 +573,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py deleted file mode 100644 index 294cca8e..00000000 --- a/examples/scripts/me_example.py +++ /dev/null @@ -1,108 +0,0 @@ -import functools - -import jax -import matplotlib.pyplot as plt - -from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids -from qdax.core.emitters.mutation_operators import isoline_variation -from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.core.map_elites import MAPElites -from qdax.tasks.arm import arm_scoring_function -from qdax.utils.metrics import default_qd_metrics -from qdax.utils.plotting import plot_2d_map_elites_repertoire - - -def run_me() -> None: - seed = 42 - num_param_dimensions = 8 # num DoF arm - init_batch_size = 100 - batch_size = 2048 - num_evaluations = int(1e6) - num_iterations = num_evaluations // batch_size - grid_shape = (100, 100) - min_param = 0.0 - max_param = 1.0 - min_bd = 0.0 - max_bd = 1.0 - - # Init a random key - random_key = jax.random.key(seed) - - # Init population of controllers - random_key, subkey = jax.random.split(random_key) - init_variables = jax.random.uniform( - subkey, - shape=(init_batch_size, num_param_dimensions), - minval=min_param, - maxval=max_param, - ) - - # Define emitter - variation_fn = functools.partial( - isoline_variation, - iso_sigma=0.005, - line_sigma=0, - minval=min_param, - maxval=max_param, - ) - mixing_emitter = MixingEmitter( - mutation_fn=lambda x, y: (x, y), - variation_fn=variation_fn, - variation_percentage=1.0, - batch_size=batch_size, - ) - - # Define a metrics function - metrics_fn = functools.partial( - default_qd_metrics, - qd_offset=0.0, - ) - - # Instantiate MAP-Elites - map_elites = MAPElites( - scoring_function=arm_scoring_function, - emitter=mixing_emitter, - metrics_function=metrics_fn, - ) - - # Compute the centroids - centroids = compute_euclidean_centroids( - grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, - ) - - # Initializes repertoire and emitter state - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) - - # Run MAP-Elites loop - for _ in range(num_iterations): - ( - repertoire, - emitter_state, - metrics, - random_key, - ) = map_elites.update( - repertoire, - emitter_state, - random_key, - ) - - # plot archive - fig, axes = plot_2d_map_elites_repertoire( - centroids=repertoire.centroids, - repertoire_fitnesses=repertoire.fitnesses, - minval=min_bd, - maxval=max_bd, - repertoire_descriptors=repertoire.descriptors, - # vmin=-0.2, - # vmax=0.0, - ) - - plt.show() - - -if __name__ == "__main__": - run_me() diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 08ee56e6..6b290b49 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Training DIAYN SMERL with Jax\n", + "# Training DIAYN SMERL with JAX\n", "\n", "This notebook shows how to use QDax to train DIAYN SMERL on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "- how to define an environment\n", @@ -46,8 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!pip install ipympl |tail -n 1\n", + "!pip install ipympl | tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", "# output.enable_custom_widget_manager()\n", @@ -78,7 +77,7 @@ " colab_tpu.setup_tpu()\n", "\n", "\n", - "clear_output()\n" + "clear_output()" ] }, { @@ -168,15 +167,17 @@ " eval_metrics=True,\n", ")\n", "\n", - "key = jax.Key(seed)\n", - "env_state = jax.jit(env.reset)(rng=key)\n", - "eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n", + "key = jax.random.key(seed)\n", + "\n", + "key, subkey_1, subkey_2 = jax.random.split(key, 3)\n", + "env_state = jax.jit(env.reset)(rng=subkey_1)\n", + "eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2)\n", "\n", "# Initialize buffer\n", "dummy_transition = QDTransition.init_dummy(\n", " observation_dim=env.observation_size + num_skills,\n", " action_dim=env.action_size,\n", - " descriptor_dim=env.behavior_descriptor_length,\n", + " descriptor_dim=env.descriptor_length,\n", ")\n", "\n", "# Use a trajectory replay buffer\n", @@ -228,11 +229,12 @@ "if descriptor_full_state:\n", " descriptor_size = env.observation_size\n", "else:\n", - " descriptor_size = env.behavior_descriptor_length\n", + " descriptor_size = env.descriptor_length\n", "\n", "# get the initial training state\n", + "key, subkey = jax.random.split(key)\n", "training_state = diayn_smerl.init(\n", - " key,\n", + " subkey,\n", " action_size=env.action_size,\n", " observation_size=env.observation_size,\n", " descriptor_size=descriptor_size,\n", @@ -366,8 +368,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%time\n", - "\n", "# Main loop\n", "(training_state, env_state, replay_buffer), metrics = jax.lax.scan(\n", " _scan_do_iteration,\n", @@ -484,10 +484,11 @@ "jit_env_step = jax.jit(visual_env.step)\n", "\n", "@jax.jit\n", - "def jit_inference_fn(params, observation, random_key):\n", + "def jit_inference_fn(params, observation, key):\n", " obs = jnp.concatenate([observation, skill], axis=0)\n", - " action, random_key = diayn_smerl.select_action(obs, params, random_key, deterministic=True)\n", - " return action, random_key" + " key, subkey = jax.random.split(key)\n", + " action = diayn_smerl.select_action(obs, params, subkey, deterministic=True)\n", + " return action" ] }, { @@ -504,11 +505,12 @@ "outputs": [], "source": [ "rollout = []\n", - "random_key = jax.random.key(seed=1)\n", - "state = jit_env_reset(rng=random_key)\n", + "key, subkey = jax.random.split(key)\n", + "state = jit_env_reset(rng=subkey)\n", "while not state.done:\n", " rollout.append(state)\n", - " action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n", + " key, subkey = jax.random.split(key)\n", + " action = jit_inference_fn(my_params, state.obs, subkey)\n", " state = jit_env_step(state, action)\n", "\n", "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" @@ -540,7 +542,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 90254907..1a0851e5 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -121,7 +121,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# Initialize environments\n", "env = environments.create(\n", " env_name=env_name,\n", @@ -150,13 +149,13 @@ "outputs": [], "source": [ "@jax.jit\n", - "def init_environments(random_key):\n", + "def init_environments(key):\n", "\n", - " env_states = jax.jit(env.reset)(rng=random_key)\n", - " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", + " env_states = jax.jit(env.reset)(rng=key)\n", + " eval_env_first_states = jax.jit(eval_env.reset)(rng=key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_util.tree_map(\n", + " lambda tree: jax.tree.map(\n", " lambda x: jnp.reshape(\n", " x, (population_size_per_device, env_batch_size,) + x.shape[1:]\n", " ),\n", @@ -180,7 +179,6 @@ }, "outputs": [], "source": [ - "%%time\n", "key = jax.random.key(seed)\n", "key, *keys = jax.random.split(key, num=1 + num_devices)\n", "keys = jnp.stack(keys)\n", @@ -224,7 +222,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# get the initial training states and replay buffers\n", "agent_init_fn = agent.get_init_fn(\n", " population_size=population_size_per_device,\n", @@ -236,7 +233,7 @@ "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", "keys = jax.random.key_data(keys)\n", "\n", - "keys, training_states, replay_buffers = jax.pmap(\n", + "training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)" ] @@ -267,7 +264,6 @@ }, "outputs": [], "source": [ - "%%time\n", "# eval policy before training\n", "population_returns, _ = eval_policy(training_states, eval_env_first_states)\n", "population_returns = jnp.reshape(population_returns, (population_size,))\n", @@ -333,8 +329,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_util.tree_map(\n", + " tree = jax.tree.map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree.map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -351,7 +347,6 @@ }, "outputs": [], "source": [ - "%%time\n", "for i in tqdm(range(num_loops), total=num_loops):\n", "\n", " # Update for num_steps\n", @@ -372,22 +367,10 @@ "\n", " # PBT selection\n", " if i < (num_loops-1):\n", - " keys, training_states, replay_buffers = select_fn(\n", + " training_states, replay_buffers = select_fn(\n", " keys, population_returns, training_states, replay_buffers\n", " )" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [] } ], "metadata": { @@ -406,7 +389,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/qdax/baselines/cmaes.py b/qdax/baselines/cmaes.py index ffe2c811..2e6303ee 100644 --- a/qdax/baselines/cmaes.py +++ b/qdax/baselines/cmaes.py @@ -166,28 +166,25 @@ def init(self) -> CMAESState: ) @partial(jax.jit, static_argnames=("self",)) - def sample( - self, cmaes_state: CMAESState, random_key: RNGKey - ) -> Tuple[Genotype, RNGKey]: + def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype: """ Sample a population. Args: cmaes_state: current state of the algorithm - random_key: jax random key + key: jax random key Returns: A tuple that contains a batch of population size genotypes and a new random key. """ - random_key, subkey = jax.random.split(random_key) samples = jax.random.multivariate_normal( - subkey, + key, shape=(self._population_size,), mean=cmaes_state.mean, cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix, ) - return samples, random_key + return samples @partial(jax.jit, static_argnames=("self",)) def update_state( diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index be79286f..7e2969c7 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -40,7 +40,7 @@ class DadsTrainingState(TrainingState): target_critic_params: Params dynamics_optimizer_state: optax.OptState dynamics_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray normalization_running_stats: RunningMeanStdState @@ -120,7 +120,7 @@ def __init__(self, config: DadsConfig, action_size: int, descriptor_size: int): def init( # type: ignore self, - random_key: RNGKey, + key: RNGKey, action_size: int, observation_size: int, descriptor_size: int, @@ -128,7 +128,7 @@ def init( # type: ignore """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space descriptor_size: the size of the environment's descriptor space (i.e. the @@ -143,17 +143,17 @@ def init( # type: ignore dummy_dyn_obs = jnp.zeros((1, descriptor_size)) dummy_skill = jnp.zeros((1, self._config.num_skills)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) policy_params = self._policy.init(subkey, dummy_obs) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), critic_params ) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) dynamics_params = self._dynamics.init( subkey, obs=dummy_dyn_obs, @@ -178,7 +178,7 @@ def init( # type: ignore target_critic_params=target_critic_params, dynamics_optimizer_state=dynamics_optimizer_state, dynamics_params=dynamics_params, - random_key=random_key, + key=key, normalization_running_stats=RunningMeanStdState( mean=jnp.zeros( descriptor_size, @@ -274,7 +274,7 @@ def play_step_fn( the played transition """ - random_key = training_state.random_key + key = training_state.key policy_params = training_state.policy_params obs = jnp.concatenate([env_state.obs, skills], axis=1) @@ -284,10 +284,11 @@ def play_step_fn( else: state_desc = jnp.zeros((env_state.obs.shape[0], 2)) - actions, random_key = self.select_action( + key, subkey = jax.random.split(key) + actions = self.select_action( obs=obs, policy_params=policy_params, - random_key=random_key, + key=subkey, deterministic=deterministic, ) @@ -324,26 +325,21 @@ def play_step_fn( actions=actions, truncations=truncations, ) + + key, subkey = jax.random.split(key) if not evaluation: training_state = training_state.replace( - random_key=random_key, + key=subkey, normalization_running_stats=normalization_running_stats, ) else: training_state = training_state.replace( - random_key=random_key, + key=subkey, ) return next_env_state, training_state, transition - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - "env_batch_size", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size")) def eval_policy_fn( self, training_state: DadsTrainingState, @@ -382,7 +378,7 @@ def eval_policy_fn( true_returns = jnp.nansum(transitions.rewards, axis=0) true_return = jnp.mean(true_returns, axis=-1) - reshaped_transitions = jax.tree_util.tree_map( + reshaped_transitions = jax.tree.map( lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)), transitions, ) @@ -479,13 +475,13 @@ def _update_networks( Args: training_state: the current training state of the algorithm. transitions: transitions sampled from a replay buffer. - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. Returns: The updated training state and training metrics. """ - random_key = training_state.random_key + key = training_state.key # Update skill-dynamics ( @@ -500,48 +496,49 @@ def _update_networks( ) # update alpha + key, subkey = jax.random.split(key) ( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, ) = self._update_alpha( alpha_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update critic + key, subkey = jax.random.split(key) ( critic_params, target_critic_params, critic_optimizer_state, critic_loss, - random_key, ) = self._update_critic( critic_lr=self._config.learning_rate, reward_scaling=self._config.reward_scaling, discount=self._config.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update actor + key, subkey = jax.random.split(key) ( policy_params, policy_optimizer_state, policy_loss, - random_key, ) = self._update_actor( policy_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # Create new training state + key, subkey = jax.random.split(key) new_training_state = DadsTrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, @@ -552,7 +549,7 @@ def _update_networks( target_critic_params=target_critic_params, dynamics_optimizer_state=dynamics_optimizer_state, dynamics_params=dynamics_params, - random_key=random_key, + key=subkey, normalization_running_stats=training_state.normalization_running_stats, steps=training_state.steps + 1, ) @@ -589,9 +586,11 @@ def update( """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + + key, subkey = jax.random.split(key) + transitions = replay_buffer.sample( + subkey, sample_size=self._config.batch_size, ) diff --git a/qdax/baselines/dads_smerl.py b/qdax/baselines/dads_smerl.py index 5bd8274d..5b91fb8e 100644 --- a/qdax/baselines/dads_smerl.py +++ b/qdax/baselines/dads_smerl.py @@ -92,14 +92,17 @@ def update( the replay buffer the training metrics """ + key = training_state.key # Sample a batch of transitions in the buffer - random_key = training_state.random_key - samples, returns, random_key = replay_buffer.sample_with_returns( - random_key, + key, subkey = jax.random.split(key) + samples, returns = replay_buffer.sample_with_returns( + subkey, sample_size=self._config.batch_size, ) + training_state = training_state.replace(key=key) + # Optionally replace the state descriptor by the observation if self._config.descriptor_full_state: _state_desc = samples.obs[:, : -self._config.num_skills] diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index 5f0d9a73..57ddda0b 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -35,7 +35,7 @@ class DiaynTrainingState(TrainingState): target_critic_params: Params discriminator_optimizer_state: optax.OptState discriminator_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -116,7 +116,7 @@ def __init__(self, config: DiaynConfig, action_size: int): def init( # type: ignore self, - random_key: RNGKey, + key: RNGKey, action_size: int, observation_size: int, descriptor_size: int, @@ -124,7 +124,7 @@ def init( # type: ignore """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space descriptor_size: the size of the environment's descriptor space (i.e. the @@ -139,17 +139,17 @@ def init( # type: ignore dummy_action = jnp.zeros((1, action_size)) dummy_discriminator_obs = jnp.zeros((1, descriptor_size)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) policy_params = self._policy.init(subkey, dummy_obs) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), critic_params ) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) discriminator_params = self._discriminator.init( subkey, obs=dummy_discriminator_obs ) @@ -173,7 +173,7 @@ def init( # type: ignore target_critic_params=target_critic_params, discriminator_optimizer_state=discriminator_optimizer_state, discriminator_params=discriminator_params, - random_key=random_key, + key=key, steps=jnp.array(0), ) @@ -239,7 +239,8 @@ def play_step_fn( the played transition """ - random_key = training_state.random_key + key = training_state.key + policy_params = training_state.policy_params obs = jnp.concatenate([env_state.obs, skills], axis=1) @@ -249,10 +250,11 @@ def play_step_fn( else: state_desc = jnp.zeros((env_state.obs.shape[0], 2)) - actions, random_key = self.select_action( + key, subkey = jax.random.split(key) + actions = self.select_action( obs=obs, policy_params=policy_params, - random_key=random_key, + key=subkey, deterministic=deterministic, ) @@ -273,18 +275,11 @@ def play_step_fn( actions=actions, truncations=truncations, ) - training_state = training_state.replace(random_key=random_key) + training_state = training_state.replace(key=key) return next_env_state, training_state, transition - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - "env_batch_size", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size")) def eval_policy_fn( self, training_state: DiaynTrainingState, @@ -324,7 +319,7 @@ def eval_policy_fn( true_returns = jnp.nansum(transitions.rewards, axis=0) true_return = jnp.mean(true_returns, axis=-1) - reshaped_transitions = jax.tree_util.tree_map( + reshaped_transitions = jax.tree.map( lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)), transitions, ) @@ -382,12 +377,12 @@ def _update_networks( Args: training_state: the current training state. transitions: transitions sampled from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The update training state, metrics and a new random key. """ - random_key = training_state.random_key + key = training_state.key # Compute discriminator loss and gradients discriminator_loss, discriminator_gradient = jax.value_and_grad( @@ -409,48 +404,49 @@ def _update_networks( ) # update alpha + key, subkey = jax.random.split(key) ( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, ) = self._update_alpha( alpha_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update critic + key, subkey = jax.random.split(key) ( critic_params, target_critic_params, critic_optimizer_state, critic_loss, - random_key, ) = self._update_critic( critic_lr=self._config.learning_rate, reward_scaling=self._config.reward_scaling, discount=self._config.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update actor + key, subkey = jax.random.split(key) ( policy_params, policy_optimizer_state, policy_loss, - random_key, ) = self._update_actor( policy_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # Create new training state + key, subkey = jax.random.split(key) new_training_state = DiaynTrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, @@ -461,7 +457,7 @@ def _update_networks( target_critic_params=target_critic_params, discriminator_optimizer_state=discriminator_optimizer_state, discriminator_params=discriminator_params, - random_key=random_key, + key=subkey, steps=training_state.steps + 1, ) metrics = { @@ -492,9 +488,11 @@ def update( the training metrics """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + + key, subkey = jax.random.split(key) + transitions = replay_buffer.sample( + subkey, sample_size=self._config.batch_size, ) diff --git a/qdax/baselines/diayn_smerl.py b/qdax/baselines/diayn_smerl.py index daacaa74..503f8e9c 100644 --- a/qdax/baselines/diayn_smerl.py +++ b/qdax/baselines/diayn_smerl.py @@ -99,14 +99,17 @@ def update( the replay buffer the training metrics """ - # Sample a batch of transitions in the buffer - random_key = training_state.random_key + key = training_state.key - samples, returns, random_key = replay_buffer.sample_with_returns( - random_key, + # Sample a batch of transitions in the buffer + key, subkey = jax.random.split(key) + samples, returns = replay_buffer.sample_with_returns( + subkey, sample_size=self._config.batch_size, ) + training_state = training_state.replace(key=key) + # Optionally replace the state descriptor by the observation if self._config.descriptor_full_state: state_desc = samples.obs[:, : -self._config.num_skills] diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index b4c6a32f..1cef4410 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -28,9 +28,7 @@ class GeneticAlgorithm: def __init__( self, - scoring_function: Callable[ - [Genotype, RNGKey], Tuple[Fitness, ExtraScores, RNGKey] - ], + scoring_function: Callable[[Genotype, RNGKey], Tuple[Fitness, ExtraScores]], emitter: Emitter, metrics_function: Callable[[GARepertoire], Metrics], ) -> None: @@ -40,23 +38,22 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, genotypes: Genotype, population_size: int, random_key: RNGKey - ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: + self, genotypes: Genotype, population_size: int, key: RNGKey + ) -> Tuple[GARepertoire, Optional[EmitterState]]: """Initialize a GARepertoire with an initial population of genotypes. Args: genotypes: the initial population of genotypes population_size: the maximal size of the repertoire - random_key: a random key to handle stochastic operations + key: a random key to handle stochastic operations Returns: The initial repertoire, an initial emitter state and a new random key. """ # score initial genotypes - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = GARepertoire.init( @@ -66,8 +63,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -75,15 +73,15 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: GARepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics, RNGKey]: + key: RNGKey, + ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics]: """ Performs one iteration of a Genetic algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -94,7 +92,7 @@ def update( Args: repertoire: a repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the updated MAP-Elites repertoire @@ -104,14 +102,12 @@ def update( """ # generate offsprings - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) # score the offsprings - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # update the repertoire repertoire = repertoire.add(genotypes, fitnesses) @@ -129,13 +125,13 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics @partial(jax.jit, static_argnames=("self",)) def scan_update( self, carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey], - unused: Any, + _: Any, ) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive. @@ -143,15 +139,16 @@ def scan_update( Args: carry: a tuple containing the repertoire, the emitter state and a random key. - unused: unused element, necessary to respect jax.lax.scan API. + _: unused element, necessary to respect jax.lax.scan API. Returns: The updated repertoire and emitter state, with a new random key and metrics. """ # iterate over grid - repertoire, emitter_state, random_key = carry - repertoire, emitter_state, metrics, random_key = self.update( - repertoire, emitter_state, random_key + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = self.update( + repertoire, emitter_state, subkey ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state, key), metrics diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index afd587af..b802fa96 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -28,13 +28,12 @@ class NSGA2(GeneticAlgorithm): @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, genotypes: Genotype, population_size: int, random_key: RNGKey - ) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]: + self, genotypes: Genotype, population_size: int, key: RNGKey + ) -> Tuple[NSGA2Repertoire, Optional[EmitterState]]: # score initial genotypes - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = NSGA2Repertoire.init( @@ -44,8 +43,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -62,4 +62,4 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/baselines/pbt.py b/qdax/baselines/pbt.py index 6555c537..7f7bb610 100644 --- a/qdax/baselines/pbt.py +++ b/qdax/baselines/pbt.py @@ -88,55 +88,54 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def update_states_and_buffer( self, - random_key: RNGKey, + key: RNGKey, population_returns: jnp.ndarray, training_state: PBTTrainingState, replay_buffer: ReplayBuffer, - ) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]: + ) -> Tuple[PBTTrainingState, ReplayBuffer]: """ Updates the agents of the population states as well as their shared replay buffer. Args: - random_key: Random RNG key. + key: Random key. population_returns: Returns of the agents in the populations. training_state: The training state of the PBT scheme. replay_buffer: Shared replay buffer by the agents. Returns: - Updated random key, updated PBT training state and updated replay buffer. + Updated PBT training state and updated replay buffer. """ indices_sorted = jax.numpy.argsort(-population_returns) best_indices = indices_sorted[: self._num_best_to_replace_from] indices_to_replace = indices_sorted[-self._num_worse_to_replace :] - random_key, key = jax.random.split(random_key) indices_used_to_replace = jax.random.choice( key, best_indices, shape=(self._num_worse_to_replace,), replace=True ) - training_state = jax.tree_util.tree_map( + training_state = jax.tree.map( lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]), training_state, jax.vmap(training_state.__class__.resample_hyperparams)(training_state), ) - replay_buffer = jax.tree_util.tree_map( + replay_buffer = jax.tree.map( lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]), replay_buffer, replay_buffer, ) - return random_key, training_state, replay_buffer + return training_state, replay_buffer @partial(jax.jit, static_argnames=("self",)) def update_states_and_buffer_pmap( self, - random_key: RNGKey, + key: RNGKey, population_returns: jnp.ndarray, training_state: PBTTrainingState, replay_buffer: ReplayBuffer, - ) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]: + ) -> Tuple[PBTTrainingState, ReplayBuffer]: """ Updates the agents of the population states as well as their shared replay buffer. This is the version of the function to be @@ -144,7 +143,7 @@ def update_states_and_buffer_pmap( and implement a parallel update through communication between the devices. Args: - random_key: Random RNG key. + key: Random RNG key. population_returns: Returns of the agents in the populations. training_state: The training state of the PBT scheme. replay_buffer: Shared replay buffer by the agents. @@ -156,7 +155,7 @@ def update_states_and_buffer_pmap( best_indices = indices_sorted[: self._num_best_to_replace_from] indices_to_replace = indices_sorted[-self._num_worse_to_replace :] - best_individuals, best_buffers, best_returns = jax.tree_util.tree_map( + best_individuals, best_buffers, best_returns = jax.tree.map( lambda x: x[best_indices], (training_state, replay_buffer, population_returns), ) @@ -164,19 +163,19 @@ def update_states_and_buffer_pmap( gathered_best_individuals, gathered_best_buffers, gathered_best_returns, - ) = jax.tree_util.tree_map( + ) = jax.tree.map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), (best_individuals, best_buffers, best_returns), ) pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns) best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from] - random_key, key = jax.random.split(random_key) + key, subkey = jax.random.split(key) indices_used_to_replace = jax.random.choice( - key, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True + subkey, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True ) - training_state = jax.tree_util.tree_map( + training_state = jax.tree.map( lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]), training_state, jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)( @@ -184,10 +183,10 @@ def update_states_and_buffer_pmap( ), ) - replay_buffer = jax.tree_util.tree_map( + replay_buffer = jax.tree.map( lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]), replay_buffer, gathered_best_buffers, ) - return random_key, training_state, replay_buffer + return training_state, replay_buffer diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index 793cdd3f..7cf48ccf 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -54,7 +54,7 @@ class SacTrainingState(TrainingState): alpha_optimizer_state: optax.OptState alpha_params: Params target_critic_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray normalization_running_stats: RunningMeanStdState @@ -95,12 +95,12 @@ def __init__(self, config: SacConfig, action_size: int) -> None: self._sample_action_fn = self._parametric_action_distribution.sample def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> SacTrainingState: """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space @@ -112,13 +112,13 @@ def init( dummy_obs = jnp.zeros((1, observation_size)) dummy_action = jnp.zeros((1, action_size)) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) policy_params = self._policy.init(subkey, dummy_obs) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), critic_params ) @@ -148,7 +148,7 @@ def init( ), count=jnp.zeros(()), ), - random_key=random_key, + key=key, steps=jnp.array(0), ) @@ -159,15 +159,15 @@ def select_action( self, obs: Observation, policy_params: Params, - random_key: RNGKey, + key: RNGKey, deterministic: bool = False, - ) -> Tuple[Action, RNGKey]: + ) -> Action: """Selects an action according to SAC policy. Args: obs: agent observation(s) policy_params: parameters of the agent's policy - random_key: jax random key + key: jax random key deterministic: whether to select action in a deterministic way. Defaults to False. @@ -177,14 +177,13 @@ def select_action( dist_params = self._policy.apply(policy_params, obs) if not deterministic: - random_key, key_sample = jax.random.split(random_key) - actions = self._sample_action_fn(dist_params, key_sample) + actions = self._sample_action_fn(dist_params, key) else: # The first half of parameters is for mean and the second half for variance actions = jax.nn.tanh(dist_params[..., : dist_params.shape[-1] // 2]) - return actions, random_key + return actions @partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation")) def play_step_fn( @@ -212,7 +211,7 @@ def play_step_fn( the new SAC training state the played transition """ - random_key = training_state.random_key + key = training_state.key policy_params = training_state.policy_params obs = env_state.obs @@ -228,21 +227,23 @@ def play_step_fn( normalized_obs = obs normalization_running_stats = training_state.normalization_running_stats - actions, random_key = self.select_action( + key, subkey = jax.random.split(key) + actions = self.select_action( obs=normalized_obs, policy_params=policy_params, - random_key=random_key, + key=subkey, deterministic=deterministic, ) + key, subkey = jax.random.split(key) if not evaluation: training_state = training_state.replace( - random_key=random_key, + key=subkey, normalization_running_stats=normalization_running_stats, ) else: training_state = training_state.replace( - random_key=random_key, + key=subkey, ) next_env_state = env.step(env_state, actions) @@ -317,13 +318,7 @@ def play_qd_step_fn( transition, ) - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn")) def eval_policy_fn( self, training_state: SacTrainingState, @@ -366,7 +361,7 @@ def eval_policy_fn( static_argnames=( "self", "play_step_fn", - "bd_extraction_fn", + "descriptor_extraction_fn", ), ) def eval_qd_policy_fn( @@ -377,11 +372,11 @@ def eval_qd_policy_fn( [EnvState, Params, RNGKey], Tuple[EnvState, SacTrainingState, QDTransition], ], - bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor], + descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor], ) -> Tuple[Reward, Descriptor, Reward, Descriptor]: """ Evaluates the agent's policy over an entire episode, across all batched - environments for QD environments. Averaged BDs are returned as well. + environments for QD environments. Averaged descriptors are returned as well. Args: @@ -407,14 +402,12 @@ def eval_qd_policy_fn( true_returns = jnp.nansum(transitions.rewards, axis=0) true_return = jnp.mean(true_returns, axis=-1) - transitions = jax.tree_util.tree_map( - lambda x: jnp.swapaxes(x, 0, 1), transitions - ) + transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions) masks = jnp.isnan(transitions.rewards) - bds = bd_extraction_fn(transitions, masks) + descriptors = descriptor_extraction_fn(transitions, masks) - mean_bd = jnp.mean(bds, axis=0) - return true_return, mean_bd, true_returns, bds + mean_descriptor = jnp.mean(descriptors, axis=0) + return true_return, mean_descriptor, true_returns, descriptors @partial(jax.jit, static_argnames=("self",)) def _update_alpha( @@ -422,8 +415,8 @@ def _update_alpha( alpha_lr: float, training_state: SacTrainingState, transitions: Transition, - random_key: RNGKey, - ) -> Tuple[Params, optax.OptState, jnp.ndarray, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, optax.OptState, jnp.ndarray]: """Updates the alpha parameter if necessary. Else, it keeps the current value. @@ -431,14 +424,13 @@ def _update_alpha( alpha_lr: alpha learning rate training_state: the current training state. transitions: a sample of transitions from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: - New alpha params, optimizer state, loss and a new random key. + New alpha params, optimizer state, and loss. """ if not self._config.fix_alpha: # update alpha - random_key, subkey = jax.random.split(random_key) alpha_loss, alpha_gradient = jax.value_and_grad(sac_alpha_loss_fn)( training_state.alpha_params, policy_fn=self._policy.apply, @@ -446,7 +438,7 @@ def _update_alpha( action_size=self._action_size, policy_params=training_state.policy_params, transitions=transitions, - random_key=subkey, + key=key, ) alpha_optimizer = optax.adam(learning_rate=alpha_lr) ( @@ -463,7 +455,7 @@ def _update_alpha( alpha_optimizer_state = training_state.alpha_optimizer_state alpha_loss = jnp.array(0.0) - return alpha_params, alpha_optimizer_state, alpha_loss, random_key + return alpha_params, alpha_optimizer_state, alpha_loss @partial(jax.jit, static_argnames=("self",)) def _update_critic( @@ -473,8 +465,8 @@ def _update_critic( discount: float, training_state: SacTrainingState, transitions: Transition, - random_key: RNGKey, - ) -> Tuple[Params, Params, optax.OptState, jnp.ndarray, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, Params, optax.OptState, jnp.ndarray]: """Updates the critic following the method described in the Soft Actor Critic paper. @@ -484,14 +476,14 @@ def _update_critic( discount: discount factor training_state: the current training state. transitions: a batch of transitions sampled from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: New parameters of the critic and its target. New optimizer state, loss and a new random key. """ # update critic - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_loss, critic_gradient = jax.value_and_grad(sac_critic_loss_fn)( training_state.critic_params, policy_fn=self._policy.apply, @@ -503,7 +495,7 @@ def _update_critic( target_critic_params=training_state.target_critic_params, alpha=jnp.exp(training_state.alpha_params), transitions=transitions, - random_key=subkey, + key=subkey, ) critic_optimizer = optax.adam(learning_rate=critic_lr) ( @@ -515,7 +507,7 @@ def _update_critic( critic_params = optax.apply_updates( training_state.critic_params, critic_updates ) - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.tau) * x1 + self._config.tau * x2, training_state.target_critic_params, critic_params, @@ -526,7 +518,6 @@ def _update_critic( target_critic_params, critic_optimizer_state, critic_loss, - random_key, ) @partial(jax.jit, static_argnames=("self",)) @@ -535,8 +526,8 @@ def _update_actor( policy_lr: float, training_state: SacTrainingState, transitions: Transition, - random_key: RNGKey, - ) -> Tuple[Params, optax.OptState, jnp.ndarray, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, optax.OptState, jnp.ndarray]: """Updates the actor parameters following the stochastic policy gradient theorem with the method introduced in SAC. @@ -545,12 +536,11 @@ def _update_actor( training_state: the current training state. transitions: a batch of transitions sampled from the replay buffer. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: - New params and optimizer state. Current loss. New random key. + New params and optimizer state. Current loss. """ - random_key, subkey = jax.random.split(random_key) policy_loss, policy_gradient = jax.value_and_grad(sac_policy_loss_fn)( training_state.policy_params, policy_fn=self._policy.apply, @@ -559,7 +549,7 @@ def _update_actor( critic_params=training_state.critic_params, alpha=jnp.exp(training_state.alpha_params), transitions=transitions, - random_key=subkey, + key=key, ) policy_optimizer = optax.adam(learning_rate=policy_lr) ( @@ -572,7 +562,7 @@ def _update_actor( training_state.policy_params, policy_updates ) - return policy_params, policy_optimizer_state, policy_loss, random_key + return policy_params, policy_optimizer_state, policy_loss @partial(jax.jit, static_argnames=("self",)) def update( @@ -593,9 +583,11 @@ def update( """ # sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + + key, subkey = jax.random.split(key) + transitions = replay_buffer.sample( + subkey, sample_size=self._config.batch_size, ) @@ -613,48 +605,49 @@ def update( ) # update alpha + key, subkey = jax.random.split(key) ( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, ) = self._update_alpha( alpha_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update critic + key, subkey = jax.random.split(key) ( critic_params, target_critic_params, critic_optimizer_state, critic_loss, - random_key, ) = self._update_critic( critic_lr=self._config.learning_rate, reward_scaling=self._config.reward_scaling, discount=self._config.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update actor + key, subkey = jax.random.split(key) ( policy_params, policy_optimizer_state, policy_loss, - random_key, ) = self._update_actor( policy_lr=self._config.learning_rate, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # create new training state + key, subkey = jax.random.split(key) new_training_state = SacTrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, @@ -664,7 +657,7 @@ def update( alpha_params=alpha_params, normalization_running_stats=training_state.normalization_running_stats, target_critic_params=target_critic_params, - random_key=random_key, + key=subkey, steps=training_state.steps + 1, ) metrics = { diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index 947a7183..f05bf226 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -45,7 +45,7 @@ def init_optimizers_states( policy_params = training_state.policy_params critic_params = training_state.critic_params alpha_params = training_state.alpha_params - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), critic_params ) return training_state.replace( # type: ignore @@ -75,20 +75,20 @@ def resample_hyperparams( training_state: "PBTSacTrainingState", ) -> "PBTSacTrainingState": - random_key = training_state.random_key - random_key, sub_key = jax.random.split(random_key) + key = training_state.key + key, sub_key = jax.random.split(key) discount = jax.random.uniform(sub_key, shape=(), minval=0.9, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) policy_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) critic_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) alpha_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) reward_scaling = jax.random.uniform(sub_key, shape=(), minval=0.1, maxval=10.0) return training_state.replace( # type: ignore @@ -97,7 +97,7 @@ def resample_hyperparams( critic_lr=critic_lr, alpha_lr=alpha_lr, reward_scaling=reward_scaling, - random_key=random_key, + key=key, ) @@ -135,12 +135,12 @@ def __init__(self, config: PBTSacConfig, action_size: int) -> None: SAC.__init__(self, config=sac_config, action_size=action_size) def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> PBTSacTrainingState: """Initialise the training state of the algorithm. Args: - random_key: a jax random key + key: a jax random key action_size: the size of the environment's action space observation_size: the size of the environment's observation space @@ -148,7 +148,7 @@ def init( the initial training state of PBT-SAC """ - sac_training_state = SAC.init(self, random_key, action_size, observation_size) + sac_training_state = SAC.init(self, key, action_size, observation_size) training_state = PBTSacTrainingState( policy_optimizer_state=sac_training_state.policy_optimizer_state, @@ -159,7 +159,7 @@ def init( alpha_params=sac_training_state.alpha_params, target_critic_params=sac_training_state.target_critic_params, normalization_running_stats=sac_training_state.normalization_running_stats, - random_key=sac_training_state.random_key, + key=sac_training_state.key, steps=sac_training_state.steps, discount=None, policy_lr=None, @@ -192,9 +192,11 @@ def update( """ # sample a batch of transitions in the buffer - random_key = training_state.random_key - transitions, random_key = replay_buffer.sample( - random_key, + key = training_state.key + + key, subkey = jax.random.split(key) + transitions = replay_buffer.sample( + subkey, sample_size=self._config.batch_size, ) @@ -212,48 +214,49 @@ def update( ) # update alpha + key, subkey = jax.random.split(key) ( alpha_params, alpha_optimizer_state, alpha_loss, - random_key, ) = self._update_alpha( alpha_lr=training_state.alpha_lr, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update critic + key, subkey = jax.random.split(key) ( critic_params, target_critic_params, critic_optimizer_state, critic_loss, - random_key, ) = self._update_critic( critic_lr=training_state.critic_lr, reward_scaling=training_state.reward_scaling, discount=training_state.discount, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # update actor + key, subkey = jax.random.split(key) ( policy_params, policy_optimizer_state, policy_loss, - random_key, ) = self._update_actor( policy_lr=training_state.policy_lr, training_state=training_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # create new training state + key, subkey = jax.random.split(key) new_training_state = PBTSacTrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, @@ -263,7 +266,7 @@ def update( alpha_params=alpha_params, normalization_running_stats=training_state.normalization_running_stats, target_critic_params=target_critic_params, - random_key=random_key, + key=subkey, steps=training_state.steps + 1, discount=training_state.discount, policy_lr=training_state.policy_lr, @@ -302,10 +305,10 @@ def get_init_fn( """ def _init_fn( - random_key: RNGKey, - ) -> Tuple[RNGKey, PBTSacTrainingState, ReplayBuffer]: + key: RNGKey, + ) -> Tuple[PBTSacTrainingState, ReplayBuffer]: - random_key, *keys = jax.random.split(random_key, num=1 + population_size) + key, *keys = jax.random.split(key, num=population_size + 1) keys = jnp.stack(keys) init_dummy_transition = partial( @@ -328,7 +331,7 @@ def _init_fn( self.init, action_size=action_size, observation_size=observation_size ) training_states = jax.vmap(agent_init)(keys) - return random_key, training_states, replay_buffers + return training_states, replay_buffers return _init_fn @@ -364,21 +367,22 @@ def get_eval_fn( def get_eval_qd_fn( self, eval_env: Env, - bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor], + descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor], ) -> Callable: """ - Returns the function the evaluation the PBT population. + Returns the evaluation function of the PBT population. Args: eval_env: evaluation environment. Might be different from training env if needed. - bd_extraction_fn: function to extract the bd from an episode. + descriptor_extraction_fn: function to extract the descriptor from an + episode. Returns: The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the - population agents mean returns and mean bds over episodes as well as all - returns and bds from all agents over all episodes. + population agents mean returns and mean descriptors over episodes, + as well as allreturns and descriptors from all agents over all episodes. """ play_eval_step = partial( self.play_qd_step_fn, @@ -389,7 +393,7 @@ def get_eval_qd_fn( eval_policy = partial( self.eval_qd_policy_fn, play_step_fn=play_eval_step, - bd_extraction_fn=bd_extraction_fn, + descriptor_extraction_fn=descriptor_extraction_fn, ) return jax.vmap(eval_policy) # type: ignore diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index 10d195ad..49552d9d 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -30,26 +30,18 @@ class SPEA2(GeneticAlgorithm): b13724cb54ae4171916f3f969d304b9e9752a57f" """ - @partial( - jax.jit, - static_argnames=( - "self", - "population_size", - "num_neighbours", - ), - ) + @partial(jax.jit, static_argnames=("self", "population_size", "num_neighbours")) def init( self, genotypes: Genotype, population_size: int, num_neighbours: int, - random_key: RNGKey, - ) -> Tuple[SPEA2Repertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[SPEA2Repertoire, Optional[EmitterState]]: # score initial genotypes - fitnesses, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = SPEA2Repertoire.init( @@ -60,8 +52,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -78,4 +71,4 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index 97f37893..8ef49cfc 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -44,7 +44,7 @@ class TD3TrainingState(TrainingState): critic_params: Params target_critic_params: Params target_policy_params: Params - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -86,13 +86,13 @@ def __init__(self, config: TD3Config, action_size: int): ) def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> TD3TrainingState: """Initialise the training state of the TD3 algorithm, through creation of optimizer states and params. Args: - random_key: a random key used for random operations. + key: a random key used for random operations. action_size: the size of the action array needed to interact with the environment. observation_size: the size of the observation array retrieved from the @@ -105,15 +105,15 @@ def init( # Initialize critics and policy params fake_obs = jnp.zeros(shape=(observation_size,)) fake_action = jnp.zeros(shape=(action_size,)) - random_key, subkey_1, subkey_2 = jax.random.split(random_key, num=3) + key, subkey_1, subkey_2 = jax.random.split(key, num=3) critic_params = self._critic.init(subkey_1, obs=fake_obs, actions=fake_action) policy_params = self._policy.init(subkey_2, fake_obs) # Initialize target networks - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), critic_params ) - target_policy_params = jax.tree_util.tree_map( + target_policy_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), policy_params ) @@ -129,7 +129,7 @@ def init( critic_params=critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, - random_key=random_key, + key=key, steps=jnp.array(0), ) @@ -140,17 +140,17 @@ def select_action( self, obs: Observation, policy_params: Params, - random_key: RNGKey, + key: RNGKey, expl_noise: float, deterministic: bool = False, - ) -> Tuple[Action, RNGKey]: + ) -> Action: """Selects an action according to TD3 policy. The action can be deterministic or stochastic by adding exploration noise. Args: obs: agent observation(s) policy_params: parameters of the agent's policy - random_key: jax random key + key: jax random key expl_noise: exploration noise deterministic: whether to select action in a deterministic way. Defaults to False. @@ -161,11 +161,10 @@ def select_action( actions = self._policy.apply(policy_params, obs) if not deterministic: - random_key, subkey = jax.random.split(random_key) - noise = jax.random.normal(subkey, actions.shape) * expl_noise + noise = jax.random.normal(key, actions.shape) * expl_noise actions = actions + noise actions = jnp.clip(actions, -1.0, 1.0) - return actions, random_key + return actions @partial(jax.jit, static_argnames=("self", "env", "deterministic")) def play_step_fn( @@ -190,16 +189,18 @@ def play_step_fn( the new TD3 training state the played transition """ + key = training_state.key - actions, random_key = self.select_action( + key, subkey = jax.random.split(key) + actions = self.select_action( obs=env_state.obs, policy_params=training_state.policy_params, - random_key=training_state.random_key, + key=subkey, expl_noise=self._config.expl_noise, deterministic=deterministic, ) training_state = training_state.replace( - random_key=random_key, + key=key, ) next_env_state = env.step(env_state, actions) transition = Transition( @@ -258,13 +259,7 @@ def play_qd_step_fn( transition, ) - @partial( - jax.jit, - static_argnames=( - "self", - "play_step_fn", - ), - ) + @partial(jax.jit, static_argnames=("self", "play_step_fn")) def eval_policy_fn( self, training_state: TD3TrainingState, @@ -306,7 +301,7 @@ def eval_policy_fn( static_argnames=( "self", "play_step_fn", - "bd_extraction_fn", + "descriptor_extraction_fn", ), ) def eval_qd_policy_fn( @@ -317,10 +312,10 @@ def eval_qd_policy_fn( [EnvState, Params, RNGKey], Tuple[EnvState, TD3TrainingState, QDTransition], ], - bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor], + descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor], ) -> Tuple[Reward, Descriptor, Reward, Descriptor]: """Evaluates the agent's policy over an entire episode, across all batched - environments for QD environments. Averaged BDs are returned as well. + environments for QD environments. Averaged descriptors are returned as well. Args: @@ -346,14 +341,12 @@ def eval_qd_policy_fn( true_returns = jnp.nansum(transitions.rewards, axis=0) true_return = jnp.mean(true_returns, axis=-1) - transitions = jax.tree_util.tree_map( - lambda x: jnp.swapaxes(x, 0, 1), transitions - ) + transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions) masks = jnp.isnan(transitions.rewards) - bds = bd_extraction_fn(transitions, masks) + descriptors = descriptor_extraction_fn(transitions, masks) - mean_bd = jnp.mean(bds, axis=0) - return true_return, mean_bd, true_returns, bds + mean_descriptor = jnp.mean(descriptors, axis=0) + return true_return, mean_descriptor, true_returns, descriptors @partial(jax.jit, static_argnames=("self",)) def update( @@ -376,13 +369,13 @@ def update( """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - samples, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size - ) + key = training_state.key + + key, subkey = jax.random.split(key) + samples = replay_buffer.sample(subkey, sample_size=self._config.batch_size) # Update Critic - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)( training_state.critic_params, target_policy_params=training_state.target_policy_params, @@ -394,7 +387,7 @@ def update( reward_scaling=self._config.reward_scaling, discount=self._config.discount, transitions=samples, - random_key=subkey, + key=subkey, ) critic_optimizer = optax.adam(learning_rate=self._config.critic_learning_rate) critic_updates, critic_optimizer_state = critic_optimizer.update( @@ -404,7 +397,7 @@ def update( training_state.critic_params, critic_updates ) # Soft update of target critic network - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, training_state.target_critic_params, @@ -434,7 +427,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: training_state.policy_params, policy_updates ) # Soft update of target policy - target_policy_params = jax.tree_util.tree_map( + target_policy_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, training_state.target_policy_params, @@ -463,7 +456,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer_state=policy_optimizer_state, target_critic_params=target_critic_params, target_policy_params=target_policy_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, ) diff --git a/qdax/baselines/td3_pbt.py b/qdax/baselines/td3_pbt.py index 5762956d..aad90855 100644 --- a/qdax/baselines/td3_pbt.py +++ b/qdax/baselines/td3_pbt.py @@ -48,10 +48,10 @@ def init_optimizers_states( optimizer_init = optax.adam(learning_rate=1.0).init policy_params = training_state.policy_params critic_params = training_state.critic_params - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), critic_params ) - target_policy_params = jax.tree_util.tree_map( + target_policy_params = jax.tree.map( lambda x: jnp.asarray(x.copy()), policy_params ) return training_state.replace( # type: ignore @@ -80,23 +80,23 @@ def resample_hyperparams( cls, training_state: "PBTTD3TrainingState" ) -> "PBTTD3TrainingState": - random_key = training_state.random_key - random_key, sub_key = jax.random.split(random_key) + key = training_state.key + key, sub_key = jax.random.split(key) discount = jax.random.uniform(sub_key, shape=(), minval=0.9, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) policy_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) critic_lr = jax.random.uniform(sub_key, shape=(), minval=3e-5, maxval=3e-3) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) noise_clip = jax.random.uniform(sub_key, shape=(), minval=0.0, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) policy_noise = jax.random.uniform(sub_key, shape=(), minval=0.0, maxval=1.0) - random_key, sub_key = jax.random.split(random_key) + key, sub_key = jax.random.split(key) expl_noise = jax.random.uniform(sub_key, shape=(), minval=0.0, maxval=0.2) return training_state.replace( # type: ignore @@ -105,7 +105,7 @@ def resample_hyperparams( critic_lr=critic_lr, noise_clip=noise_clip, policy_noise=policy_noise, - random_key=random_key, + key=key, expl_noise=expl_noise, ) @@ -138,13 +138,13 @@ def __init__(self, config: PBTTD3Config, action_size: int): TD3.__init__(self, td3_config, action_size) def init( - self, random_key: RNGKey, action_size: int, observation_size: int + self, key: RNGKey, action_size: int, observation_size: int ) -> PBTTD3TrainingState: """Initialise the training state of the PBT-TD3 algorithm, through creation of optimizer states and params. Args: - random_key: a random key used for random operations. + key: a random key used for random operations. action_size: the size of the action array needed to interact with the environment. observation_size: the size of the observation array retrieved from the @@ -154,7 +154,7 @@ def init( the initial training state. """ - training_state = TD3.init(self, random_key, action_size, observation_size) + training_state = TD3.init(self, key, action_size, observation_size) # Initial training state training_state = PBTTD3TrainingState( @@ -164,7 +164,7 @@ def init( critic_params=training_state.critic_params, target_policy_params=training_state.target_policy_params, target_critic_params=training_state.target_critic_params, - random_key=training_state.random_key, + key=training_state.key, steps=training_state.steps, discount=None, policy_lr=None, @@ -202,16 +202,17 @@ def play_step_fn( the new PBT-TD3 training state the played transition """ + key, subkey = jax.random.split(training_state.key) - actions, random_key = self.select_action( + actions = self.select_action( obs=env_state.obs, policy_params=training_state.policy_params, - random_key=training_state.random_key, + key=subkey, expl_noise=training_state.expl_noise, deterministic=deterministic, ) training_state = training_state.replace( - random_key=random_key, + key=key, ) next_env_state = env.step(env_state, actions) transition = Transition( @@ -245,13 +246,13 @@ def update( """ # Sample a batch of transitions in the buffer - random_key = training_state.random_key - samples, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size - ) + key = training_state.key + + key, subkey = jax.random.split(key) + samples = replay_buffer.sample(subkey, sample_size=self._config.batch_size) # Update Critic - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)( training_state.critic_params, target_policy_params=training_state.target_policy_params, @@ -263,7 +264,7 @@ def update( reward_scaling=self._config.reward_scaling, discount=self._config.discount, transitions=samples, - random_key=subkey, + key=subkey, ) critic_optimizer = optax.adam(learning_rate=training_state.critic_lr) critic_updates, critic_optimizer_state = critic_optimizer.update( @@ -273,7 +274,7 @@ def update( training_state.critic_params, critic_updates ) # Soft update of target critic network - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, training_state.target_critic_params, @@ -301,7 +302,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: training_state.policy_params, policy_updates ) # Soft update of target policy - target_policy_params = jax.tree_util.tree_map( + target_policy_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, training_state.target_policy_params, @@ -330,7 +331,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer_state=policy_optimizer_state, target_critic_params=target_critic_params, target_policy_params=target_policy_params, - random_key=random_key, + key=key, steps=training_state.steps + 1, ) @@ -363,9 +364,9 @@ def get_init_fn( """ def _init_fn( - random_key: RNGKey, - ) -> Tuple[RNGKey, PBTTD3TrainingState, ReplayBuffer]: - random_key, *keys = jax.random.split(random_key, num=1 + population_size) + key: RNGKey, + ) -> Tuple[PBTTD3TrainingState, ReplayBuffer]: + key, *keys = jax.random.split(key, num=population_size + 1) keys = jnp.stack(keys) init_dummy_transition = partial( @@ -388,7 +389,7 @@ def _init_fn( self.init, action_size=action_size, observation_size=observation_size ) training_states = jax.vmap(agent_init)(keys) - return random_key, training_states, replay_buffers + return training_states, replay_buffers return _init_fn @@ -424,7 +425,7 @@ def get_eval_fn( def get_eval_qd_fn( self, eval_env: Env, - bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor], + descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor], ) -> Callable: """ Returns the function the evaluation the PBT population. @@ -432,13 +433,14 @@ def get_eval_qd_fn( Args: eval_env: evaluation environment. Might be different from training env if needed. - bd_extraction_fn: function to extract the bd from an episode. + descriptor_extraction_fn: function to extract the descriptor from an + episode. Returns: The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the - population agents mean returns and mean bds over episodes as well as all - returns and bds from all agents over all episodes. + population agents mean returns and mean descriptors over episodes, + as well as all returns and descriptors from all agents over all episodes. """ play_eval_step = partial( self.play_qd_step_fn, @@ -449,7 +451,7 @@ def get_eval_qd_fn( eval_policy = partial( self.eval_qd_policy_fn, play_step_fn=play_eval_step, - bd_extraction_fn=bd_extraction_fn, + descriptor_extraction_fn=descriptor_extraction_fn, ) return jax.vmap(eval_policy) # type: ignore diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index f67d7b4f..4a9ea5fc 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -21,7 +21,7 @@ Params, RNGKey, ) -from qdax.environments.bd_extractors import AuroraExtraInfo +from qdax.environments.descriptor_extractors import AuroraExtraInfo class AURORA: @@ -42,7 +42,7 @@ def __init__( self, scoring_function: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ArrayTree, RNGKey], + Tuple[Fitness, Descriptor, ArrayTree], ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], @@ -62,11 +62,11 @@ def train( repertoire: UnstructuredRepertoire, model_params: Params, iteration: int, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[UnstructuredRepertoire, AuroraExtraInfo]: - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) aurora_extra_info = self._train_fn( - random_key, + key, repertoire, model_params, iteration, @@ -122,8 +122,8 @@ def init( aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, - random_key: RNGKey, - ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], AuroraExtraInfo, RNGKey]: + key: RNGKey, + ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], AuroraExtraInfo]: """Initialize an unstructured repertoire with an initial population of genotypes. Also performs the first training of the AURORA encoder. @@ -134,15 +134,16 @@ def init( such as the encoder parameters l_value: threshold distance for the unstructured repertoire max_size: maximum size of the repertoire - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: an initialized unstructured repertoire, with the initial state of the emitter, and the updated information to perform AURORA encodings """ - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function( genotypes, - random_key, + subkey, ) observations = extra_scores["last_valid_observations"] @@ -159,8 +160,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -168,21 +170,20 @@ def init( extra_scores=extra_scores, ) - random_key, subkey = jax.random.split(random_key) repertoire, updated_aurora_extra_info = self.train( - repertoire, aurora_extra_info.model_params, iteration=0, random_key=subkey + repertoire, aurora_extra_info.model_params, iteration=0, key=key ) - return repertoire, emitter_state, updated_aurora_extra_info, random_key + return repertoire, emitter_state, updated_aurora_extra_info @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, + key: RNGKey, aurora_extra_info: AuroraExtraInfo, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """Main step of the AURORA algorithm. @@ -194,7 +195,7 @@ def update( Args: repertoire: unstructured repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key aurora_extra_info: extra info for computing encodings Results: @@ -204,14 +205,14 @@ def update( a new key """ # generate offsprings with the emitter - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) # scores the offsprings - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function( genotypes, - random_key, + subkey, ) observations = extra_scores["last_valid_observations"] @@ -239,4 +240,4 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics diff --git a/qdax/core/containers/archive.py b/qdax/core/containers/archive.py index d2e1f812..1590d9eb 100644 --- a/qdax/core/containers/archive.py +++ b/qdax/core/containers/archive.py @@ -311,7 +311,7 @@ def top_1(data: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: return data, value, indice def scannable_top_1( - carry: jnp.ndarray, unused: Any + carry: jnp.ndarray, _: Any ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: data = carry data, value, indice = top_1(data) diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_repertoire.py index 403331ff..a0be8ec7 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_repertoire.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import partial -from typing import Callable, Tuple +from typing import Callable import jax import jax.numpy as jnp @@ -34,7 +34,7 @@ class GARepertoire(Repertoire): @property def size(self) -> int: """Gives the size of the population.""" - first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0] + first_leaf = jax.tree.leaves(self.genotypes)[0] return int(first_leaf.shape[0]) def save(self, path: str = "./") -> None: @@ -78,11 +78,11 @@ def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire: ) @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample genotypes from the repertoire. Args: - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. num_samples: the number of genotypes to sample. Returns: @@ -94,15 +94,14 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1)) # sample - random_key, subkey = jax.random.split(random_key) - samples = jax.tree_util.tree_map( + samples = jax.tree.map( lambda x: jax.random.choice( - subkey, x, shape=(num_samples,), p=p, replace=False + key, x, shape=(num_samples,), p=p, replace=False ), self.genotypes, ) - return samples, random_key + return samples @jax.jit def add( @@ -122,7 +121,7 @@ def add( """ # gather individuals and fitnesses - candidates = jax.tree_util.tree_map( + candidates = jax.tree.map( lambda x, y: jnp.concatenate((x, y), axis=0), self.genotypes, batch_of_genotypes, @@ -138,9 +137,7 @@ def add( survivor_indices = indices[: self.size] # keep only the best ones - new_candidates = jax.tree_util.tree_map( - lambda x: x[survivor_indices], candidates - ) + new_candidates = jax.tree.map(lambda x: x[survivor_indices], candidates) new_repertoire = self.replace( genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices] @@ -174,7 +171,7 @@ def init( # type: ignore ) # create default genotypes - default_genotypes = jax.tree_util.tree_map( + default_genotypes = jax.tree.map( lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes ) diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index 2556470b..6c23037e 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -31,8 +31,8 @@ def compute_cvt_centroids( num_centroids: int, minval: Union[float, List[float]], maxval: Union[float, List[float]], - random_key: RNGKey, -) -> Tuple[jnp.ndarray, RNGKey]: + key: RNGKey, +) -> Centroid: """Compute centroids for CVT tessellation. Args: @@ -44,31 +44,30 @@ def compute_cvt_centroids( num_centroids: number of centroids minval: minimum descriptors value maxval: maximum descriptors value - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the centroids with shape (num_centroids, num_descriptors) - random_key: an updated jax PRNG random key + key: an updated jax PRNG random key """ minval = jnp.array(minval) maxval = jnp.array(maxval) # assume here all values are in [0, 1] and rescale later - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) x = jax.random.uniform(key=subkey, shape=(num_init_cvt_samples, num_descriptors)) # compute k means - random_key, subkey = jax.random.split(random_key) k_means = KMeans( init="k-means++", n_clusters=num_centroids, n_init=1, - random_state=RandomState(jax.random.key_data(subkey)), + random_state=RandomState(jax.random.key_data(key)), ) k_means.fit(x) centroids = k_means.cluster_centers_ # rescale now - return jnp.asarray(centroids) * (maxval - minval) + minval, random_key + return jnp.asarray(centroids) * (maxval - minval) + minval def compute_euclidean_centroids( @@ -79,7 +78,7 @@ def compute_euclidean_centroids( """Compute centroids for square Euclidean tessellation. Args: - grid_shape: number of centroids per BD dimension + grid_shape: number of centroids per descriptor dimension minval: minimum descriptors value maxval: maximum descriptors value @@ -144,7 +143,7 @@ class MapElitesRepertoire(flax.struct.PyTreeNode): Args: genotypes: a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The - PyTree can be a simple Jax array or a more complex nested structure such + PyTree can be a simple JAX array or a more complex nested structure such as to represent parameters of neural network in Flax. fitnesses: an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,). @@ -212,60 +211,57 @@ def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesReperto ) @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample elements in the repertoire. Args: - random_key: a jax PRNG random key + key: a jax PRNG random key num_samples: the number of elements to be sampled Returns: samples: a batch of genotypes sampled in the repertoire - random_key: an updated jax PRNG random key + key: an updated jax PRNG random key """ repertoire_empty = self.fitnesses == -jnp.inf p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) - random_key, subkey = jax.random.split(random_key) - samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + samples = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.genotypes, ) - return samples, random_key + return samples @partial(jax.jit, static_argnames=("num_samples",)) def sample_with_descs( self, - random_key: RNGKey, + key: RNGKey, num_samples: int, - ) -> Tuple[Genotype, Descriptor, RNGKey]: + ) -> Tuple[Genotype, Descriptor]: """Sample elements in the repertoire. Args: - random_key: a jax PRNG random key + key: a jax PRNG random key num_samples: the number of elements to be sampled Returns: samples: a batch of genotypes sampled in the repertoire - random_key: an updated jax PRNG random key """ repertoire_empty = self.fitnesses == -jnp.inf p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) - random_key, subkey = jax.random.split(random_key) - samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + samples = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.genotypes, ) - descs = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + descs = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.descriptors, ) - return samples, descs, random_key + return samples, descs @jax.jit def add( @@ -326,7 +322,7 @@ def add( ) # create new repertoire - new_repertoire_genotypes = jax.tree_util.tree_map( + new_repertoire_genotypes = jax.tree.map( lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[ batch_of_indices.squeeze(axis=-1) ].set(new_genotypes), @@ -387,7 +383,7 @@ def init( ) # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) + first_genotype = jax.tree.map(lambda x: x[0], genotypes) # create a repertoire with default values repertoire = cls.init_default(genotype=first_genotype, centroids=centroids) @@ -425,7 +421,7 @@ def init_default( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # default genotypes is all 0 - default_genotypes = jax.tree_util.tree_map( + default_genotypes = jax.tree.map( lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype), genotype, ) diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py index 7ef57bb9..8fa82f39 100644 --- a/qdax/core/containers/mels_repertoire.py +++ b/qdax/core/containers/mels_repertoire.py @@ -72,7 +72,7 @@ class MELSRepertoire(MapElitesRepertoire): Args: genotypes: a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The - PyTree can be a simple Jax array or a more complex nested structure such + PyTree can be a simple JAX array or a more complex nested structure such as to represent parameters of neural network in Flax. fitnesses: an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,). @@ -243,7 +243,7 @@ def add( ) # create new repertoire - new_repertoire_genotypes = jax.tree_util.tree_map( + new_repertoire_genotypes = jax.tree.map( lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[ batch_of_indices.squeeze(axis=-1) ].set(new_genotypes), @@ -298,7 +298,7 @@ def init_default( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # default genotypes is all 0 - default_genotypes = jax.tree_util.tree_map( + default_genotypes = jax.tree.map( lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype), genotype, ) diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 58a089a6..405eb3f4 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -56,7 +56,7 @@ def repertoire_capacity(self) -> int: Returns: The repertoire capacity. """ - first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0] + first_leaf = jax.tree.leaves(self.genotypes)[0] return int(first_leaf.shape[0] * first_leaf.shape[1]) @jax.jit @@ -64,44 +64,40 @@ def _sample_in_masked_pareto_front( self, pareto_front_genotypes: ParetoFront[Genotype], mask: Mask, - random_key: RNGKey, + key: RNGKey, ) -> Genotype: """Sample one single genotype in masked pareto front. - Note: do not retrieve a random key because this function - is to be vmapped. The public method that uses this function - will return a random key - Args: pareto_front_genotypes: the genotypes of a pareto front mask: a mask associated to the front - random_key: a random key to handle stochastic operations + key: a random key to handle stochastic operations Returns: A single genotype among the pareto front. """ p = (1.0 - mask) / jnp.sum(1.0 - mask) - genotype_sample = jax.tree_util.tree_map( - lambda x: jax.random.choice(random_key, x, shape=(1,), p=p), + genotype_sample = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(1,), p=p), pareto_front_genotypes, ) return genotype_sample @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample elements in the repertoire. This method sample a non-empty pareto front, and then sample genotypes from this pareto front. Args: - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. num_samples: number of samples to retrieve from the repertoire. Returns: - A sample of genotypes and a new random key. + A sample of genotypes. """ # create sampling probability for the cells @@ -114,32 +110,27 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey indices = jnp.arange(start=0, stop=repertoire_empty.shape[0]) # choose idx - among indices of cells that are not empty - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p) # get genotypes (front) from the chosen indices - pareto_front_genotypes = jax.tree_util.tree_map( - lambda x: x[cells_idx], self.genotypes - ) + pareto_front_genotypes = jax.tree.map(lambda x: x[cells_idx], self.genotypes) # prepare second sampling function sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front) # sample genotypes from the pareto front - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, num=num_samples) + subkeys = jax.random.split(key, num=num_samples) sampled_genotypes = sample_in_fronts( # type: ignore pareto_front_genotypes=pareto_front_genotypes, mask=repertoire_empty[cells_idx], - random_key=subkeys, + key=subkeys, ) # remove the dim coming from pareto front - sampled_genotypes = jax.tree_util.tree_map( - lambda x: x.squeeze(axis=1), sampled_genotypes - ) + sampled_genotypes = jax.tree.map(lambda x: x.squeeze(axis=1), sampled_genotypes) - return sampled_genotypes, random_key + return sampled_genotypes @jax.jit def _update_masked_pareto_front( @@ -185,7 +176,7 @@ def _update_masked_pareto_front( cat_fitnesses = jnp.concatenate( [pareto_front_fitnesses, new_batch_of_fitnesses], axis=0 ) - cat_genotypes = jax.tree_util.tree_map( + cat_genotypes = jax.tree.map( lambda x, y: jnp.concatenate([x, y], axis=0), pareto_front_genotypes, new_batch_of_genotypes, @@ -208,7 +199,7 @@ def _update_masked_pareto_front( # get new fitness, genotypes and descriptors new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0) - new_front_genotypes = jax.tree_util.tree_map( + new_front_genotypes = jax.tree.map( lambda x: jnp.take(x, indices, axis=0), cat_genotypes ) new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0) @@ -232,10 +223,10 @@ def _update_masked_pareto_front( front_size = len(pareto_front_fitnesses) # type: ignore new_front_fitness = new_front_fitness[:front_size, :] - new_front_genotypes = jax.tree_util.tree_map( + new_front_genotypes = jax.tree.map( lambda x: x * new_mask_indices[0], new_front_genotypes ) - new_front_genotypes = jax.tree_util.tree_map( + new_front_genotypes = jax.tree.map( lambda x: x[:front_size], new_front_genotypes ) @@ -292,16 +283,12 @@ def _add_one( index = index.astype(jnp.int32) # get current repertoire cell data - cell_genotype = jax.tree_util.tree_map( - lambda x: x[index][0], carry.genotypes - ) + cell_genotype = jax.tree.map(lambda x: x[index][0], carry.genotypes) cell_fitness = carry.fitnesses[index][0] cell_descriptor = carry.descriptors[index][0] cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1) - new_genotypes = jax.tree_util.tree_map( - lambda x: jnp.expand_dims(x, axis=0), genotype - ) + new_genotypes = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), genotype) # update pareto front ( @@ -324,7 +311,7 @@ def _add_one( cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1) # update grid - new_genotypes = jax.tree_util.tree_map( + new_genotypes = jax.tree.map( lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype ) new_fitnesses = carry.fitnesses.at[index].set(cell_fitness) @@ -402,7 +389,7 @@ def init( # type: ignore default_fitnesses = -jnp.inf * jnp.ones( shape=(num_centroids, pareto_front_max_length, num_criteria) ) - default_genotypes = jax.tree_util.tree_map( + default_genotypes = jax.tree.map( lambda x: jnp.zeros( shape=( num_centroids, diff --git a/qdax/core/containers/nsga2_repertoire.py b/qdax/core/containers/nsga2_repertoire.py index 331ef153..13f2ded3 100644 --- a/qdax/core/containers/nsga2_repertoire.py +++ b/qdax/core/containers/nsga2_repertoire.py @@ -106,7 +106,7 @@ def add( The updated repertoire. """ # All the candidates - candidates = jax.tree_util.tree_map( + candidates = jax.tree.map( lambda x, y: jnp.concatenate((x, y), axis=0), self.genotypes, batch_of_genotypes, @@ -114,7 +114,7 @@ def add( candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses)) - first_leaf = jax.tree_util.tree_leaves(candidates)[0] + first_leaf = jax.tree.leaves(candidates)[0] num_candidates = first_leaf.shape[0] def compute_current_front( @@ -237,7 +237,7 @@ def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool: indices = indices - 1 # keep only the survivors - new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates) + new_candidates = jax.tree.map(lambda x: x[indices], candidates) new_scores = candidate_fitnesses[indices] new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_scores) diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py index 24c9fbf9..ebeaeb73 100644 --- a/qdax/core/containers/repertoire.py +++ b/qdax/core/containers/repertoire.py @@ -28,13 +28,13 @@ def init(cls) -> Repertoire: # noqa: N805 @abstractmethod def sample( self, - random_key: RNGKey, + key: RNGKey, num_samples: int, ) -> Genotype: """Sample genotypes from the repertoire. Args: - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. num_samples: the number of genotypes to sample. Returns: diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index e93fba85..2236662f 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -72,7 +72,7 @@ def add( Updated repertoire. """ # All the candidates - candidates = jax.tree_util.tree_map( + candidates = jax.tree.map( lambda x, y: jnp.concatenate((x, y), axis=0), self.genotypes, batch_of_genotypes, @@ -87,7 +87,7 @@ def add( indices = jnp.argsort(strength_scores)[: self.size] # keep the survivors - new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates) + new_candidates = jax.tree.map(lambda x: x[indices], candidates) new_fitnesses = candidates_fitnesses[indices] new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_fitnesses) @@ -121,7 +121,7 @@ def init( # type: ignore ) # create default genotypes - default_genotypes = jax.tree_util.tree_map( + default_genotypes = jax.tree.map( lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes ) diff --git a/qdax/core/containers/uniform_replacement_archive.py b/qdax/core/containers/uniform_replacement_archive.py index 830878cf..1eacb598 100644 --- a/qdax/core/containers/uniform_replacement_archive.py +++ b/qdax/core/containers/uniform_replacement_archive.py @@ -18,7 +18,7 @@ class UniformReplacementArchive(Archive): Most methods are inherited from Archive. """ - random_key: RNGKey + key: RNGKey @classmethod def create( # type: ignore @@ -26,7 +26,7 @@ def create( # type: ignore acceptance_threshold: float, state_descriptor_size: int, max_size: int, - random_key: RNGKey, + key: RNGKey, ) -> Archive: """Create an Archive instance. @@ -39,7 +39,7 @@ def create( # type: ignore state_descriptor_size: the number of elements in a state descriptor. max_size: the maximal size of the archive. In case of overflow, previous elements are replaced by new ones. Defaults to 80000. - random_key: a key to handle random operations. Defaults to key with + key: a key to handle random operations. Defaults to key with seed = 0. Returns: @@ -52,7 +52,7 @@ def create( # type: ignore max_size, ) - return archive.replace(random_key=random_key) # type: ignore + return archive.replace(key=key) # type: ignore @jax.jit def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: @@ -69,7 +69,7 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: new_current_position = self.current_position + 1 is_full = new_current_position >= self.max_size - random_key, subkey = jax.random.split(self.random_key) + key, subkey = jax.random.split(self.key) random_index = jax.random.randint( subkey, shape=(1,), minval=0, maxval=self.max_size ) @@ -79,5 +79,5 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: new_data = self.data.at[index].set(state_descriptor) return self.replace( # type: ignore - current_position=new_current_position, data=new_data, random_key=random_key + current_position=new_current_position, data=new_data, key=key ) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 32ac5689..2d371d4f 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -85,7 +85,7 @@ def intra_batch_comp( ) # If we do not use a fitness (i.e same fitness everywhere), we create a virtual - # fitness function to add individuals with the same bd + # fitness function to add individuals with the same descriptor additional_score = jnp.where( jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 ) @@ -122,8 +122,8 @@ def intra_batch_comp( fitness, ).any() - # Discard Individuals with Nans as their BD (mainly for the readdition where we - # have NaN bds) + # Discard individuals with nan as their descriptor (mainly for the readdition + # where we have nan descriptors) discard_indiv = jnp.logical_or(discard_indiv, not_existent) # Negate to know if we keep the individual @@ -137,7 +137,7 @@ class UnstructuredRepertoire(flax.struct.PyTreeNode): Args: genotypes: a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The - PyTree can be a simple Jax array or a more complex nested structure such + PyTree can be a simple JAX array or a more complex nested structure such as to represent parameters of neural network in Flax. fitnesses: an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,). @@ -290,29 +290,31 @@ def add( -1, ) - # We get all the indices of the empty bds first and then the filled ones + # We get all the indices of the empty descriptors first and then the filled ones # (because of -1) - sorted_bds = jax.lax.top_k( + sorted_descriptors = jax.lax.top_k( -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] )[1] batch_of_indices = jnp.where( - jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value), - batch_of_indices.at[sorted_bds].get(), + jnp.squeeze( + batch_of_distances.at[sorted_descriptors].get() <= self.l_value + ), + batch_of_indices.at[sorted_descriptors].get(), empty_indexes, ) batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) # ReIndexing of all the inputs to the correct sorted way - batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() - batch_of_genotypes = jax.tree_util.tree_map( - lambda x: x.at[sorted_bds].get(), batch_of_genotypes + batch_of_descriptors = batch_of_descriptors.at[sorted_descriptors].get() + batch_of_genotypes = jax.tree.map( + lambda x: x.at[sorted_descriptors].get(), batch_of_genotypes ) - batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() - batch_of_observations = batch_of_observations.at[sorted_bds].get() - not_novel_enough = not_novel_enough.at[sorted_bds].get() + batch_of_fitnesses = batch_of_fitnesses.at[sorted_descriptors].get() + batch_of_observations = batch_of_observations.at[sorted_descriptors].get() + not_novel_enough = not_novel_enough.at[sorted_descriptors].get() - # Check to find Individuals with same BD within the Batch + # Check to find Individuals with same descriptor within the Batch keep_indiv = jax.jit( jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0)) )( @@ -357,7 +359,7 @@ def add( ) # create new grid - new_grid_genotypes = jax.tree_util.tree_map( + new_grid_genotypes = jax.tree.map( lambda grid_genotypes, new_genotypes: grid_genotypes.at[ batch_of_indices.squeeze() ].set(new_genotypes), @@ -387,28 +389,25 @@ def add( ) @partial(jax.jit, static_argnames=("num_samples",)) - def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample(self, key: RNGKey, num_samples: int) -> Genotype: """Sample elements in the repertoire. Args: - random_key: a jax PRNG random key + key: a jax PRNG random key num_samples: the number of elements to be sampled Returns: samples: a batch of genotypes sampled in the repertoire - random_key: an updated jax PRNG random key """ - - random_key, sub_key = jax.random.split(random_key) grid_empty = self.fitnesses == -jnp.inf p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) - samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), + samples = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(num_samples,), p=p), self.genotypes, ) - return samples, random_key + return samples @classmethod def init( @@ -440,7 +439,7 @@ def init( # Initialize grid with default values default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) - default_genotypes = jax.tree_util.tree_map( + default_genotypes = jax.tree.map( lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan), genotypes, ) diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index e8549005..23024d4c 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -20,8 +20,8 @@ def init( self, genotypes: Genotype, centroids: Centroid, - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState]]: """ Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method @@ -34,23 +34,21 @@ def init( genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: An initialized MAP-Elite repertoire with the initial state of the emitter, and a random key. """ # score initial genotypes - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, key) # gather across all devices ( gathered_genotypes, gathered_fitnesses, gathered_descriptors, - ) = jax.tree_util.tree_map( + ) = jax.tree.map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), (genotypes, fitnesses, descriptors), ) @@ -64,8 +62,8 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + emitter_state = self._emitter.init( + key=key, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -83,15 +81,15 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -105,7 +103,7 @@ def update( Args: repertoire: the MAP-Elites repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the updated MAP-Elites repertoire @@ -114,20 +112,19 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) + # scores the offsprings - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # gather across all devices ( gathered_genotypes, gathered_fitnesses, gathered_descriptors, - ) = jax.tree_util.tree_map( + ) = jax.tree.map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), (genotypes, fitnesses, descriptors), ) @@ -150,12 +147,12 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics def get_distributed_init_fn( self, centroids: Centroid, devices: List[Any] ) -> Callable[ - [Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey] + [Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState]] ]: """Create a function that init MAP-Elites in a distributed way. @@ -177,7 +174,7 @@ def get_distributed_update_fn( self, num_iterations: int, devices: List[Any] ) -> Callable[ [MapElitesRepertoire, Optional[EmitterState], RNGKey], - Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics], + Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics], ]: """Create a function that can do a certain number of updates of MAP-Elites in a way that is distributed on several devices. @@ -194,43 +191,43 @@ def get_distributed_update_fn( @jax.jit def _scan_update( carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], - unused: Any, + _: Any, ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.""" # unwrap the input - repertoire, emitter_state, random_key = carry + repertoire, emitter_state, key = carry # apply one step of update + key, subkey = jax.random.split(key) ( repertoire, emitter_state, metrics, - random_key, ) = self.update( repertoire, emitter_state, - random_key, + subkey, ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state, key), metrics def update_fn( repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """Apply num_iterations of update.""" ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( _scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) - return repertoire, emitter_state, random_key, metrics + return repertoire, emitter_state, metrics return jax.pmap(update_fn, devices=devices, axis_name="p") # type: ignore diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index 9ac4eda1..1183a217 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -28,7 +28,7 @@ class CMAEmitterState(EmitterState): Emitter state for the CMA-ME emitter. Args: - random_key: a random key to handle stochastic operations. Used for + key: a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm @@ -37,7 +37,7 @@ class CMAEmitterState(EmitterState): emit_count: count the number of emission events. """ - random_key: RNGKey + key: RNGKey cmaes_state: CMAESState previous_fitnesses: Fitness emit_count: int @@ -55,7 +55,7 @@ def __init__( ): """ Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the - Rapid Illumination of Behavior Space" by Fontaine et al. + Rapid Illumination of Descriptor Space" by Fontaine et al. Args: batch_size: number of solutions sampled at each iteration @@ -107,7 +107,7 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -120,7 +120,7 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. @@ -131,24 +131,22 @@ def init( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # return the initial state - random_key, subkey = jax.random.split(random_key) - return ( - CMAEmitterState( - random_key=subkey, - cmaes_state=self._cma_initial_state, - previous_fitnesses=default_fitnesses, - emit_count=0, - ), - random_key, + key, subkey = jax.random.split(key) + emitter_state = CMAEmitterState( + key=subkey, + cmaes_state=self._cma_initial_state, + previous_fitnesses=default_fitnesses, + emit_count=0, ) + return emitter_state @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Optional[MapElitesRepertoire], emitter_state: CMAEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -157,22 +155,17 @@ def emit( Args: repertoire: a repertoire of genotypes (unused). emitter_state: the state of the CMA-MEGA emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: New genotypes and a new random key. """ # emit from CMA-ES - offsprings, random_key = self._cmaes.sample( - cmaes_state=emitter_state.cmaes_state, random_key=random_key - ) + offsprings = self._cmaes.sample(cmaes_state=emitter_state.cmaes_state, key=key) - return offsprings, {}, random_key + return offsprings, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: CMAEmitterState, @@ -181,7 +174,7 @@ def state_update( fitnesses: Fitness, descriptors: Descriptor, extra_scores: Optional[ExtraScores] = None, - ) -> Optional[EmitterState]: + ) -> CMAEmitterState: """ Updates the CMA-ME emitter state. @@ -223,9 +216,7 @@ def state_update( sorted_indices = jnp.flip(jnp.argsort(ranking_criteria)) # sort the candidates - sorted_candidates = jax.tree_util.tree_map( - lambda x: x[sorted_indices], genotypes - ) + sorted_candidates = jax.tree.map(lambda x: x[sorted_indices], genotypes) sorted_improvements = improvements[sorted_indices] # compute reinitialize condition @@ -250,14 +241,14 @@ def update_and_reinit( operand: Tuple[ CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey ], - ) -> Tuple[CMAEmitterState, RNGKey]: + ) -> CMAEmitterState: return self._update_and_init_emitter_state(*operand) def update_wo_reinit( operand: Tuple[ CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey ], - ) -> Tuple[CMAEmitterState, RNGKey]: + ) -> CMAEmitterState: """Update the emitter when no reinit event happened. Here lies a divergence compared to the original implementation. We @@ -284,7 +275,7 @@ def update_wo_reinit( instability. """ - (cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand + cmaes_state, emitter_state, _, emit_count, _ = operand # Update CMA Parameters mask = jnp.ones_like(sorted_improvements) @@ -298,10 +289,12 @@ def update_wo_reinit( emit_count=emit_count, ) - return emitter_state, random_key + return emitter_state # type: ignore # Update CMA Parameters - emitter_state, random_key = jax.lax.cond( + key = emitter_state.key + key, subkey = jax.random.split(key) + emitter_state = jax.lax.cond( reinitialize, update_and_reinit, update_wo_reinit, @@ -310,13 +303,14 @@ def update_wo_reinit( emitter_state, repertoire, emit_count, - emitter_state.random_key, + subkey, ), ) # update the emitter state emitter_state = emitter_state.replace( - random_key=random_key, previous_fitnesses=repertoire.fitnesses + previous_fitnesses=repertoire.fitnesses, + key=key, ) return emitter_state @@ -327,8 +321,8 @@ def _update_and_init_emitter_state( emitter_state: CMAEmitterState, repertoire: MapElitesRepertoire, emit_count: int, - random_key: RNGKey, - ) -> Tuple[CMAEmitterState, RNGKey]: + key: RNGKey, + ) -> CMAEmitterState: """Update the emitter state in the case of a reinit event. Reinit the cmaes state and use an individual from the repertoire as the starting mean. @@ -338,17 +332,17 @@ def _update_and_init_emitter_state( emitter_state: current cmame state repertoire: most recent repertoire emit_count: counter of the emitter - random_key: key to handle stochastic events + key: key to handle stochastic events Returns: The updated emitter state. """ # re-sample - random_genotype, random_key = repertoire.sample(random_key, 1) + random_genotype = repertoire.sample(key, 1) # remove the batch dim - new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) + new_mean = jax.tree.map(lambda x: x.squeeze(0), random_genotype) cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0) @@ -356,7 +350,7 @@ def _update_and_init_emitter_state( cmaes_state=cmaes_init_state, emit_count=0 ) - return emitter_state, random_key + return emitter_state # type: ignore @abstractmethod def _ranking_criteria( diff --git a/qdax/core/emitters/cma_improvement_emitter.py b/qdax/core/emitters/cma_improvement_emitter.py index fd84bf17..f1d5a9e0 100644 --- a/qdax/core/emitters/cma_improvement_emitter.py +++ b/qdax/core/emitters/cma_improvement_emitter.py @@ -11,7 +11,7 @@ class CMAImprovementEmitter(CMAEmitter): """Class for the emitter of CMA ME from "Covariance Matrix Adaptation - for the Rapid Illumination of Behavior Space" by Fontaine et al. + for the Rapid Illumination of Descriptor Space" by Fontaine et al. This class implements the improvement emitter, where the update of the distribution is biased towards solution that improve the QD score. diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index b4e83b96..dc420b1a 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -30,7 +30,7 @@ class CMAMEGAState(EmitterState): Args: theta: current genotype from where candidates will be drawn. theta_grads: normalized fitness and descriptors gradients of theta. - random_key: a random key to handle stochastic operations. Used for + key: a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm @@ -40,7 +40,7 @@ class CMAMEGAState(EmitterState): theta: Genotype theta_grads: Gradient - random_key: RNGKey + key: RNGKey cmaes_state: CMAESState previous_fitnesses: Fitness @@ -49,7 +49,7 @@ class CMAMEGAEmitter(Emitter): def __init__( self, scoring_function: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores] ], batch_size: int, learning_rate: float, @@ -101,7 +101,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -114,20 +114,20 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. """ # define init theta as 0 - theta = jax.tree_util.tree_map( + theta = jax.tree.map( lambda x: jnp.zeros_like(x[:1, ...]), genotypes, ) # score it - _, _, extra_score, random_key = self._scoring_function(theta, random_key) + _, _, extra_score = self._scoring_function(theta, key) theta_grads = extra_score["normalized_grads"] # Initialize repertoire with default values @@ -135,25 +135,23 @@ def init( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # return the initial state - random_key, subkey = jax.random.split(random_key) - return ( - CMAMEGAState( - theta=theta, - theta_grads=theta_grads, - random_key=subkey, - cmaes_state=self._cma_initial_state, - previous_fitnesses=default_fitnesses, - ), - random_key, + key, subkey = jax.random.split(key) + emitter_state = CMAMEGAState( + theta=theta, + theta_grads=theta_grads, + key=subkey, + cmaes_state=self._cma_initial_state, + previous_fitnesses=default_fitnesses, ) + return emitter_state @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Optional[MapElitesRepertoire], emitter_state: CMAMEGAState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -162,10 +160,10 @@ def emit( Args: repertoire: a repertoire of genotypes (unused). emitter_state: the state of the CMA-MEGA emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: - New genotypes and a new random key. + New genotypes. """ # retrieve elements from the emitter state @@ -176,23 +174,18 @@ def emit( grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0)) # Draw random coefficients - use the emitter state key - coeffs, random_key = self._cmaes.sample( - cmaes_state=cmaes_state, random_key=emitter_state.random_key - ) + coeffs = self._cmaes.sample(cmaes_state=cmaes_state, key=key) # make sure the fitness coefficient is positive coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0])) update_grad = coeffs @ grads.T # Compute new candidates - new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) + new_thetas = jax.tree.map(lambda x, y: x + y, theta, update_grad) - return new_thetas, {}, random_key + return new_thetas, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: CMAMEGAState, @@ -224,6 +217,7 @@ def state_update( Returns: The updated emitter state. """ + key = emitter_state.key # retrieve elements from the emitter state cmaes_state = emitter_state.cmaes_state @@ -251,9 +245,8 @@ def state_update( sorted_indices = jnp.flip(jnp.argsort(ranking_criteria)) # Draw the coeffs - reuse the emitter state key to get same coeffs - coeffs, random_key = self._cmaes.sample( - cmaes_state=cmaes_state, random_key=emitter_state.random_key - ) + key, subkey = jax.random.split(key) + coeffs = self._cmaes.sample(cmaes_state=cmaes_state, key=subkey) # make sure the fitness coeff is positive coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0])) @@ -264,7 +257,7 @@ def state_update( gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0) # update theta - theta = jax.tree_util.tree_map( + theta = jax.tree.map( lambda x, y: x + self._learning_rate * y, theta, gradient_step ) @@ -278,28 +271,30 @@ def state_update( ) # re-sample - random_theta, random_key = repertoire.sample(random_key, 1) + key, subkey = jax.random.split(key) + random_theta = repertoire.sample(subkey, 1) # update theta in case of reinit - theta = jax.tree_util.tree_map( + theta = jax.tree.map( lambda x, y: jnp.where(reinitialize, x, y), random_theta, theta ) # update cmaes state in case of reinit - cmaes_state = jax.tree_util.tree_map( + cmaes_state = jax.tree.map( lambda x, y: jnp.where(reinitialize, x, y), self._cma_initial_state, cmaes_state, ) # score theta - _, _, extra_score, random_key = self._scoring_function(theta, random_key) + key, subkey = jax.random.split(key) + _, _, extra_score = self._scoring_function(theta, subkey) # create new emitter state emitter_state = CMAMEGAState( theta=theta, theta_grads=extra_score["normalized_grads"], - random_key=random_key, + key=key, cmaes_state=cmaes_state, previous_fitnesses=repertoire.fitnesses, ) diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index 55ccaa4f..14d58e14 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -45,12 +45,12 @@ def batch_size(self) -> int: Returns: the batch size emitted by the emitter. """ - return self._emitter.batch_size + return self._emitter.batch_size # type: ignore @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -63,29 +63,28 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. """ - def scan_emitter_init( - carry: RNGKey, unused: Any - ) -> Tuple[RNGKey, CMAEmitterState]: - random_key = carry - emitter_state, random_key = self._emitter.init( - random_key, + def scan_emitter_init(carry: RNGKey, _: Any) -> Tuple[RNGKey, CMAEmitterState]: + key = carry + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + subkey, repertoire, genotypes, fitnesses, descriptors, extra_scores, ) - return random_key, emitter_state + return key, emitter_state # init all the emitter states - random_key, emitter_states = jax.lax.scan( - scan_emitter_init, random_key, (), length=self._num_states + key, emitter_states = jax.lax.scan( + scan_emitter_init, key, (), length=self._num_states ) # define the emitter state of the pool @@ -93,47 +92,39 @@ def scan_emitter_init( current_index=0, emitter_states=emitter_states ) - return ( - emitter_state, - random_key, - ) + return emitter_state @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Optional[MapElitesRepertoire], emitter_state: CMAPoolEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emits new individuals. Args: repertoire: a repertoire of genotypes (unused). emitter_state: the state of the CMA-MEGA emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: - New genotypes and a new random key. + New genotypes and extra infos. """ # retrieve the relevant emitter state current_index = emitter_state.current_index - used_emitter_state = jax.tree_util.tree_map( + used_emitter_state = jax.tree.map( lambda x: x[current_index], emitter_state.emitter_states ) # use it to emit offsprings - offsprings, extra_info, random_key = self._emitter.emit( - repertoire, used_emitter_state, random_key - ) + offsprings, extra_info = self._emitter.emit(repertoire, used_emitter_state, key) - return offsprings, extra_info, random_key + return offsprings, extra_info - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: CMAPoolEmitterState, @@ -162,9 +153,7 @@ def state_update( current_index = emitter_state.current_index emitter_states = emitter_state.emitter_states - used_emitter_state = jax.tree_util.tree_map( - lambda x: x[current_index], emitter_states - ) + used_emitter_state = jax.tree.map(lambda x: x[current_index], emitter_states) # update the used emitter state used_emitter_state = self._emitter.state_update( @@ -177,7 +166,7 @@ def state_update( ) # update the emitter state - emitter_states = jax.tree_util.tree_map( + emitter_states = jax.tree.map( lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state ) diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index c015922c..8ed164fe 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -18,14 +18,14 @@ class CMARndEmitterState(CMAEmitterState): Args: - random_key: a random key to handle stochastic operations. Used for + key: a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm previous_fitnesses: store last fitnesses of the repertoire. Used to compute the improvement. emit_count: count the number of emission events. - random_direction: direction of the behavior space we are trying to + random_direction: direction of the descriptor space we are trying to explore. """ @@ -36,7 +36,7 @@ class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -49,7 +49,7 @@ def init( Args: genotypes: initial genotypes to add to the grid. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial state of the emitter. @@ -60,34 +60,33 @@ def init( default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # take a random direction - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) random_direction = jax.random.uniform( subkey, shape=(self._centroids.shape[-1],), ) # return the initial state - random_key, subkey = jax.random.split(random_key) - - return ( - CMARndEmitterState( - random_key=subkey, - cmaes_state=self._cma_initial_state, - previous_fitnesses=default_fitnesses, - emit_count=0, - random_direction=random_direction, - ), - random_key, + key, subkey = jax.random.split(key) + + emitter_state = CMARndEmitterState( + key=subkey, + cmaes_state=self._cma_initial_state, + previous_fitnesses=default_fitnesses, + emit_count=0, + random_direction=random_direction, ) + return emitter_state + def _update_and_init_emitter_state( self, cmaes_state: CMAESState, emitter_state: CMAEmitterState, repertoire: MapElitesRepertoire, emit_count: int, - random_key: RNGKey, - ) -> Tuple[CMAEmitterState, RNGKey]: + key: RNGKey, + ) -> CMAEmitterState: """Update the emitter state in the case of a reinit event. Reinit the cmaes state and use an individual from the repertoire as the starting mean. @@ -97,25 +96,25 @@ def _update_and_init_emitter_state( emitter_state: current cmame state repertoire: most recent repertoire emit_count: counter of the emitter - random_key: key to handle stochastic events + key: key to handle stochastic events Returns: The updated emitter state. """ # re-sample - random_genotype, random_key = repertoire.sample(random_key, 1) + key, subkey = jax.random.split(key) + random_genotype = repertoire.sample(subkey, 1) # get new mean - remove the batch dim - new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) + new_mean = jax.tree.map(lambda x: x.squeeze(0), random_genotype) # define the corresponding cmaes init state cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0) # take a new random direction - random_key, subkey = jax.random.split(random_key) random_direction = jax.random.uniform( - subkey, + key, shape=(self._centroids.shape[-1],), ) @@ -125,7 +124,7 @@ def _update_and_init_emitter_state( random_direction=random_direction, ) - return emitter_state, random_key + return emitter_state # type: ignore def _ranking_criteria( self, diff --git a/qdax/core/emitters/dcrl_emitter.py b/qdax/core/emitters/dcrl_emitter.py index b353a22f..3332d5f7 100644 --- a/qdax/core/emitters/dcrl_emitter.py +++ b/qdax/core/emitters/dcrl_emitter.py @@ -133,7 +133,7 @@ def init( fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[DCRLEmitterState, RNGKey]: + ) -> DCRLEmitterState: """Initializes the emitter state. Args: @@ -141,11 +141,11 @@ def init( key: A random key. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ - observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] - descriptor_size = self._env.behavior_descriptor_length + observation_size = jax.tree.leaves(genotypes)[1].shape[1] + descriptor_size = self._env.descriptor_length action_size = self._env.action_size # Initialise critic, greedy actor and population @@ -157,11 +157,11 @@ def init( critic_params = self._critic_network.init( subkey, obs=fake_obs, actions=fake_action, desc=fake_desc ) - target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + target_critic_params = jax.tree.map(lambda x: x, critic_params) key, subkey = jax.random.split(key) actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) - target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) + target_actor_params = jax.tree.map(lambda x: x, actor_params) # Prepare init optimizer states critic_opt_state = self._critic_optimizer.init(critic_params) @@ -204,7 +204,7 @@ def init( steps=jnp.array(0), ) - return emitter_state, key + return emitter_state @partial(jax.jit, static_argnames=("self",)) def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: @@ -223,22 +223,17 @@ def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: def _normalize_desc(self, desc: Descriptor) -> Descriptor: return ( 2 - * (desc - self._env.behavior_descriptor_limits[0]) - / ( - self._env.behavior_descriptor_limits[1] - - self._env.behavior_descriptor_limits[0] - ) + * (desc - self._env.descriptor_limits[0]) + / (self._env.descriptor_limits[1] - self._env.descriptor_limits[0]) - 1 ) @partial(jax.jit, static_argnames=("self",)) def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: return 0.5 * ( - self._env.behavior_descriptor_limits[1] - - self._env.behavior_descriptor_limits[0] + self._env.descriptor_limits[1] - self._env.descriptor_limits[0] ) * desc_normalized + 0.5 * ( - self._env.behavior_descriptor_limits[1] - + self._env.behavior_descriptor_limits[0] + self._env.descriptor_limits[1] + self._env.descriptor_limits[0] ) @partial(jax.jit, static_argnames=("self",)) @@ -274,16 +269,13 @@ def _compute_equivalent_params_with_desc( actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias return actor_dc_params - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: DCRLEmitterState, key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + ) -> Tuple[Genotype, ExtraScores]: """Do a step of PG emission. Args: @@ -295,33 +287,28 @@ def emit( A batch of offspring, the new emitter state and a new key. """ # PG emitter - parents_pg, descs_pg, key = repertoire.sample_with_descs( - key, self._config.dcrl_batch_size + key, subkey = jax.random.split(key) + parents_pg, descs_pg = repertoire.sample_with_descs( + subkey, self._config.dcrl_batch_size ) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter - _, descs_ai, key = repertoire.sample_with_descs(key, self._config.ai_batch_size) - descs_ai = descs_ai.reshape( - descs_ai.shape[0], self._env.behavior_descriptor_length - ) + _, descs_ai = repertoire.sample_with_descs(key, self._config.ai_batch_size) + descs_ai = descs_ai.reshape(descs_ai.shape[0], self._env.descriptor_length) genotypes_ai = self.emit_ai(emitter_state, descs_ai) # Concatenate PG and AI genotypes - genotypes = jax.tree_util.tree_map( + genotypes = jax.tree.map( lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai ) return ( genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, - key, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_pg( self, emitter_state: DCRLEmitterState, @@ -348,10 +335,7 @@ def emit_pg( return offsprings - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_ai(self, emitter_state: DCRLEmitterState, descs: Descriptor) -> Genotype: """Emit the offsprings generated through pg mutation. @@ -386,10 +370,7 @@ def emit_actor(self, emitter_state: DCRLEmitterState) -> Genotype: """ return emitter_state.actor_params - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: DCRLEmitterState, @@ -448,10 +429,10 @@ def state_update( # sample transitions from the replay buffer key, subkey = jax.random.split(emitter_state.key) - transitions, key = replay_buffer.sample( + transitions = replay_buffer.sample( subkey, self._config.num_critic_training_steps * self._config.batch_size ) - transitions = jax.tree_util.tree_map( + transitions = jax.tree.map( lambda x: jnp.reshape( x, ( @@ -503,19 +484,20 @@ def _train_critics( New emitter state where the critic and the greedy actor have been updated. Optimizer states have also been updated in the process. """ + key, subkey = jax.random.split(emitter_state.key) + # Update Critic ( critic_opt_state, critic_params, target_critic_params, - key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_opt_state=emitter_state.critic_opt_state, transitions=transitions, - key=emitter_state.key, + key=subkey, ) # Update greedy actor @@ -563,7 +545,7 @@ def _update_critic( critic_opt_state: Params, transitions: DCRLTransition, key: RNGKey, - ) -> Tuple[Params, Params, Params, RNGKey]: + ) -> Tuple[Params, Params, Params]: # compute loss and gradients key, subkey = jax.random.split(key) @@ -572,7 +554,7 @@ def _update_critic( target_actor_params, target_critic_params, transitions, - subkey, + key, ) critic_updates, critic_opt_state = self._critic_optimizer.update( critic_gradient, critic_opt_state @@ -582,14 +564,14 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, critic_params, ) - return critic_opt_state, critic_params, target_critic_params, key + return critic_opt_state, critic_params, target_critic_params @partial(jax.jit, static_argnames=("self",)) def _update_actor( @@ -614,7 +596,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_util.tree_map( + target_actor_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -627,10 +609,7 @@ def _update_actor( target_actor_params, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _mutation_function_pg( self, policy_params: Genotype, @@ -650,9 +629,10 @@ def _mutation_function_pg( Returns: The updated params of the neural network. """ + key, subkey = jax.random.split(emitter_state.key) # Get transitions - transitions, key = emitter_state.replay_buffer.sample( - emitter_state.key, + transitions = emitter_state.replay_buffer.sample( + subkey, sample_size=self._config.num_pg_training_steps * self._config.batch_size, ) descs_prime = jnp.tile( @@ -664,7 +644,7 @@ def _mutation_function_pg( * transitions.rewards, desc_prime=descs_prime_normalized, ) - transitions = jax.tree_util.tree_map( + transitions = jax.tree.map( lambda x: jnp.reshape( x, ( diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index a30e87af..51e3f6d8 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -79,26 +79,30 @@ def __init__( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[DiversityPGEmitterState, RNGKey]: + ) -> DiversityPGEmitterState: """Initializes the emitter state. Args: + key: A random key. + repertoire: The initial repertoire. genotypes: The initial population. - random_key: A random key. + fitnesses: The initial fitnesses of the population. + descriptors: The initial descriptors of the population. + extra_scores: Extra scores coming from the scoring function. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init( - random_key, + diversity_emitter_state = super().init( + key, repertoire, genotypes, fitnesses, @@ -130,7 +134,7 @@ def init( archive=archive, ) - return emitter_state, random_key + return emitter_state @partial(jax.jit, static_argnames=("self",)) def state_update( @@ -163,6 +167,8 @@ def state_update( New emitter state where the replay buffer has been filled with the new experienced transitions. """ + key = emitter_state.key + # get the transitions out of the dictionary assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" transitions = extra_scores["transitions"] @@ -181,11 +187,9 @@ def scan_train_critics( return new_emitter_state, () # sample transitions - ( - transitions, - random_key, - ) = emitter_state.replay_buffer.sample( - random_key=emitter_state.random_key, + key, subkey = jax.random.split(key) + transitions = emitter_state.replay_buffer.sample( + key=subkey, sample_size=self._config.num_critic_training_steps * self._config.batch_size, ) @@ -196,7 +200,7 @@ def scan_train_critics( transitions = transitions.replace(rewards=diversity_rewards) # reshape the transitions - transitions = jax.tree_util.tree_map( + transitions = jax.tree.map( lambda x: x.reshape( ( self._config.num_critic_training_steps, @@ -215,7 +219,7 @@ def scan_train_critics( length=self._config.num_critic_training_steps, ) - emitter_state = emitter_state.replace(archive=archive) + emitter_state = emitter_state.replace(archive=archive, key=key) return emitter_state # type: ignore @@ -236,20 +240,21 @@ def _train_critics( New emitter state where the critic and the greedy actor have been updated. Optimizer states have also been updated in the process. """ + key = emitter_state.key # Update Critic + key, subkey = jax.random.split(key) ( critic_optimizer_state, critic_params, target_critic_params, - random_key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_optimizer_state=emitter_state.critic_optimizer_state, transitions=transitions, - random_key=emitter_state.random_key, + key=subkey, ) # Update greedy policy @@ -282,7 +287,7 @@ def _train_critics( actor_opt_state=policy_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, replay_buffer=emitter_state.replay_buffer, ) @@ -296,6 +301,7 @@ def _mutation_function_pg( emitter_state: DiversityPGEmitterState, ) -> Genotype: """Apply pg mutation to a policy via multiple steps of gradient descent. + TODO: random key not properly handled. Args: policy_params: a policy, supposed to be a differentiable neural @@ -332,8 +338,8 @@ def scan_train_policy( ), () # sample transitions - transitions, _random_key = emitter_state.replay_buffer.sample( - random_key=emitter_state.random_key, + transitions = emitter_state.replay_buffer.sample( + key=emitter_state.key, sample_size=self._config.num_pg_training_steps * self._config.batch_size, ) @@ -345,7 +351,7 @@ def scan_train_policy( transitions = transitions.replace(rewards=diversity_rewards) # reshape the transitions - transitions = jax.tree_util.tree_map( + transitions = jax.tree.map( lambda x: x.reshape( ( self._config.num_pg_training_steps, diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d2a477a8..d3733ddb 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -31,33 +31,33 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[Optional[EmitterState], RNGKey]: + ) -> Optional[EmitterState]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted. Args: genotypes: The genotypes of the initial population. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial emitter state and a random key. """ - return None, random_key + return None @abstractmethod def emit( self, repertoire: Optional[Repertoire], emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire. @@ -66,17 +66,14 @@ def emit( Args: repertoire: a repertoire of genotypes. emitter_state: the state of the emitter. - random_key: a random key to handle random operations. + key: a random key to handle random operations. Returns: A batch of offspring, a new random key. """ pass - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: Optional[EmitterState], diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index 821b24da..eb9d3db9 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -160,7 +160,7 @@ class MEESEmitterState(EmitterState): last_updated_genotypes: used to choose parents from repertoire last_updated_fitnesses: used to choose parents from repertoire last_updated_position: used to choose parents from repertoire - random_key: key to handle stochastic operations + key: key to handle stochastic operations """ initial_optimizer_state: optax.OptState @@ -171,7 +171,7 @@ class MEESEmitterState(EmitterState): last_updated_genotypes: Genotype last_updated_fitnesses: Fitness last_updated_position: jnp.ndarray - random_key: RNGKey + key: RNGKey class MEESEmitter(Emitter): @@ -198,7 +198,7 @@ def __init__( config: MEESConfig, total_generations: int, scoring_fn: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores] ], num_descriptors: int, ) -> None: @@ -232,13 +232,10 @@ def batch_size(self) -> int: """ return 1 - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, @@ -249,14 +246,14 @@ def init( Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: The initial state of the MEESEmitter, a new random key. """ # Initialisation requires one initial genotype - if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1: - genotypes = jax.tree_util.tree_map( + if jax.tree.leaves(genotypes)[0].shape[0] > 1: + genotypes = jax.tree.map( lambda x: x[0], genotypes, ) @@ -275,7 +272,7 @@ def init( ) # Create empty updated genotypes and fitness - last_updated_genotypes = jax.tree_util.tree_map( + last_updated_genotypes = jax.tree.map( lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]), genotypes, ) @@ -283,55 +280,46 @@ def init( shape=self._config.last_updated_size ) - return ( - MEESEmitterState( - initial_optimizer_state=initial_optimizer_state, - optimizer_state=initial_optimizer_state, - offspring=genotypes, - generation_count=0, - novelty_archive=novelty_archive, - last_updated_genotypes=last_updated_genotypes, - last_updated_fitnesses=last_updated_fitnesses, - last_updated_position=0, - random_key=random_key, - ), - random_key, + emitter_state = MEESEmitterState( + initial_optimizer_state=initial_optimizer_state, + optimizer_state=initial_optimizer_state, + offspring=genotypes, + generation_count=0, + novelty_archive=novelty_archive, + last_updated_genotypes=last_updated_genotypes, + last_updated_fitnesses=last_updated_fitnesses, + last_updated_position=0, + key=key, ) + return emitter_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: MapElitesRepertoire, emitter_state: MEESEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Return the offspring generated through gradient update. Params: repertoire: the MAP-Elites repertoire to sample from emitter_state - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: a new gradient offspring - a new jax PRNG key """ - return emitter_state.offspring, {}, random_key + return emitter_state.offspring, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _sample_exploit( self, emitter_state: MEESEmitterState, repertoire: MapElitesRepertoire, - random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + key: RNGKey, + ) -> Genotype: """Sample half of the time uniformly from the exploit_num_cell_sample highest-performing cells of the repertoire and half of the time uniformly from the exploit_num_cell_sample highest-performing cells among the @@ -340,18 +328,17 @@ def _sample_exploit( Args: emitter_state: current emitter_state repertoire: the current repertoire - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: samples: a genotype sampled in the repertoire - random_key: an updated jax PRNG random key """ def _sample( - random_key: RNGKey, + key: RNGKey, genotypes: Genotype, fitnesses: Fitness, - ) -> Tuple[Genotype, RNGKey]: + ) -> Genotype: """Sample uniformly from the 2 highest fitness cells.""" max_fitnesses, _ = jax.lax.top_k( @@ -362,14 +349,13 @@ def _sample( ) genotypes_empty = fitnesses < min_fitness p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty) - random_key, subkey = jax.random.split(random_key) - samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), + samples = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(1,), p=p), genotypes, ) - return samples, random_key + return samples - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) # Sample p uniformly p = jax.random.uniform(subkey) @@ -383,35 +369,31 @@ def _sample( genotypes=emitter_state.last_updated_genotypes, fitnesses=emitter_state.last_updated_fitnesses, ) - samples, random_key = jax.lax.cond( + samples = jax.lax.cond( p < 0.5, repertoire_sample, last_updated_sample, - random_key, + key, ) - return samples, random_key + return samples - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _sample_explore( self, emitter_state: MEESEmitterState, repertoire: MapElitesRepertoire, - random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + key: RNGKey, + ) -> Genotype: """Sample uniformly from the explore_num_cell_sample most-novel genotypes. Args: emitter_state: current emitter state repertoire: the current genotypes repertoire - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: samples: a genotype sampled in the repertoire - random_key: an updated jax PRNG random key """ # Compute the novelty of all indivs in the archive @@ -429,25 +411,21 @@ def _sample_explore( ) repertoire_empty = novelties < min_novelty p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) - random_key, subkey = jax.random.split(random_key) - samples = jax.tree_util.tree_map( - lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), + samples = jax.tree.map( + lambda x: jax.random.choice(key, x, shape=(1,), p=p), repertoire.genotypes, ) - return samples, random_key + return samples - @partial( - jax.jit, - static_argnames=("self", "scores_fn"), - ) + @partial(jax.jit, static_argnames=("self", "scores_fn")) def _es_emitter( self, parent: Genotype, optimizer_state: optax.OptState, - random_key: RNGKey, + key: RNGKey, scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray], - ) -> Tuple[Genotype, optax.OptState, RNGKey]: + ) -> Tuple[Genotype, optax.OptState]: """Main es component, given a parent and a way to infer the score from the fitnesses and descriptors of its es-samples, return its approximated-gradient-generated offspring. @@ -456,27 +434,27 @@ def _es_emitter( parent: the considered parent. scores_fn: a function to infer the score of its es-samples from their fitness and descriptors. - random_key + key Returns: - The approximated-gradients-generated offspring and a new random_key. + The approximated-gradients-generated offspring and a new key. """ - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) # Sampling mirror noise total_sample_number = self._config.sample_number if self._config.sample_mirror: sample_number = total_sample_number // 2 - half_sample_noise = jax.tree_util.tree_map( + half_sample_noise = jax.tree.map( lambda x: jax.random.normal( key=subkey, shape=jnp.repeat(x, sample_number, axis=0).shape, ), parent, ) - sample_noise = jax.tree_util.tree_map( + sample_noise = jax.tree.map( lambda x: jnp.concatenate( [jnp.expand_dims(x, axis=1), jnp.expand_dims(-x, axis=1)], axis=1 ).reshape(jnp.repeat(x, 2, axis=0).shape), @@ -487,7 +465,7 @@ def _es_emitter( # Sampling non-mirror noise else: sample_number = total_sample_number - sample_noise = jax.tree_util.tree_map( + sample_noise = jax.tree.map( lambda x: jax.random.normal( key=subkey, shape=jnp.repeat(x, sample_number, axis=0).shape, @@ -497,20 +475,18 @@ def _es_emitter( gradient_noise = sample_noise # Applying noise - samples = jax.tree_util.tree_map( + samples = jax.tree.map( lambda x: jnp.repeat(x, total_sample_number, axis=0), parent, ) - samples = jax.tree_util.tree_map( + samples = jax.tree.map( lambda mean, noise: mean + self._config.sample_sigma * noise, samples, sample_noise, ) # Evaluating samples - fitnesses, descriptors, extra_scores, random_key = self._scoring_fn( - samples, random_key - ) + fitnesses, descriptors, extra_scores = self._scoring_fn(samples, key) # Computing rank, with or without normalisation scores = scores_fn(fitnesses, descriptors) @@ -527,7 +503,7 @@ def _es_emitter( if self._config.sample_mirror: ranks = jnp.reshape(ranks, (sample_number, 2)) ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks) - ranks = jax.tree_util.tree_map( + ranks = jax.tree.map( lambda x: jnp.reshape( jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape ), @@ -535,16 +511,16 @@ def _es_emitter( ) # Computing the gradients - gradient = jax.tree_util.tree_map( + gradient = jax.tree.map( lambda noise, rank: jnp.multiply(noise, rank), gradient_noise, ranks, ) - gradient = jax.tree_util.tree_map( + gradient = jax.tree.map( lambda x: jnp.reshape(x, (sample_number, -1)), gradient, ) - gradient = jax.tree_util.tree_map( + gradient = jax.tree.map( lambda g, p: jnp.reshape( -jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma), p.shape, @@ -554,7 +530,7 @@ def _es_emitter( ) # Adding regularisation - gradient = jax.tree_util.tree_map( + gradient = jax.tree.map( lambda g, p: g + self._config.l2_coefficient * p, gradient, parent, @@ -566,12 +542,9 @@ def _es_emitter( ) offspring = optax.apply_updates(parent, offspring_update) - return offspring, optimizer_state, random_key + return offspring, optimizer_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def _buffers_update( self, emitter_state: MEESEmitterState, @@ -601,8 +574,8 @@ def _buffers_update( indice = get_cells_indices(descriptors, repertoire.centroids) added_genotype = jnp.all( jnp.asarray( - jax.tree_util.tree_leaves( - jax.tree_util.tree_map( + jax.tree.leaves( + jax.tree.map( lambda new_gen, rep_gen: jnp.all( jnp.equal( jnp.ravel(new_gen), jnp.ravel(rep_gen.at[indice].get()) @@ -627,7 +600,7 @@ def _buffers_update( last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set( fitnesses[0] ) - last_updated_genotypes = jax.tree_util.tree_map( + last_updated_genotypes = jax.tree.map( lambda last_gen, gen: last_gen.at[ jnp.expand_dims(last_updated_position, axis=0) ].set(gen), @@ -646,10 +619,7 @@ def _buffers_update( last_updated_position=last_updated_position, ) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: MEESEmitterState, @@ -675,10 +645,10 @@ def state_update( The modified emitter state. """ - assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, ( + assert jax.tree.leaves(genotypes)[0].shape[0] == 1, ( "ERROR: MAP-Elites-ES generates 1 offspring per generation, " + "batch_size should be 1, the inputted batch has size:" - + str(jax.tree_util.tree_leaves(genotypes)[0].shape[0]) + + str(jax.tree.leaves(genotypes)[0].shape[0]) ) # Update all the buffers and archives of the emitter_state @@ -698,23 +668,21 @@ def state_update( ) # Select parent and optimizer_state - parent, random_key = jax.lax.cond( + key, subkey = jax.random.split(emitter_state.key) + parent = jax.lax.cond( sample_new_parent, - lambda emitter_state, repertoire, random_key: jax.lax.cond( + lambda emitter_state, repertoire, key: jax.lax.cond( use_exploration, self._sample_explore, self._sample_exploit, emitter_state, repertoire, - random_key, - ), - lambda emitter_state, repertoire, random_key: ( - emitter_state.offspring, - random_key, + key, ), + lambda emitter_state, repertoire, key: emitter_state.offspring, emitter_state, repertoire, - emitter_state.random_key, + subkey, ) optimizer_state = jax.lax.cond( sample_new_parent, @@ -739,10 +707,11 @@ def exploration_exploitation_scores( return scores # Run es process - offspring, optimizer_state, random_key = self._es_emitter( + key, subkey = jax.random.split(key) + offspring, optimizer_state = self._es_emitter( parent=parent, optimizer_state=optimizer_state, - random_key=random_key, + key=subkey, scores_fn=exploration_exploitation_scores, ) @@ -750,5 +719,5 @@ def exploration_exploitation_scores( optimizer_state=optimizer_state, offspring=offspring, generation_count=generation_count + 1, - random_key=random_key, + key=key, ) diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 17cb8ace..a0d309ca 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -57,33 +57,32 @@ def get_indexes_separation_batches( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[Optional[EmitterState], RNGKey]: + ) -> Optional[EmitterState]: """ Initialize the state of the emitter. Args: genotypes: The genotypes of the initial population. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: - The initial emitter state and a random key. + The initial emitter state. """ # prepare keys for each emitter - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, len(self.emitters)) + keys = jax.random.split(key, len(self.emitters)) # init all emitter states - gather them emitter_states = [] - for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init( - subkey_emitter, + for emitter, key_emitter in zip(self.emitters, keys): + emitter_state = emitter.init( + key_emitter, repertoire, genotypes, fitnesses, @@ -92,54 +91,53 @@ def init( ) emitter_states.append(emitter_state) - return MultiEmitterState(tuple(emitter_states)), random_key + return MultiEmitterState(tuple(emitter_states)) @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Optional[Repertoire], emitter_state: Optional[MultiEmitterState], - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Emit new population. Use all the sub emitters to emit subpopulation and gather them. Args: repertoire: a repertoire of genotypes. emitter_state: the current state of the emitter. - random_key: key for random operations. + key: key for random operations. Returns: - Offsprings and a new random key. + Offsprings. """ assert emitter_state is not None assert len(emitter_state.emitter_states) == len(self.emitters) # prepare subkeys for each sub emitter - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, len(self.emitters)) + keys = jax.random.split(key, len(self.emitters)) # emit from all emitters and gather offsprings all_offsprings = [] all_extra_info: ExtraScores = {} - for emitter, sub_emitter_state, subkey_emitter in zip( + for emitter, sub_emitter_state, key_emitter in zip( self.emitters, emitter_state.emitter_states, - subkeys, + keys, ): - genotype, extra_info, _ = emitter.emit( - repertoire, sub_emitter_state, subkey_emitter + genotype, extra_info = emitter.emit( + repertoire, sub_emitter_state, key_emitter ) - batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] + batch_size = jax.tree.leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) all_extra_info = {**all_extra_info, **extra_info} # concatenate offsprings together - offsprings = jax.tree_util.tree_map( + offsprings = jax.tree.map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, all_extra_info, random_key + return offsprings, all_extra_info @partial(jax.jit, static_argnames=("self",)) def state_update( @@ -171,7 +169,7 @@ def state_update( emitter_states = [] def _get_sub_pytree(pytree: ArrayTree, start: int, end: int) -> ArrayTree: - return jax.tree_util.tree_map(lambda x: x[start:end], pytree) + return jax.tree.map(lambda x: x[start:end], pytree) for emitter, sub_emitter_state, index_start, index_end in zip( self.emitters, @@ -193,7 +191,7 @@ def _get_sub_pytree(pytree: ArrayTree, start: int, end: int) -> ArrayTree: # update only with the data of the emitted genotypes else: # extract relevant data - sub_gen, sub_fit, sub_desc, sub_extra_scores = jax.tree_util.tree_map( + sub_gen, sub_fit, sub_desc, sub_extra_scores = jax.tree.map( partial(_get_sub_pytree, start=index_start, end=index_end), ( genotypes, diff --git a/qdax/core/emitters/mutation_operators.py b/qdax/core/emitters/mutation_operators.py index bda2daca..0a4ab8a3 100644 --- a/qdax/core/emitters/mutation_operators.py +++ b/qdax/core/emitters/mutation_operators.py @@ -1,7 +1,7 @@ """File defining mutation and crossover functions.""" from functools import partial -from typing import Optional, Tuple +from typing import Optional import jax import jax.numpy as jnp @@ -11,7 +11,7 @@ def _polynomial_mutation( x: jnp.ndarray, - random_key: RNGKey, + key: RNGKey, proportion_to_mutate: float, eta: float, minval: float, @@ -24,7 +24,7 @@ def _polynomial_mutation( Args: x: parameters. - random_key: a random key + key: a random key proportion_to_mutate: the proportion of the given parameters that need to be mutated. eta: the inverse of the power of the mutation applied. @@ -39,7 +39,7 @@ def _polynomial_mutation( num_positions = x.shape[0] positions = jnp.arange(start=0, stop=num_positions) num_positions_to_mutate = int(proportion_to_mutate * num_positions) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) selected_positions = jax.random.choice( key=subkey, a=positions, shape=(num_positions_to_mutate,), replace=False ) @@ -51,7 +51,7 @@ def _polynomial_mutation( mutpow = 1.0 / (1.0 + eta) # Randomly select where to put delta_1 and delta_2 - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) rand = jax.random.uniform( key=subkey, shape=delta_1.shape, @@ -80,18 +80,18 @@ def _polynomial_mutation( def polynomial_mutation( x: Genotype, - random_key: RNGKey, + key: RNGKey, proportion_to_mutate: float, eta: float, minval: float, maxval: float, -) -> Tuple[Genotype, RNGKey]: +) -> Genotype: """ Polynomial mutation over several genotypes Parameters: x: array of genotypes to transform (real values only) - random_key: RNG key for reproducibility. + key: RNG key for reproducibility. Assumed to be of shape (batch_size, genotype_dim) proportion_to_mutate (float): proportion of variables to mutate in each genotype (must be in [0, 1]). @@ -101,11 +101,10 @@ def polynomial_mutation( maxval: maximum value to clip the genotypes. Returns: - New genotypes - same shape as input and a new RNG key + New genotypes - same shape as input """ - random_key, subkey = jax.random.split(random_key) - batch_size = jax.tree_util.tree_leaves(x)[0].shape[0] - mutation_key = jax.random.split(subkey, num=batch_size) + batch_size = jax.tree.leaves(x)[0].shape[0] + mutation_keys = jax.random.split(key, num=batch_size) mutation_fn = partial( _polynomial_mutation, proportion_to_mutate=proportion_to_mutate, @@ -114,14 +113,14 @@ def polynomial_mutation( maxval=maxval, ) mutation_fn = jax.vmap(mutation_fn) - x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_key), x) - return x, random_key + x = jax.tree.map(lambda x_: mutation_fn(x_, mutation_keys), x) + return x def _polynomial_crossover( x1: jnp.ndarray, x2: jnp.ndarray, - random_key: RNGKey, + key: RNGKey, proportion_var_to_change: float, ) -> jnp.ndarray: """ @@ -132,9 +131,7 @@ def _polynomial_crossover( """ num_var_to_change = int(proportion_var_to_change * x1.shape[0]) indices = jnp.arange(start=0, stop=x1.shape[0]) - selected_indices = jax.random.choice( - random_key, indices, shape=(num_var_to_change,) - ) + selected_indices = jax.random.choice(key, indices, shape=(num_var_to_change,)) x = x1.at[selected_indices].set(x2[selected_indices]) return x @@ -142,9 +139,9 @@ def _polynomial_crossover( def polynomial_crossover( x1: Genotype, x2: Genotype, - random_key: RNGKey, + key: RNGKey, proportion_var_to_change: float, -) -> Tuple[Genotype, RNGKey]: +) -> Genotype: """ Crossover over a set of pairs of genotypes. @@ -156,45 +153,41 @@ def polynomial_crossover( Parameters: x1: first batch of genotypes x2: second batch of genotypes - random_key: RNG key for reproducibility + key: RNG key for reproducibility proportion_var_to_change: proportion of variables to exchange between genotypes (must be [0, 1]) Returns: - New genotypes and a new RNG key + New genotypes """ - - random_key, subkey = jax.random.split(random_key) - batch_size = jax.tree_util.tree_leaves(x2)[0].shape[0] - crossover_keys = jax.random.split(subkey, num=batch_size) + batch_size = jax.tree.leaves(x2)[0].shape[0] + crossover_keys = jax.random.split(key, num=batch_size) crossover_fn = partial( _polynomial_crossover, proportion_var_to_change=proportion_var_to_change, ) crossover_fn = jax.vmap(crossover_fn) # TODO: check that key usage is correct - x = jax.tree_util.tree_map( - lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2 - ) - return x, random_key + x = jax.tree.map(lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2) + return x def isoline_variation( x1: Genotype, x2: Genotype, - random_key: RNGKey, + key: RNGKey, iso_sigma: float, line_sigma: float, minval: Optional[float] = None, maxval: Optional[float] = None, -) -> Tuple[Genotype, RNGKey]: +) -> Genotype: """ Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes Parameters: x1 (Genotypes): first batch of genotypes x2 (Genotypes): second batch of genotypes - random_key (RNGKey): RNG key for reproducibility + key (RNGKey): RNG key for reproducibility iso_sigma (float): spread parameter (noise) line_sigma (float): line parameter (direction of the new genotype) minval (float, Optional): minimum value to clip the genotypes @@ -202,7 +195,6 @@ def isoline_variation( Returns: x (Genotypes): new genotypes - random_key (RNGKey): new RNG key [1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and @@ -210,14 +202,12 @@ def isoline_variation( """ # Computing line_noise - random_key, key_line_noise = jax.random.split(random_key) - batch_size = jax.tree_util.tree_leaves(x1)[0].shape[0] + key, key_line_noise = jax.random.split(key) + batch_size = jax.tree.leaves(x1)[0].shape[0] line_noise = jax.random.normal(key_line_noise, shape=(batch_size,)) * line_sigma - def _variation_fn( - x1: jnp.ndarray, x2: jnp.ndarray, random_key: RNGKey - ) -> jnp.ndarray: - iso_noise = jax.random.normal(random_key, shape=x1.shape) * iso_sigma + def _variation_fn(x1: jnp.ndarray, x2: jnp.ndarray, key: RNGKey) -> jnp.ndarray: + iso_noise = jax.random.normal(key, shape=x1.shape) * iso_sigma x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1), line_noise) # Back in bounds if necessary (floating point issues) @@ -226,14 +216,11 @@ def _variation_fn( return x # create a tree with random keys - nb_leaves = len(jax.tree_util.tree_leaves(x1)) - random_key, subkey = jax.random.split(random_key) - subkeys = jax.random.split(subkey, num=nb_leaves) - keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), subkeys) + nb_leaves = len(jax.tree.leaves(x1)) + keys = jax.random.split(key, num=nb_leaves) + keys_tree = jax.tree.unflatten(jax.tree.structure(x1), keys) # apply isolinedd to each branch of the tree - x = jax.tree_util.tree_map( - lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree - ) + x = jax.tree.map(lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree) - return x, random_key + return x diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index c7aac8c9..372db531 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -92,28 +92,28 @@ def __init__( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[OMGMEGAEmitterState, RNGKey]: + ) -> OMGMEGAEmitterState: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: genotypes: The genotypes of the initial population. - random_key: a random key to handle stochastic operations. + key: a random key to handle stochastic operations. Returns: The initial emitter state. """ # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) + first_genotype = jax.tree.map(lambda x: x[0], genotypes) # add a dimension of size num descriptors + 1 - gradient_genotype = jax.tree_util.tree_map( + gradient_genotype = jax.tree.map( lambda x: jnp.repeat( jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1 ), @@ -137,21 +137,15 @@ def init( extra_scores, ) - return ( - OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire), - random_key, - ) + return OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire) - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: MapElitesRepertoire, emitter_state: OMGMEGAEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates. @@ -159,28 +153,26 @@ def emit( Args: repertoire: current repertoire emitter_state: current emitter state, contains the gradients - random_key: random key + key: random key Returns: new_genotypes: new candidates to be added to the grid - random_key: updated random key """ # sample genotypes - ( - genotypes, - _, - ) = repertoire.sample(random_key, num_samples=self._batch_size) + key, subkey = jax.random.split(key) + genotypes = repertoire.sample(subkey, num_samples=self._batch_size) # sample gradients - use the same random key for sampling # See class docstrings for discussion about this choice - gradients, random_key = emitter_state.gradients_repertoire.sample( - random_key, num_samples=self._batch_size + key, subkey = jax.random.split(key) + gradients = emitter_state.gradients_repertoire.sample( + subkey, num_samples=self._batch_size ) - fitness_gradients = jax.tree_util.tree_map( + fitness_gradients = jax.tree.map( lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients ) - descriptors_gradients = jax.tree_util.tree_map(lambda x: x[:, :, 1:], gradients) + descriptors_gradients = jax.tree.map(lambda x: x[:, :, 1:], gradients) # Normalize the gradients norm_fitness_gradients = jnp.linalg.norm( @@ -195,15 +187,14 @@ def emit( descriptors_gradients = descriptors_gradients / norm_descriptors_gradients # Draw random coefficients - random_key, subkey = jax.random.split(random_key) coeffs = jax.random.multivariate_normal( - subkey, + key, shape=(self._batch_size,), mean=self._mu, cov=self._sigma, ) coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0])) - grads = jax.tree_util.tree_map( + grads = jax.tree.map( lambda x, y: jnp.concatenate((x, y), axis=-1), fitness_gradients, descriptors_gradients, @@ -211,16 +202,11 @@ def emit( update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1) # update the genotypes - new_genotypes = jax.tree_util.tree_map( - lambda x, y: x + y, genotypes, update_grad - ) + new_genotypes = jax.tree.map(lambda x, y: x + y, genotypes, update_grad) - return new_genotypes, {}, random_key + return new_genotypes, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def state_update( self, emitter_state: OMGMEGAEmitterState, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 55bded4e..1a0f2edf 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -25,7 +25,7 @@ class PBTEmitterState(EmitterState): replay_buffers: ReplayBuffer env_states: EnvState training_states: PBTTrainingState - random_key: RNGKey + key: RNGKey class PBTEmitterConfig(PyTreeNode): @@ -92,21 +92,21 @@ def __init__( def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[PBTEmitterState, RNGKey]: + ) -> PBTEmitterState: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ observation_size = self._env.observation_size @@ -131,11 +131,11 @@ def init( replay_buffers = replay_buffer_init(transition=dummy_transitions) # Initialise env states - (random_key, subkey1, subkey2) = jax.random.split(random_key, num=3) - env_states = jax.jit(self._env.reset)(rng=subkey1) + key, subkey = jax.random.split(key) + env_states = jax.jit(self._env.reset)(rng=subkey) reshape_fn = jax.jit( - lambda tree: jax.tree_util.tree_map( + lambda tree: jax.tree.map( lambda x: jnp.reshape( x, ( @@ -151,28 +151,25 @@ def init( # Create emitter state # keep only pg population size training states if more are provided - genotypes = jax.tree_util.tree_map( + genotypes = jax.tree.map( lambda x: x[: self._config.pg_population_size_per_device], genotypes ) emitter_state = PBTEmitterState( replay_buffers=replay_buffers, env_states=env_states, training_states=genotypes, - random_key=subkey2, + key=key, ) - return emitter_state, random_key + return emitter_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: PBTEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid. @@ -180,7 +177,7 @@ def emit( Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: A batch of offspring, the new emitter state and a new key. @@ -192,12 +189,13 @@ def emit( # Mutation evo if self._config.ga_population_size_per_device > 0: mutation_ga_batch_size = self._config.ga_population_size_per_device - x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size) - x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size) - x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key) + sample_key_1, sample_key_2, variation_key = jax.random.split(key, 3) + x1 = repertoire.sample(sample_key_1, mutation_ga_batch_size) + x2 = repertoire.sample(sample_key_2, mutation_ga_batch_size) + x_mutation_ga = self._variation_fn(x1, x2, variation_key) # Gather offspring - genotypes = jax.tree_util.tree_map( + genotypes = jax.tree.map( lambda x, y: jnp.concatenate([x, y], axis=0), x_mutation_ga, x_mutation_pg, @@ -205,7 +203,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, {}, random_key + return genotypes, {} @property def batch_size(self) -> int: @@ -259,10 +257,10 @@ def state_update( * self._config.fraction_best_to_replace_from ) indices_to_share = indices_to_share[:num_best_local] - genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map( + genotypes_to_share, fitnesses_to_share = jax.tree.map( lambda x: x[indices_to_share], (genotypes, fitnesses) ) - gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map( + gathered_genotypes, gathered_fitnesses = jax.tree.map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), (genotypes_to_share, fitnesses_to_share), ) @@ -270,7 +268,7 @@ def state_update( genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses best_indices_stacked = jnp.argsort(-fitnesses_stacked) best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from] - best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map( + best_genotypes_local, best_fitnesses_local = jax.tree.map( lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked) ) @@ -282,15 +280,15 @@ def _loop_fn(i, val): # type: ignore [i * self._num_to_exchange], [self._num_to_exchange], ) - genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map( + genotypes_to_share, fitnesses_to_share = jax.tree.map( lambda x: x[indices_to_share], (genotypes, fitnesses) ) - gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map( + gathered_genotypes, gathered_fitnesses = jax.tree.map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), (genotypes_to_share, fitnesses_to_share), ) - genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map( + genotypes_stacked, fitnesses_stacked = jax.tree.map( lambda x, y: jnp.concatenate([x, y], axis=0), (gathered_genotypes, gathered_fitnesses), (best_genotypes_local, best_fitnesses_local), @@ -300,7 +298,7 @@ def _loop_fn(i, val): # type: ignore best_indices_stacked = best_indices_stacked[ : self._num_best_to_replace_from ] - best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map( + best_genotypes_local, best_fitnesses_local = jax.tree.map( lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked), ) @@ -316,17 +314,17 @@ def _loop_fn(i, val): # type: ignore ) # Gather fitnesses from all devices to rank locally against it - all_fitnesses = jax.tree_util.tree_map( + all_fitnesses = jax.tree.map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), fitnesses, ) all_fitnesses = jnp.ravel(all_fitnesses) all_fitnesses = -jnp.sort(-all_fitnesses) - random_key = emitter_state.random_key - random_key, sub_key = jax.random.split(random_key) - best_genotypes = jax.tree_util.tree_map( + key = emitter_state.key + key, subkey = jax.random.split(key) + best_genotypes = jax.tree.map( lambda x: jax.random.choice( - sub_key, x, shape=(len(fitnesses),), replace=True + subkey, x, shape=(len(fitnesses),), replace=True ), best_genotypes_local, ) @@ -341,7 +339,7 @@ def _loop_fn(i, val): # type: ignore lower_bound = all_fitnesses[-self._num_to_replace_from_best] cond = fitnesses <= lower_bound - training_states = jax.tree_util.tree_map( + training_states = jax.tree.map( lambda x, y: jnp.where( jnp.expand_dims( cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)]) @@ -352,7 +350,7 @@ def _loop_fn(i, val): # type: ignore best_training_states, training_states, ) - replay_buffers = jax.tree_util.tree_map( + replay_buffers = jax.tree.map( lambda x, y: jnp.where( jnp.expand_dims( cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)]) @@ -366,8 +364,9 @@ def _loop_fn(i, val): # type: ignore # Replacing with samples from the ME repertoire if self._num_to_replace_from_samples > 0: - me_samples, random_key = repertoire.sample( - random_key, self._config.pg_population_size_per_device + key, subkey = jax.random.split(key) + me_samples = repertoire.sample( + subkey, self._config.pg_population_size_per_device ) # Resample hyper-params me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples) @@ -375,7 +374,7 @@ def _loop_fn(i, val): # type: ignore -self._num_to_replace_from_best - self._num_to_replace_from_samples ] cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound) - training_states = jax.tree_util.tree_map( + training_states = jax.tree.map( lambda x, y: jnp.where( jnp.expand_dims( cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)]) @@ -407,6 +406,6 @@ def _loop_fn(i, val): # type: ignore training_states=training_states, replay_buffers=replay_buffers, env_states=env_states, - random_key=random_key, + key=key, ) return emitter_state # type: ignore diff --git a/qdax/core/emitters/pbt_variation_operators.py b/qdax/core/emitters/pbt_variation_operators.py index c8537003..54a02e8f 100644 --- a/qdax/core/emitters/pbt_variation_operators.py +++ b/qdax/core/emitters/pbt_variation_operators.py @@ -1,5 +1,3 @@ -from typing import Tuple - from qdax.baselines.sac_pbt import PBTSacTrainingState from qdax.baselines.td3_pbt import PBTTD3TrainingState from qdax.core.emitters.mutation_operators import isoline_variation @@ -9,10 +7,10 @@ def sac_pbt_variation_fn( training_state1: PBTSacTrainingState, training_state2: PBTSacTrainingState, - random_key: RNGKey, + key: RNGKey, iso_sigma: float, line_sigma: float, -) -> Tuple[PBTSacTrainingState, RNGKey]: +) -> PBTSacTrainingState: """ This operator runs a cross-over between two SAC agents. It is used as variation operator in the SAC-PBT-Map-Elites algorithm. An isoline-dd variation is applied @@ -21,12 +19,12 @@ def sac_pbt_variation_fn( Args: training_state1: Training state of first SAC agent. training_state2: Training state of first SAC agent. - random_key: Random key. + key: Random key. iso_sigma: Spread parameter (noise). line_sigma: Line parameter (direction of the new genotype). Returns: - A new SAC training state obtained from cross-over and an updated random key. + A new SAC training state obtained from cross-over. """ @@ -42,10 +40,10 @@ def sac_pbt_variation_fn( training_state1.alpha_params, training_state2.alpha_params, ) - (policy_params, critic_params, alpha_params), random_key = isoline_variation( + policy_params, critic_params, alpha_params = isoline_variation( x1=(policy_params1, critic_params1, alpha_params1), x2=(policy_params2, critic_params2, alpha_params2), - random_key=random_key, + key=key, iso_sigma=iso_sigma, line_sigma=line_sigma, ) @@ -56,19 +54,16 @@ def sac_pbt_variation_fn( alpha_params=alpha_params, ) - return ( - new_training_state, - random_key, - ) + return new_training_state # type: ignore def td3_pbt_variation_fn( training_state1: PBTTD3TrainingState, training_state2: PBTTD3TrainingState, - random_key: RNGKey, + key: RNGKey, iso_sigma: float, line_sigma: float, -) -> Tuple[PBTTD3TrainingState, RNGKey]: +) -> PBTTD3TrainingState: """ This operator runs a cross-over between two TD3 agents. It is used as variation operator in the TD3-PBT-Map-Elites algorithm. An isoline-dd variation is applied @@ -77,12 +72,12 @@ def td3_pbt_variation_fn( Args: training_state1: Training state of first TD3 agent. training_state2: Training state of first TD3 agent. - random_key: Random key. + key: Random key. iso_sigma: Spread parameter (noise). line_sigma: Line parameter (direction of the new genotype). Returns: - A new TD3 training state obtained from cross-over and an updated random key. + A new TD3 training state obtained from cross-over. """ @@ -94,13 +89,10 @@ def td3_pbt_variation_fn( training_state1.critic_params, training_state2.critic_params, ) - ( - policy_params, - critic_params, - ), random_key = isoline_variation( + policy_params, critic_params = isoline_variation( x1=(policy_params1, critic_params1), x2=(policy_params2, critic_params2), - random_key=random_key, + key=key, iso_sigma=iso_sigma, line_sigma=line_sigma, ) @@ -109,7 +101,4 @@ def td3_pbt_variation_fn( critic_params=critic_params, ) - return ( - new_training_state, - random_key, - ) + return new_training_state # type: ignore diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index 3616a4b9..1ba8788f 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -2,7 +2,7 @@ paper https://arxiv.org/abs/2006.08505. QDPG has been updated to enter in the container+emitter framework of QD. Furthermore, -it has been updated to work better with Jax in term of time cost. Those changes have +it has been updated to work better with JAX in term of time cost. Those changes have been made in accordance with the authors of this algorithm. """ diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index 63373494..87ec3e4e 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -54,7 +54,7 @@ class QualityPGEmitterState(EmitterState): target_critic_params: Params target_actor_params: Params replay_buffer: ReplayBuffer - random_key: RNGKey + key: RNGKey steps: jnp.ndarray @@ -120,21 +120,21 @@ def use_all_data(self) -> bool: def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[QualityPGEmitterState, RNGKey]: + ) -> QualityPGEmitterState: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: - The initial state of the PGAMEEmitter, a new random key. + The initial state of the PGAMEEmitter. """ observation_size = self._env.observation_size @@ -142,16 +142,16 @@ def init( descriptor_size = self._env.state_descriptor_length # Initialise critic, greedy actor and population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) fake_obs = jnp.zeros(shape=(observation_size,)) fake_action = jnp.zeros(shape=(action_size,)) critic_params = self._critic_network.init( subkey, obs=fake_obs, actions=fake_action ) - target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + target_critic_params = jax.tree.map(lambda x: x, critic_params) - actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) - target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) + actor_params = jax.tree.map(lambda x: x[0], genotypes) + target_actor_params = jax.tree.map(lambda x: x[0], genotypes) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -176,7 +176,6 @@ def init( replay_buffer = replay_buffer.insert(transitions) # Initial training state - random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( critic_params=critic_params, critic_optimizer_state=critic_optimizer_state, @@ -185,38 +184,35 @@ def init( target_critic_params=target_critic_params, target_actor_params=target_actor_params, replay_buffer=replay_buffer, - random_key=subkey, + key=key, steps=jnp.array(0), ) - return emitter_state, random_key + return emitter_state - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: QualityPGEmitterState, - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """Do a step of PG emission. Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: - A batch of offspring, the new emitter state and a new key. + A batch of offspring, the new emitter state. """ batch_size = self._config.env_batch_size # sample parents mutation_pg_batch_size = int(batch_size - 1) - parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size) + parents = repertoire.sample(key, mutation_pg_batch_size) # apply the pg mutation offsprings_pg = self.emit_pg(emitter_state, parents) @@ -225,23 +221,20 @@ def emit( offspring_actor = self.emit_actor(emitter_state) # add dimension for concatenation - offspring_actor = jax.tree_util.tree_map( + offspring_actor = jax.tree.map( lambda x: jnp.expand_dims(x, axis=0), offspring_actor ) # gather offspring - genotypes = jax.tree_util.tree_map( + genotypes = jax.tree.map( lambda x, y: jnp.concatenate([x, y], axis=0), offsprings_pg, offspring_actor, ) - return genotypes, {}, random_key + return genotypes, {} - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_pg( self, emitter_state: QualityPGEmitterState, parents: Genotype ) -> Genotype: @@ -264,10 +257,7 @@ def emit_pg( return offsprings - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: """Emit the greedy actor. @@ -322,7 +312,7 @@ def state_update( emitter_state = emitter_state.replace(replay_buffer=replay_buffer) def scan_train_critics( - carry: QualityPGEmitterState, unused: Any + carry: QualityPGEmitterState, _: Any ) -> Tuple[QualityPGEmitterState, Any]: emitter_state = carry new_emitter_state = self._train_critics(emitter_state) @@ -355,27 +345,26 @@ def _train_critics( New emitter state where the critic and the greedy actor have been updated. Optimizer states have also been updated in the process. """ + key = emitter_state.key # Sample a batch of transitions in the buffer - random_key = emitter_state.random_key + key, subkey = jax.random.split(key) replay_buffer = emitter_state.replay_buffer - transitions, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size - ) + transitions = replay_buffer.sample(subkey, sample_size=self._config.batch_size) # Update Critic + key, subkey = jax.random.split(key) ( critic_optimizer_state, critic_params, target_critic_params, - random_key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_optimizer_state=emitter_state.critic_optimizer_state, transitions=transitions, - random_key=random_key, + key=subkey, ) # Update greedy actor @@ -408,7 +397,7 @@ def _train_critics( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, replay_buffer=replay_buffer, ) @@ -423,17 +412,16 @@ def _update_critic( target_actor_params: Params, critic_optimizer_state: Params, transitions: QDTransition, - random_key: RNGKey, - ) -> Tuple[Params, Params, Params, RNGKey]: + key: RNGKey, + ) -> Tuple[Params, Params, Params]: # compute loss and gradients - random_key, subkey = jax.random.split(random_key) critic_gradient = jax.grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, transitions, - subkey, + key, ) critic_updates, critic_optimizer_state = self._critic_optimizer.update( critic_gradient, critic_optimizer_state @@ -443,14 +431,14 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_util.tree_map( + target_critic_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, critic_params, ) - return critic_optimizer_state, critic_params, target_critic_params, random_key + return critic_optimizer_state, critic_params, target_critic_params @partial(jax.jit, static_argnames=("self",)) def _update_actor( @@ -475,7 +463,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_util.tree_map( + target_actor_params = jax.tree.map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -513,7 +501,7 @@ def _mutation_function_pg( def scan_train_policy( carry: Tuple[QualityPGEmitterState, Genotype, optax.OptState], - unused: Any, + _: Any, ) -> Tuple[Tuple[QualityPGEmitterState, Genotype, optax.OptState], Any]: emitter_state, policy_params, policy_optimizer_state = carry ( @@ -561,13 +549,12 @@ def _train_policy( Returns: The new emitter state and new params of the NN. """ + key = emitter_state.key # Sample a batch of transitions in the buffer - random_key = emitter_state.random_key + key, subkey = jax.random.split(key) replay_buffer = emitter_state.replay_buffer - transitions, random_key = replay_buffer.sample( - random_key, sample_size=self._config.batch_size - ) + transitions = replay_buffer.sample(subkey, sample_size=self._config.batch_size) # update policy policy_optimizer_state, policy_params = self._update_policy( @@ -579,7 +566,7 @@ def _train_policy( # Create new training state new_emitter_state = emitter_state.replace( - random_key=random_key, + key=key, replay_buffer=replay_buffer, ) diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 1d949b2d..4931bcd3 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -22,16 +22,13 @@ def __init__( self._variation_percentage = variation_percentage self._batch_size = batch_size - @partial( - jax.jit, - static_argnames=("self",), - ) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Repertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[Genotype, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores]: """ Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, @@ -45,37 +42,37 @@ def emit( Params: repertoire: the MAP-Elites repertoire to sample from emitter_state: void - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: a batch of offsprings - a new jax PRNG key """ n_variation = int(self._batch_size * self._variation_percentage) n_mutation = self._batch_size - n_variation if n_variation > 0: - x1, random_key = repertoire.sample(random_key, n_variation) - x2, random_key = repertoire.sample(random_key, n_variation) - - x_variation, random_key = self._variation_fn(x1, x2, random_key) + sample_key_1, sample_key_2, variation_key = jax.random.split(key, 3) + x1 = repertoire.sample(sample_key_1, n_variation) + x2 = repertoire.sample(sample_key_2, n_variation) + x_variation = self._variation_fn(x1, x2, variation_key) if n_mutation > 0: - x1, random_key = repertoire.sample(random_key, n_mutation) - x_mutation, random_key = self._mutation_fn(x1, random_key) + sample_key, mutation_key = jax.random.split(key) + x1 = repertoire.sample(sample_key, n_mutation) + x_mutation = self._mutation_fn(x1, mutation_key) if n_variation == 0: genotypes = x_mutation elif n_mutation == 0: genotypes = x_variation else: - genotypes = jax.tree_util.tree_map( + genotypes = jax.tree.map( lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0), x_variation, x_mutation, ) - return genotypes, {}, random_key + return genotypes, {} @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c3155dd3..1cacd4ae 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -41,7 +41,7 @@ class MAPElites: def __init__( self, scoring_function: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores] ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], @@ -55,8 +55,8 @@ def init( self, genotypes: Genotype, centroids: Centroid, - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState]]: """ Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method @@ -66,16 +66,15 @@ def init( genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: An initialized MAP-Elite repertoire with the initial state of the emitter, and a random key. """ # score initial genotypes - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = MapElitesRepertoire.init( @@ -87,8 +86,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -96,15 +96,15 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state @partial(jax.jit, static_argnames=("self",)) def update( self, repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], - random_key: RNGKey, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + key: RNGKey, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]: """ Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -116,7 +116,7 @@ def update( Args: repertoire: the MAP-Elites repertoire emitter_state: state of the emitter - random_key: a jax PRNG random key + key: a jax PRNG random key Returns: the updated MAP-Elites repertoire @@ -125,14 +125,12 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, extra_info, random_key = self._emitter.emit( - repertoire, emitter_state, random_key - ) + key, subkey = jax.random.split(key) + genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey) # scores the offsprings - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # add genotypes in the repertoire repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores) @@ -150,13 +148,13 @@ def update( # update the metrics metrics = self._metrics_function(repertoire) - return repertoire, emitter_state, metrics, random_key + return repertoire, emitter_state, metrics @partial(jax.jit, static_argnames=("self",)) def scan_update( self, carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], - unused: Any, + _: Any, ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive. @@ -164,21 +162,21 @@ def scan_update( Args: carry: a tuple containing the repertoire, the emitter state and a random key. - unused: unused element, necessary to respect jax.lax.scan API. + _: unused element, necessary to respect jax.lax.scan API. Returns: The updated repertoire and emitter state, with a new random key and metrics. """ - repertoire, emitter_state, random_key = carry + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) ( repertoire, emitter_state, metrics, - random_key, ) = self.update( repertoire, emitter_state, - random_key, + subkey, ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state, key), metrics diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 8b0e7511..36fa1be3 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -58,8 +58,8 @@ def init( self, genotypes: Genotype, centroids: Centroid, - random_key: RNGKey, - ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MELSRepertoire, Optional[EmitterState]]: """Initialize a MAP-Elites Low-Spread repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. @@ -68,16 +68,15 @@ def init( genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) - random_key: a random key used for stochastic operations. + key: a random key used for stochastic operations. Returns: A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter state, JAX random key). """ # score initial genotypes - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = MELSRepertoire.init( @@ -89,8 +88,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -107,4 +107,4 @@ def init( descriptors=descriptors, extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/core/mome.py b/qdax/core/mome.py index c239bd1f..bd81f36f 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -26,8 +26,8 @@ def init( genotypes: jnp.ndarray, centroids: Centroid, pareto_front_max_length: int, - random_key: RNGKey, - ) -> Tuple[MOMERepertoire, Optional[EmitterState], RNGKey]: + key: RNGKey, + ) -> Tuple[MOMERepertoire, Optional[EmitterState]]: """Initialize a MOME grid with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. @@ -37,16 +37,15 @@ def init( centroids: centroids of the repertoire. pareto_front_max_length: maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint. - random_key: a random key to handle stochasticity. + key: a random key to handle stochasticity. Returns: The initial repertoire and emitter state, and a new random key. """ # first score - fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - genotypes, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey) # init the repertoire repertoire = MOMERepertoire.init( @@ -59,8 +58,9 @@ def init( ) # get initial state of the emitter - emitter_state, random_key = self._emitter.init( - random_key=random_key, + key, subkey = jax.random.split(key) + emitter_state = self._emitter.init( + key=subkey, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -68,4 +68,4 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + return repertoire, emitter_state diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 81f1e896..8978e185 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -466,22 +466,21 @@ def init( @partial(jax.jit, static_argnames=("sample_size",)) def sample( self, - random_key: RNGKey, + key: RNGKey, sample_size: int, ) -> Tuple[Transition, RNGKey]: """ Sample a batch of transitions in the replay buffer. """ - random_key, subkey = jax.random.split(random_key) idx = jax.random.randint( - subkey, + key, shape=(sample_size,), minval=0, maxval=self.current_size, ) samples = jnp.take(self.data, idx, axis=0, mode="clip") transitions = self.transition.__class__.from_flatten(samples, self.transition) - return transitions, random_key + return transitions @jax.jit def insert(self, transitions: Transition) -> ReplayBuffer: diff --git a/qdax/core/neuroevolution/buffers/trajectory_buffer.py b/qdax/core/neuroevolution/buffers/trajectory_buffer.py index 93e1b2f9..aca96b56 100644 --- a/qdax/core/neuroevolution/buffers/trajectory_buffer.py +++ b/qdax/core/neuroevolution/buffers/trajectory_buffer.py @@ -119,9 +119,9 @@ def init( # type: ignore @partial(jax.jit, static_argnames=("sample_size")) def sample( self, - random_key: RNGKey, + key: RNGKey, sample_size: int, - ) -> Tuple[Transition, RNGKey]: + ) -> Transition: """ Sample transitions from the buffer. If sample_traj=False, returns stacked transitions in the shape (sample_size,), if sample_traj=True, return transitions @@ -130,9 +130,8 @@ def sample( # Here we want to sample single transitions # We sample uniformly at random the indexes of valid transitions - random_key, subkey = jax.random.split(random_key) idx = jax.random.randint( - subkey, + key, shape=(sample_size,), minval=0, maxval=self.current_size, @@ -142,28 +141,27 @@ def sample( # (sample_size, concat_dim) transitions = self.transition.__class__.from_flatten(samples, self.transition) - return transitions, random_key + return transitions def sample_with_returns( self, - random_key: RNGKey, + key: RNGKey, sample_size: int, - ) -> Tuple[Transition, Reward, RNGKey]: + ) -> Tuple[Transition, Reward]: """Sample transitions and the return corresponding to their episode. The returns are compute by the method `compute_returns`. Args: - random_key: a random key + key: a random key sample_size: the number of transitions Returns: - The transitions, the associated returns and a new random key. + The transitions, the associated returns. """ # Here we want to sample single transitions # We sample uniformly at random the indexes of valid transitions - random_key, subkey = jax.random.split(random_key) idx = jax.random.randint( - subkey, + key, shape=(sample_size,), minval=0, maxval=self.current_size, @@ -173,7 +171,7 @@ def sample_with_returns( returns = jnp.take(self.returns, idx, mode="clip") # (sample_size, concat_dim) transitions = self.transition.__class__.from_flatten(samples, self.transition) - return transitions, returns, random_key + return transitions, returns @jax.jit def insert(self, transitions: Transition) -> TrajectoryBuffer: diff --git a/qdax/core/neuroevolution/losses/sac_loss.py b/qdax/core/neuroevolution/losses/sac_loss.py index d7289292..6662eb7f 100644 --- a/qdax/core/neuroevolution/losses/sac_loss.py +++ b/qdax/core/neuroevolution/losses/sac_loss.py @@ -71,7 +71,7 @@ def sac_policy_loss_fn( critic_params: Params, alpha: jnp.ndarray, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """ Creates the policy loss used in SAC. @@ -84,16 +84,14 @@ def sac_policy_loss_fn( critic_params: parameters of the critic alpha: entropy coefficient value transitions: transitions collected by the agent - random_key: random key + key: random key Returns: the loss of the policy """ dist_params = policy_fn(policy_params, transitions.obs) - action = parametric_action_distribution.sample_no_postprocessing( - dist_params, random_key - ) + action = parametric_action_distribution.sample_no_postprocessing(dist_params, key) log_prob = parametric_action_distribution.log_prob(dist_params, action) action = parametric_action_distribution.postprocess(action) q_action = critic_fn(critic_params, transitions.obs, action) @@ -114,7 +112,7 @@ def sac_critic_loss_fn( target_critic_params: Params, alpha: jnp.ndarray, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """ Creates the critic loss used in SAC. @@ -128,7 +126,7 @@ def sac_critic_loss_fn( target_critic_params: parameters of the target critic alpha: entropy coefficient value transitions: transitions collected by the agent - random_key: random key + key: random key reward_scaling: a multiplicative factor to the reward discount: the discount factor @@ -139,7 +137,7 @@ def sac_critic_loss_fn( q_old_action = critic_fn(critic_params, transitions.obs, transitions.actions) next_dist_params = policy_fn(policy_params, transitions.next_obs) next_action = parametric_action_distribution.sample_no_postprocessing( - next_dist_params, random_key + next_dist_params, key ) next_log_prob = parametric_action_distribution.log_prob( next_dist_params, next_action @@ -168,7 +166,7 @@ def sac_alpha_loss_fn( action_size: int, policy_params: Params, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """ Creates the alpha loss used in SAC. @@ -180,7 +178,7 @@ def sac_alpha_loss_fn( parametric_action_distribution: the distribution over actions policy_params: parameters of the policy transitions: transitions collected by the agent - random_key: random key + key: random key action_size: the size of the environment's action space Returns: @@ -190,9 +188,7 @@ def sac_alpha_loss_fn( target_entropy = -0.5 * action_size dist_params = policy_fn(policy_params, transitions.obs) - action = parametric_action_distribution.sample_no_postprocessing( - dist_params, random_key - ) + action = parametric_action_distribution.sample_no_postprocessing(dist_params, key) log_prob = parametric_action_distribution.log_prob(dist_params, action) alpha = jnp.exp(log_alpha) alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 964c2c4f..28fd9eb2 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -56,12 +56,11 @@ def _critic_loss_fn( target_policy_params: Params, target_critic_params: Params, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """Critics loss function for TD3 agent""" noise = ( - jax.random.normal(random_key, shape=transitions.actions.shape) - * policy_noise + jax.random.normal(key, shape=transitions.actions.shape) * policy_noise ).clip(-noise_clip, noise_clip) next_action = ( @@ -158,12 +157,11 @@ def _critic_loss_fn( target_actor_params: Params, target_critic_params: Params, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """Descriptor-conditioned critic loss function for TD3 agent""" noise = ( - jax.random.normal(random_key, shape=transitions.actions.shape) - * policy_noise + jax.random.normal(key, shape=transitions.actions.shape) * policy_noise ).clip(-noise_clip, noise_clip) next_action = ( @@ -236,7 +234,7 @@ def td3_critic_loss_fn( reward_scaling: float, discount: float, transitions: Transition, - random_key: RNGKey, + key: RNGKey, ) -> jnp.ndarray: """Critics loss function for TD3 agent. @@ -256,7 +254,7 @@ def td3_critic_loss_fn( Return the loss function used to train the critic in TD3. """ noise = ( - jax.random.normal(random_key, shape=transitions.actions.shape) * policy_noise + jax.random.normal(key, shape=transitions.actions.shape) * policy_noise ).clip(-noise_clip, noise_clip) next_action = (policy_fn(target_policy_params, transitions.next_obs) + noise).clip( diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index f269a22b..4c79657f 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.custom_types import Descriptor, Genotype, Params, RNGKey +from qdax.custom_types import Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -26,7 +26,7 @@ class TrainingState(PyTreeNode): def generate_unroll( init_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_step_fn: Callable[ [EnvState, Params, RNGKey], @@ -44,7 +44,7 @@ def generate_unroll( Args: init_state: first state of the rollout. policy_params: params of the individual. - random_key: random key for stochasiticity handling. + key: random key for stochasiticity handling. episode_length: length of the rollout. play_step_fn: function describing how a step need to be taken. @@ -55,66 +55,12 @@ def generate_unroll( def _scan_play_step_fn( carry: Tuple[EnvState, Params, RNGKey], unused_arg: Any ) -> Tuple[Tuple[EnvState, Params, RNGKey], Transition]: - env_state, policy_params, random_key, transitions = play_step_fn(*carry) - return (env_state, policy_params, random_key), transitions + env_state, policy_params, key, transitions = play_step_fn(*carry) + return (env_state, policy_params, key), transitions (state, _, _), transitions = jax.lax.scan( _scan_play_step_fn, - (init_state, policy_params, random_key), - (), - length=episode_length, - ) - return state, transitions - - -@partial(jax.jit, static_argnames=("play_step_actor_dc_fn", "episode_length")) -def generate_unroll_actor_dc( - init_state: EnvState, - actor_dc_params: Params, - desc: Descriptor, - random_key: RNGKey, - episode_length: int, - play_step_actor_dc_fn: Callable[ - [EnvState, Descriptor, Params, RNGKey], - Tuple[ - EnvState, - Descriptor, - Params, - RNGKey, - Transition, - ], - ], -) -> Tuple[EnvState, Transition]: - """Generates an episode according to the agent's policy and descriptor, - returns the final state of the episode and the transitions of the episode. - - Args: - init_state: first state of the rollout. - policy_dc_params: descriptor-conditioned policy params. - desc: descriptor the policy attempts to achieve. - random_key: random key for stochasiticity handling. - episode_length: length of the rollout. - play_step_fn: function describing how a step need to be taken. - - Returns: - A new state, the experienced transition. - """ - - def _scan_play_step_fn( - carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any - ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: - ( - env_state, - actor_dc_params, - desc, - random_key, - transitions, - ) = play_step_actor_dc_fn(*carry) - return (env_state, actor_dc_params, desc, random_key), transitions - - (state, _, _, _), transitions = jax.lax.scan( - _scan_play_step_fn, - (init_state, actor_dc_params, desc, random_key), + (init_state, policy_params, key), (), length=episode_length, ) @@ -134,15 +80,15 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray: # the double transpose trick is here to allow easy broadcasting return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T - return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore + return jax.tree.map(mask_episodes, transition) # type: ignore def init_population_controllers( policy_network: nn.Module, env: brax.envs.Env, batch_size: int, - random_key: RNGKey, -) -> Tuple[Genotype, RNGKey]: + key: RNGKey, +) -> Genotype: """ Initializes the population of controllers using a policy_network. @@ -151,15 +97,13 @@ def init_population_controllers( controllers. env: the BRAX environment. batch_size: the number of environments we play simultaneously. - random_key: a JAX random key. + key: a JAX random key. Returns: A tuple of the initial population and the new random key. """ - random_key, subkey = jax.random.split(random_key) - - keys = jax.random.split(subkey, num=batch_size) + keys = jax.random.split(key, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - return init_variables, random_key + return init_variables diff --git a/qdax/environments/__init__.py b/qdax/environments/__init__.py index f0b7e9d1..5036d976 100644 --- a/qdax/environments/__init__.py +++ b/qdax/environments/__init__.py @@ -10,7 +10,7 @@ ) from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper -from qdax.environments.bd_extractors import ( +from qdax.environments.descriptor_extractors import ( get_feet_contact_proportion, get_final_xy_position, ) @@ -42,7 +42,7 @@ "walker2d_uni": 1.413, } -behavior_descriptor_extractor = { +descriptor_extractor = { "pointmaze": get_final_xy_position, "anttrap": get_final_xy_position, "humanoidtrap": get_final_xy_position, diff --git a/qdax/environments/base_wrappers.py b/qdax/environments/base_wrappers.py index 3f709fa7..64ed9a56 100644 --- a/qdax/environments/base_wrappers.py +++ b/qdax/environments/base_wrappers.py @@ -28,12 +28,12 @@ def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: @property @abstractmethod - def behavior_descriptor_length(self) -> int: + def descriptor_length(self) -> int: pass @property @abstractmethod - def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: + def descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -76,12 +76,12 @@ def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.state_descriptor_limits @property - def behavior_descriptor_length(self) -> int: - return self.env.behavior_descriptor_length + def descriptor_length(self) -> int: + return self.env.descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: - return self.env.behavior_descriptor_limits + def descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: + return self.env.descriptor_limits @property def name(self) -> str: diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/descriptor_extractors.py similarity index 95% rename from qdax/environments/bd_extractors.py rename to qdax/environments/descriptor_extractors.py index 8649b74c..72daf990 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/descriptor_extractors.py @@ -14,10 +14,10 @@ def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: This function suppose that state descriptor is the xy position, as it just select the final one of the state descriptors given. """ - # reshape mask for bd extraction + # reshape mask for descriptor extraction mask = jnp.expand_dims(mask, axis=-1) - # Get behavior descriptor + # Get descriptor last_index = jnp.int32(jnp.sum(1.0 - mask, axis=1)) - 1 descriptors = jax.vmap(lambda x, y: x[y])(data.state_desc, last_index) @@ -31,10 +31,10 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri This function suppose that state descriptor is the feet contact, as it just computes the mean of the state descriptors given. """ - # reshape mask for bd extraction + # reshape mask for descriptor extraction mask = jnp.expand_dims(mask, axis=-1) - # Get behavior descriptor + # Get descriptor descriptors = jnp.sum(data.state_desc * (1.0 - mask), axis=1) descriptors = descriptors / jnp.sum(1.0 - mask, axis=1) diff --git a/qdax/environments/locomotion_wrappers.py b/qdax/environments/locomotion_wrappers.py index 982f5b69..07916686 100644 --- a/qdax/environments/locomotion_wrappers.py +++ b/qdax/environments/locomotion_wrappers.py @@ -96,7 +96,7 @@ def __init__(self, env: Env, env_name: str): @property def state_descriptor_length(self) -> int: - return self.behavior_descriptor_length + return self.descriptor_length @property def state_descriptor_name(self) -> str: @@ -104,16 +104,16 @@ def state_descriptor_name(self) -> str: @property def state_descriptor_limits(self) -> Tuple[List, List]: - return self.behavior_descriptor_limits + return self.descriptor_limits @property - def behavior_descriptor_length(self) -> int: + def descriptor_length(self) -> int: return len(self._feet_contact_idx) @property - def behavior_descriptor_limits(self) -> Tuple[List, List]: - bd_length = self.behavior_descriptor_length - return (jnp.zeros((bd_length,)), jnp.ones((bd_length,))) + def descriptor_limits(self) -> Tuple[List, List]: + descriptor_length = self.descriptor_length + return (jnp.zeros((descriptor_length,)), jnp.ones((descriptor_length,))) @property def name(self) -> str: @@ -246,11 +246,11 @@ def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: return self._minval, self._maxval @property - def behavior_descriptor_length(self) -> int: + def descriptor_length(self) -> int: return self.state_descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def descriptor_limits(self) -> Tuple[List[float], List[float]]: return self.state_descriptor_limits @property diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index b299864f..743c1516 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -6,7 +6,7 @@ class PointMaze(Env): - """Jax/Brax implementation of the PointMaze. + """JAX/Brax implementation of the PointMaze. Highly inspired from the old python implementation of the PointMaze. @@ -91,11 +91,11 @@ def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: return [self._x_min, self._y_min], [self._x_max, self._y_max] @property - def behavior_descriptor_length(self) -> int: + def descriptor_length(self) -> int: return self.state_descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def descriptor_limits(self) -> Tuple[List[float], List[float]]: return self.state_descriptor_limits @property diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index babedaed..b1bb0363 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -22,10 +22,8 @@ def reset(self, rng: jp.ndarray) -> State: reset_state = self.env.reset(rng) reset_state.metrics["reward"] = reset_state.reward eval_metrics = CompletedEvalMetrics( - current_episode_metrics=jax.tree_util.tree_map( - jp.zeros_like, reset_state.metrics - ), - completed_episodes_metrics=jax.tree_util.tree_map( + current_episode_metrics=jax.tree.map(jp.zeros_like, reset_state.metrics), + completed_episodes_metrics=jax.tree.map( lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics ), completed_episodes=jp.zeros(()), @@ -46,16 +44,16 @@ def step(self, state: State, action: jp.ndarray) -> State: completed_episodes_steps = state_metrics.completed_episodes_steps + jp.sum( nstate.info["steps"] * nstate.done ) - current_episode_metrics = jax.tree_util.tree_map( + current_episode_metrics = jax.tree.map( lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics ) completed_episodes = state_metrics.completed_episodes + jp.sum(nstate.done) - completed_episodes_metrics = jax.tree_util.tree_map( + completed_episodes_metrics = jax.tree.map( lambda a, b: a + jp.sum(b * nstate.done), state_metrics.completed_episodes_metrics, current_episode_metrics, ) - current_episode_metrics = jax.tree_util.tree_map( + current_episode_metrics = jax.tree.map( lambda a, b: a * (1 - nstate.done) + b * nstate.done, current_episode_metrics, nstate.metrics, diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md index d35c9125..76744870 100644 --- a/qdax/tasks/README.md +++ b/qdax/tasks/README.md @@ -19,7 +19,7 @@ Notes: import jax from qdax.tasks.arm import arm_scoring_function -random_key = jax.random.key(0) +key = jax.random.key(0) # Get scoring function scoring_fn = arm_scoring_function @@ -31,7 +31,7 @@ min_desc, max_desc = 0., 1. # Get initial batch of parameters num_param_dimensions = ... init_batch_size = ... -random_key, _subkey = jax.random.split(random_key) +key, _subkey = jax.random.split(key) initial_params = jax.random.uniform( _subkey, shape=(init_batch_size, num_param_dimensions), @@ -56,7 +56,7 @@ desc_size = 2 import jax from qdax.tasks.standard_functions import sphere_scoring_function -random_key = jax.random.key(0) +key = jax.random.key(0) # Get scoring function scoring_fn = sphere_scoring_function @@ -68,7 +68,7 @@ min_desc, max_desc = 0., 1. # Get initial batch of parameters num_param_dimensions = ... init_batch_size = ... -random_key, _subkey = jax.random.split(random_key) +key, _subkey = jax.random.split(key) initial_params = jax.random.uniform( _subkey, shape=(init_batch_size, num_param_dimensions), @@ -98,7 +98,7 @@ desc_size = 2 import jax from qdax.tasks.hypervolume_functions import square_scoring_function -random_key = jax.random.key(0) +key = jax.random.key(0) # Get scoring function scoring_fn = square_scoring_function @@ -110,7 +110,7 @@ min_desc, max_desc = 0., 1. # Get initial batch of parameters num_param_dimensions = ... init_batch_size = ... -random_key, _subkey = jax.random.split(random_key) +key, _subkey = jax.random.split(key) initial_params = jax.random.uniform( _subkey, shape=(init_batch_size, num_param_dimensions), @@ -127,7 +127,7 @@ desc_size = num_param_dimensions | Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | |--------------------------------|----------------------|--------------------------------------------------------------------------------|---------------------------------------|-----------------------------------------------------------------------------|-------------| -| archimedean-spiral-v0 | 1 | $[0,\alpha\pi]^n$ (angle param.)
$[0,max-arc-length]$ (arc length param.) | 1 (geodesic BD)
2 (euclidean BD) | $[0,max-arc-length]$ (geodesic BD)
$[-radius,radius]^2$ (euclidean BD) | | +| archimedean-spiral-v0 | 1 | $[0,\alpha\pi]^n$ (angle param.)
$[0,max-arc-length]$ (arc length param.) | 1 (geodesic descriptor)
2 (euclidean descriptor) | $[0,max-arc-length]$ (geodesic descriptor)
$[-radius,radius]^2$ (euclidean descriptor) | | | SSF-v0 | $n$ | Unbounded | 1 | $[ 0 ,$ ∞ $)$ | | | deceptive-evolvability-v0
| $n$ (2 by default) | Bounded area including the two gaussian peaks | 1 | $[0,max-sum-gaussians]$ | | @@ -147,7 +147,7 @@ min_param, max_param = task.get_min_max_params() min_desc, max_desc = task.get_bounded_min_max_descriptor() # To consider bounded Descriptor space # If the task has a descriptor space that is not bounded, then the unbounded descriptor # space can be obtained via the following: -# min_bd, max_bd = task.get_min_max_bd() +# min_descriptor, max_descriptor = task.get_min_max_descriptor() # Get initial batch of parameters initial_params = task.get_initial_parameters(batch_size=...) diff --git a/qdax/tasks/arm.py b/qdax/tasks/arm.py index 27782cf3..92614d74 100644 --- a/qdax/tasks/arm.py +++ b/qdax/tasks/arm.py @@ -8,7 +8,7 @@ def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - Compute the fitness and BD of one individual in the Planar Arm task. + Compute the fitness and descriptor of one individual in the Planar Arm task. Based on the Planar Arm implementation in fast_map_elites (https://github.com/hucebot/fast_map-elites). @@ -19,9 +19,9 @@ def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: Returns: f: the fitness of the individual, given as the variance of the angles. - bd: the bd of the individual, given as the [x, y] position of the - end-effector of the arm. - BD is normalized to [0, 1] regardless of the num of DoF. + descriptor: the descriptor of the individual, given as the [x, y] position + of the end-effector of the arm. + Descriptor is normalized to [0, 1] regardless of the DoF. Arm is centered at 0.5, 0.5. """ @@ -40,33 +40,28 @@ def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: def arm_scoring_function( params: Genotype, - random_key: RNGKey, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate policies contained in params in parallel. """ fitnesses, descriptors = jax.vmap(arm)(params) - return ( - fitnesses, - descriptors, - {}, - random_key, - ) + return fitnesses, descriptors, {} def noisy_arm_scoring_function( params: Genotype, - random_key: RNGKey, + key: RNGKey, fit_variance: float, desc_variance: float, params_variance: float, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate policies contained in params in parallel. """ - random_key, f_subkey, d_subkey, p_subkey = jax.random.split(random_key, num=4) + key, f_subkey, d_subkey, p_subkey = jax.random.split(key, num=4) # Add noise to the parameters params = params + jax.random.normal(p_subkey, shape=params.shape) * params_variance @@ -83,9 +78,4 @@ def noisy_arm_scoring_function( + jax.random.normal(d_subkey, shape=descriptors.shape) * desc_variance ) - return ( - fitnesses, - descriptors, - {}, - random_key, - ) + return fitnesses, descriptors, {} diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index ea928fba..25c89f3d 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -10,7 +10,7 @@ import qdax.environments from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc +from qdax.core.neuroevolution.mdp_utils import generate_unroll from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import ( Descriptor, @@ -34,8 +34,8 @@ def make_policy_network_play_step_fn_brax( Creates a function that when called, plays a step of the environment. Args: - env: The BRAX environment. - policy_network: The policy network structure used for creating and evaluating + env: The Brax environment. + policy_network: The policy network structure used for creating and evaluating policy controllers. Returns: @@ -46,19 +46,20 @@ def make_policy_network_play_step_fn_brax( def default_play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated EnvState and the transition. - Args: env_state: The state of the environment (containing for instance the - actor joint positions and velocities, the reward...). policy_params: The - parameters of policies/controllers. random_key: JAX random key. + Args: + env_state: The state of the environment (containing for instance the + actor joint positions and velocities, the reward...). + policy_params: The parameters of policies/controllers. key: JAX random key. Returns: - next_state: The updated environment state. + next_env_state: The updated environment state. policy_params: The parameters of policies/controllers (unchanged). - random_key: The updated random key. + key: The updated random key. transition: containing some information about the transition: observation, reward, next observation, policy action... """ @@ -66,20 +67,20 @@ def default_play_step_fn( actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] - next_state = env.step(env_state, actions) + next_env_state = env.step(env_state, actions) transition = QDTransition( obs=env_state.obs, - next_obs=next_state.obs, - rewards=next_state.reward, - dones=next_state.done, + next_obs=next_env_state.obs, + rewards=next_env_state.reward, + dones=next_env_state.done, actions=actions, - truncations=next_state.info["truncation"], + truncations=next_env_state.info["truncation"], state_desc=state_desc, - next_state_desc=next_state.info["state_descriptor"], + next_state_desc=next_env_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_env_state, policy_params, key, transition return default_play_step_fn @@ -93,168 +94,25 @@ def get_mask_from_transitions( return mask -@partial( - jax.jit, - static_argnames=( - "episode_length", - "play_step_fn", - "behavior_descriptor_extractor", - ), -) -def scoring_function_brax_envs( - policies_params: Genotype, - random_key: RNGKey, - init_states: EnvState, - episode_length: int, - play_step_fn: Callable[ - [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] - ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """Evaluates policies contained in policies_params in parallel in - deterministic or pseudo-deterministic environments. - - This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarily - evaluated with the same environment every time, this won't be deterministic. - When the init states are different, this is not purely stochastic. - - Args: - policies_params: The parameters of closed-loop controllers/policies to evaluate. - random_key: A jax random key - episode_length: The maximal rollout length. - play_step_fn: The function to play a step of the environment. - behavior_descriptor_extractor: The function to extract the behavior descriptor. - - Returns: - fitness: Array of fitnesses of all evaluated policies - descriptor: Behavioural descriptors of all evaluated policies - extra_scores: Additional information resulting from evaluation - random_key: The updated random key. - """ - - # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) - unroll_fn = partial( - generate_unroll, - episode_length=episode_length, - play_step_fn=play_step_fn, - random_key=subkey, - ) - - _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) - - # create a mask to extract data properly - mask = get_mask_from_transitions(data) - - # scores - fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) - descriptors = behavior_descriptor_extractor(data, mask) - - return ( - fitnesses, - descriptors, - { - "transitions": data, - }, - random_key, - ) - - -@partial( - jax.jit, - static_argnames=( - "episode_length", - "play_step_actor_dc_fn", - "behavior_descriptor_extractor", - ), -) -def scoring_actor_dc_function_brax_envs( - actors_dc_params: Genotype, - descs: Descriptor, - random_key: RNGKey, - init_states: EnvState, - episode_length: int, - play_step_actor_dc_fn: Callable[ - [EnvState, Descriptor, Params, RNGKey], - Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], - ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """Evaluates policies contained in policy_dc_params in parallel in - deterministic or pseudo-deterministic environments. - - This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarily - evaluated with the same environment every time, this won't be determinist. - When the init states are different, this is not purely stochastic. - - Args: - policy_dc_params: The parameters of closed-loop - descriptor-conditioned policy to evaluate. - descriptors: The descriptors the - descriptor-conditioned policy attempts to achieve. - random_key: A jax random key - episode_length: The maximal rollout length. - play_step_fn: The function to play a step of the environment. - behavior_descriptor_extractor: The function to extract the behavior descriptor. - - Returns: - fitness: Array of fitnesses of all evaluated policies - descriptor: Behavioural descriptors of all evaluated policies - extra_scores: Additional information resulting from evaluation - random_key: The updated random key. - """ - - # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) - unroll_fn = partial( - generate_unroll_actor_dc, - episode_length=episode_length, - play_step_actor_dc_fn=play_step_actor_dc_fn, - random_key=subkey, - ) - - _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) - - # create a mask to extract data properly - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) - - # Scores - add offset to ensure positive fitness (through positive rewards) - fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) - descriptors = behavior_descriptor_extractor(data, mask) - - return ( - fitnesses, - descriptors, - { - "transitions": data, - }, - random_key, - ) - - @partial( jax.jit, static_argnames=( "episode_length", "play_reset_fn", "play_step_fn", - "behavior_descriptor_extractor", + "descriptor_extractor", ), ) -def reset_based_scoring_function_brax_envs( +def scoring_function_brax_envs( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_reset_fn: Callable[[RNGKey], EnvState], play_step_fn: Callable[ [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Evaluates policies contained in policies_params in parallel. The play_reset_fn function allows for a more general scoring_function that can be called with different batch-size and not only with a batch-size of the same @@ -264,125 +122,52 @@ def reset_based_scoring_function_brax_envs( environment, use "play_reset_fn = env.reset". To define purely deterministic environments, as in "scoring_function", generate - a single init_state using "init_state = env.reset(random_key)", then use - "play_reset_fn = lambda random_key: init_state". + a single init_state using "init_state = env.reset(key)", then use + "play_reset_fn = lambda key: init_state". Args: policies_params: The parameters of closed-loop controllers/policies to evaluate. - random_key: A jax random key + key: A jax random key episode_length: The maximal rollout length. play_reset_fn: The function to reset the environment and obtain initial states. play_step_fn: The function to play a step of the environment. - behavior_descriptor_extractor: The function to extract the behavior descriptor. + descriptor_extractor: The function to extract the descriptor. Returns: fitness: Array of fitnesses of all evaluated policies descriptor: Behavioural descriptors of all evaluated policies extra_scores: Additional information resulting from the evaluation - random_key: The updated random key. """ - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split( - subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0] - ) - reset_fn = jax.vmap(play_reset_fn) - init_states = reset_fn(keys) + # Reset environments + key, subkey = jax.random.split(key) + keys = jax.random.split(subkey, jax.tree.leaves(policies_params)[0].shape[0]) + init_states = jax.vmap(play_reset_fn)(keys) - fitnesses, descriptors, extra_scores, random_key = scoring_function_brax_envs( - policies_params=policies_params, - random_key=random_key, - init_states=init_states, + # Step environments + unroll_fn = partial( + generate_unroll, episode_length=episode_length, play_step_fn=play_step_fn, - behavior_descriptor_extractor=behavior_descriptor_extractor, ) + keys = jax.random.split(key, jax.tree.leaves(policies_params)[0].shape[0]) + _, data = jax.vmap(unroll_fn)(init_states, policies_params, keys) - return fitnesses, descriptors, extra_scores, random_key - - -@partial( - jax.jit, - static_argnames=( - "episode_length", - "play_reset_fn", - "play_step_actor_dc_fn", - "behavior_descriptor_extractor", - ), -) -def reset_based_scoring_actor_dc_function_brax_envs( - actors_dc_params: Genotype, - descs: Descriptor, - random_key: RNGKey, - episode_length: int, - play_reset_fn: Callable[[RNGKey], EnvState], - play_step_actor_dc_fn: Callable[ - [EnvState, Descriptor, Params, RNGKey], - Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], - ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """Evaluates policies contained in policy_dc_params in parallel. - The play_reset_fn function allows for a more general scoring_function that can be - called with different batch-size and not only with a batch-size of the same - dimension as init_states. - - To define purely stochastic environments, using the reset function from the - environment, use "play_reset_fn = env.reset". - - To define purely deterministic environments, as in "scoring_function", generate - a single init_state using "init_state = env.reset(random_key)", then use - "play_reset_fn = lambda random_key: init_state". - - Args: - policy_dc_params: The parameters of closed-loop - descriptor-conditioned policy to evaluate. - descriptors: The descriptors the - descriptor-conditioned policy attempts to achieve. - random_key: A jax random key - episode_length: The maximal rollout length. - play_reset_fn: The function to reset the environment - and obtain initial states. - play_step_fn: The function to play a step of the environment. - behavior_descriptor_extractor: The function to extract the behavior descriptor. - - Returns: - fitness: Array of fitnesses of all evaluated policies - descriptor: Behavioural descriptors of all evaluated policies - extra_scores: Additional information resulting from the evaluation - random_key: The updated random key. - """ + # Create a mask to extract data properly + mask = get_mask_from_transitions(data) - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split( - subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0] - ) - reset_fn = jax.vmap(play_reset_fn) - init_states = reset_fn(keys) - - ( - fitnesses, - descriptors, - extra_scores, - random_key, - ) = scoring_actor_dc_function_brax_envs( - actors_dc_params=actors_dc_params, - descs=descs, - random_key=random_key, - init_states=init_states, - episode_length=episode_length, - play_step_actor_dc_fn=play_step_actor_dc_fn, - behavior_descriptor_extractor=behavior_descriptor_extractor, - ) + # Evaluate + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = descriptor_extractor(data, mask) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, {"transitions": data} def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, - bd_extraction_fn: Callable[[QDTransition, jnp.ndarray], Descriptor], - random_key: RNGKey, + descriptor_extraction_fn: Callable[[QDTransition, jnp.ndarray], Descriptor], + key: RNGKey, play_step_fn: Optional[ Callable[ [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] @@ -391,18 +176,15 @@ def create_brax_scoring_fn( episode_length: int = 100, deterministic: bool = True, play_reset_fn: Optional[Callable[[RNGKey], EnvState]] = None, -) -> Tuple[ - Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], - RNGKey, -]: +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: """ Creates a scoring function to evaluate a policy in a BRAX task. Args: - env: The BRAX environment. + env: The Brax environment. policy_network: The policy network controller. - bd_extraction_fn: The behaviour descriptor extraction function. - random_key: a random key used for stochastic operations. + descriptor_extraction_fn: The behaviour descriptor extraction function. + key: a random key used for stochastic operations. play_step_fn: the function used to perform environment rollouts and collect evaluation episodes. If None, we use make_policy_network_play_step_fn_brax to generate it. @@ -416,7 +198,6 @@ def create_brax_scoring_fn( Returns: The scoring function: a function that takes a batch of genotypes and compute their fitnesses and descriptors - The updated random key. """ if play_step_fn is None: play_step_fn = make_policy_network_play_step_fn_brax(env, policy_network) @@ -424,7 +205,7 @@ def create_brax_scoring_fn( # Deterministic case if deterministic: # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_state = env.reset(subkey) # Define the function to deterministically reset the environment @@ -438,34 +219,33 @@ def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: play_reset_fn = env.reset scoring_fn = functools.partial( - reset_based_scoring_function_brax_envs, + scoring_function_brax_envs, episode_length=episode_length, play_reset_fn=play_reset_fn, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) - return scoring_fn, random_key + return scoring_fn def create_default_brax_task_components( env_name: str, - random_key: RNGKey, + key: RNGKey, episode_length: int = 100, mlp_policy_hidden_layer_sizes: Tuple[int, ...] = (64, 64), deterministic: bool = True, ) -> Tuple[ brax.envs.Env, MLP, - Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], - RNGKey, + Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]], ]: """ Creates default environment, policy network and scoring function for a BRAX task. Args: env_name: Name of the BRAX environment (e.g. "ant_omni", "walker2d_uni"...). - random_key: Jax random key + key: JAX random key episode_length: The maximal rollout length. mlp_policy_hidden_layer_sizes: Hidden layer sizes of the policy network. deterministic: Whether we reset the initial state of the robot to the same @@ -477,7 +257,6 @@ def create_default_brax_task_components( policy controllers. scoring_fn: a function that takes a batch of genotypes and compute their fitnesses and descriptors. - random_key: The updated random key. """ env = environments.create(env_name, episode_length=episode_length) @@ -489,28 +268,24 @@ def create_default_brax_task_components( final_activation=jnp.tanh, ) - bd_extraction_fn = qdax.environments.behavior_descriptor_extractor[env_name] + descriptor_extraction_fn = qdax.environments.descriptor_extractor[env_name] - scoring_fn, random_key = create_brax_scoring_fn( + scoring_fn = create_brax_scoring_fn( env, policy_network, - bd_extraction_fn, - random_key, + descriptor_extraction_fn, + key, episode_length=episode_length, deterministic=deterministic, ) - return env, policy_network, scoring_fn, random_key + return env, policy_network, scoring_fn def get_aurora_scoring_fn( - scoring_fn: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] - ], + scoring_fn: Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]], observation_extractor_fn: Callable[[Transition], Observation], -) -> Callable[ - [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] -]: +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores]]: """Evaluates policies contained in flatten_variables in parallel This rollout is only deterministic when all the init states are the same. @@ -525,12 +300,12 @@ def get_aurora_scoring_fn( @functools.wraps(scoring_fn) def _wrapper( - params: Params, random_key: RNGKey # Perform rollouts with each policy - ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: - fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) + params: Params, key: RNGKey # Perform rollouts with each policy + ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores]: + fitnesses, _, extra_scores = scoring_fn(params, key) data = extra_scores["transitions"] observation = observation_extractor_fn(data) # type: ignore extra_scores["last_valid_observations"] = observation - return fitnesses, None, extra_scores, random_key + return fitnesses, None, extra_scores return _wrapper diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py index bd9ac933..4d5f75d4 100644 --- a/qdax/tasks/hypervolume_functions.py +++ b/qdax/tasks/hypervolume_functions.py @@ -14,29 +14,29 @@ def square(params: Genotype) -> Tuple[Fitness, Descriptor]: """ Search space should be [0,1]^n - BD space should be [0,1]^n + Descriptor space should be [0,1]^n """ freq = 5 f = 1 - jnp.prod(params) - bd = jnp.sin(freq * params) - return f, bd + descriptor = jnp.sin(freq * params) + return f, descriptor def checkered(params: Genotype) -> Tuple[Fitness, Descriptor]: """ Search space should be [0,1]^n - BD space should be [0,1]^n + Descriptor space should be [0,1]^n """ freq = 5 f = jnp.prod(jnp.sin(params * 50)) - bd = jnp.sin(params * freq) - return f, bd + descriptor = jnp.sin(params * freq) + return f, descriptor def empty_circle(params: Genotype) -> Tuple[Fitness, Descriptor]: """ Search space should be [0,1]^n - BD space should be [0,1]^n + Descriptor space should be [0,1]^n """ def _gaussian(x: jnp.ndarray, mu: float, sig: float) -> jnp.ndarray: @@ -46,44 +46,45 @@ def _gaussian(x: jnp.ndarray, mu: float, sig: float) -> jnp.ndarray: centre = jnp.ones_like(params) * 0.5 distance_from_centre = jnp.linalg.norm(params - centre) f = _gaussian(distance_from_centre, mu=0.5, sig=0.3) - bd = jnp.sin(freq * params) - return f, bd + descriptor = jnp.sin(freq * params) + return f, descriptor def non_continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: """ Search space should be [0,1]^n - BD space should be [0,1]^n + Descriptor space should be [0,1]^n """ f = jnp.prod(params) - bd = jnp.round(10 * params) / 10 - return f, bd + descriptor = jnp.round(10 * params) / 10 + return f, descriptor def continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: """ Search space should be [0,1]^n - BD space should be [0,1]^n + Descriptor space should be [0,1]^n """ coeff = 20 f = jnp.prod(params) - bd = params - jnp.sin(coeff * jnp.pi * params) / (coeff * jnp.pi) - return f, bd + descriptor = params - jnp.sin(coeff * jnp.pi * params) / (coeff * jnp.pi) + return f, descriptor def get_scoring_function( task_fn: Callable[[Genotype], Tuple[Fitness, Descriptor]] -) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]: +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: + def scoring_function( params: Genotype, - random_key: RNGKey, - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate params in parallel """ fitnesses, descriptors = jax.vmap(task_fn)(params) - return (fitnesses, descriptors, {}, random_key) + return fitnesses, descriptors, {} return scoring_function diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index 68f2409c..ed4c4580 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -47,7 +47,7 @@ def default_play_step_fn( env_state: JumanjiState, timestep: JumanjiTimeStep, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey, QDTransition]: """Play an environment step and return the updated state and the transition. Everything is deterministic in this simple example. @@ -75,7 +75,7 @@ def default_play_step_fn( next_state_desc=next_state_desc, ) - return next_state, next_timestep, policy_params, random_key, transition + return next_state, next_timestep, policy_params, key, transition return default_play_step_fn @@ -85,7 +85,7 @@ def generate_jumanji_unroll( init_state: JumanjiState, init_timestep: JumanjiTimeStep, policy_params: Params, - random_key: RNGKey, + key: RNGKey, episode_length: int, play_step_fn: Callable[ [JumanjiState, JumanjiTimeStep, Params, RNGKey], @@ -104,7 +104,7 @@ def generate_jumanji_unroll( Args: init_state: first state of the rollout. policy_params: params of the individual. - random_key: random key for stochasiticity handling. + key: random key for stochasiticity handling. episode_length: length of the rollout. play_step_fn: function describing how a step need to be taken. @@ -115,14 +115,12 @@ def generate_jumanji_unroll( def _scan_play_step_fn( carry: Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey], unused_arg: Any ) -> Tuple[Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey], Transition]: - env_state, timestep, policy_params, random_key, transitions = play_step_fn( - *carry - ) - return (env_state, timestep, policy_params, random_key), transitions + env_state, timestep, policy_params, key, transitions = play_step_fn(*carry) + return (env_state, timestep, policy_params, key), transitions (state, timestep, _, _), transitions = jax.lax.scan( _scan_play_step_fn, - (init_state, init_timestep, policy_params, random_key), + (init_state, init_timestep, policy_params, key), (), length=episode_length, ) @@ -134,12 +132,12 @@ def _scan_play_step_fn( static_argnames=( "episode_length", "play_step_fn", - "behavior_descriptor_extractor", + "descriptor_extractor", ), ) def jumanji_scoring_function( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, init_states: JumanjiState, init_timesteps: JumanjiTimeStep, episode_length: int, @@ -147,8 +145,8 @@ def jumanji_scoring_function( [JumanjiState, JumanjiTimeStep, Params, RNGKey, jumanji.env.Environment], Tuple[JumanjiState, JumanjiTimeStep, Params, RNGKey, QDTransition], ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Evaluates policies contained in policies_params in parallel in deterministic or pseudo-deterministic environments. @@ -158,33 +156,22 @@ def jumanji_scoring_function( When the init states are different, this is not purely stochastic. """ - # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) + # Step environments unroll_fn = partial( generate_jumanji_unroll, episode_length=episode_length, play_step_fn=play_step_fn, - random_key=subkey, - ) - - _final_state, _final_timestep, data = jax.vmap(unroll_fn)( - init_states, init_timesteps, policies_params ) + keys = jax.random.split(key, jax.tree.leaves(policies_params)[0].shape[0]) + _, _, data = jax.vmap(unroll_fn)(init_states, init_timesteps, policies_params, keys) - # create a mask to extract data properly + # Create a mask to extract data properly is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) mask = jnp.roll(is_done, 1, axis=1) mask = mask.at[:, 0].set(0) - # scores + # Evaluate fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) - descriptors = behavior_descriptor_extractor(data, mask) - - return ( - fitnesses, - descriptors, - { - "transitions": data, - }, - random_key, - ) + descriptors = descriptor_extractor(data, mask) + + return fitnesses, descriptors, {"transitions": data} diff --git a/qdax/tasks/qd_suite/__init__.py b/qdax/tasks/qd_suite/__init__.py index ae3de3d6..34ddde00 100644 --- a/qdax/tasks/qd_suite/__init__.py +++ b/qdax/tasks/qd_suite/__init__.py @@ -1,5 +1,5 @@ from qdax.tasks.qd_suite.archimedean_spiral import ( - ArchimedeanBD, + ArchimedeanDescriptor, ArchimedeanSpiralV0, ParameterizationGenotype, ) @@ -8,19 +8,19 @@ archimedean_spiral_v0_angle_euclidean_task = ArchimedeanSpiralV0( ParameterizationGenotype.angle, - ArchimedeanBD.euclidean, + ArchimedeanDescriptor.euclidean, ) archimedean_spiral_v0_angle_geodesic_task = ArchimedeanSpiralV0( ParameterizationGenotype.angle, - ArchimedeanBD.geodesic, + ArchimedeanDescriptor.geodesic, ) archimedean_spiral_v0_arc_length_euclidean_task = ArchimedeanSpiralV0( ParameterizationGenotype.arc_length, - ArchimedeanBD.euclidean, + ArchimedeanDescriptor.euclidean, ) archimedean_spiral_v0_arc_length_geodesic_task = ArchimedeanSpiralV0( ParameterizationGenotype.arc_length, - ArchimedeanBD.geodesic, + ArchimedeanDescriptor.geodesic, ) deceptive_evolvability_v0_task = DeceptiveEvolvabilityV0() ssf_v0_param_size_1_task = SsfV0(param_size=1) diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py index 59108ae5..7f1fba99 100644 --- a/qdax/tasks/qd_suite/archimedean_spiral.py +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -13,7 +13,7 @@ class ParameterizationGenotype(Enum): arc_length = "arc_length" -class ArchimedeanBD(Enum): +class ArchimedeanDescriptor(Enum): euclidean = "euclidean" geodesic = "geodesic" @@ -22,7 +22,7 @@ class ArchimedeanSpiralV0(QDSuiteTask): def __init__( self, parameterization: ParameterizationGenotype, - archimedean_bd: ArchimedeanBD, + archimedean_descriptor: ArchimedeanDescriptor, amplitude: float = 0.01, precision: Optional[float] = None, alpha: float = 40.0, @@ -34,15 +34,15 @@ def __init__( Args: parameterization: The parameterization of the genotype, can be either angle or arc length. - archimedean_bd: The Archimedean BD, can be either euclidean or - geodesic. + archimedean_descriptor: The Archimedean descriptor, can be either euclidean + or geodesic. amplitude: The amplitude of the Archimedean spiral. precision: The precision of the approximation of the angle from the arc length. alpha: controls the length/maximal angle of the Archimedean spiral. """ self.parameterization = parameterization - self.archimedean_bd = archimedean_bd + self.archimedean_descriptor = archimedean_descriptor self.amplitude = amplitude if precision is None: self.precision = alpha * jnp.pi / 1e7 @@ -158,34 +158,34 @@ def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: constant_fitness = jnp.asarray(1.0) if ( - self.archimedean_bd == ArchimedeanBD.geodesic + self.archimedean_descriptor == ArchimedeanDescriptor.geodesic and self.parameterization == ParameterizationGenotype.arc_length ): arc_length = params return constant_fitness, arc_length elif ( - self.archimedean_bd == ArchimedeanBD.geodesic + self.archimedean_descriptor == ArchimedeanDescriptor.geodesic and self.parameterization == ParameterizationGenotype.angle ): angle = params arc_length = self.get_arc_length(angle) return constant_fitness, arc_length elif ( - self.archimedean_bd == ArchimedeanBD.euclidean + self.archimedean_descriptor == ArchimedeanDescriptor.euclidean and self.parameterization == ParameterizationGenotype.arc_length ): arc_length = params angle = self._approximate_angle_from_arc_length(arc_length[0]) - euclidean_bd = self._gamma(angle) - return constant_fitness, euclidean_bd + euclidean_descriptor = self._gamma(angle) + return constant_fitness, euclidean_descriptor elif ( - self.archimedean_bd == ArchimedeanBD.euclidean + self.archimedean_descriptor == ArchimedeanDescriptor.euclidean and self.parameterization == ParameterizationGenotype.angle ): angle = params return constant_fitness, self._gamma(angle) else: - raise ValueError("Invalid parameterization and/or BD") + raise ValueError("Invalid parameterization and/or Descriptor") def get_descriptor_size(self) -> int: """ @@ -194,12 +194,12 @@ def get_descriptor_size(self) -> int: Returns: The size of the descriptor. """ - if self.archimedean_bd == ArchimedeanBD.euclidean: + if self.archimedean_descriptor == ArchimedeanDescriptor.euclidean: return 2 - elif self.archimedean_bd == ArchimedeanBD.geodesic: + elif self.archimedean_descriptor == ArchimedeanDescriptor.geodesic: return 1 else: - raise ValueError("Invalid BD") + raise ValueError("Invalid descriptor") def get_min_max_descriptor(self) -> Tuple[float, float]: """ @@ -212,13 +212,13 @@ def get_min_max_descriptor(self) -> Tuple[float, float]: max_angle = self.alpha * jnp.pi max_norm = jnp.linalg.norm(self._gamma(max_angle)) - if self.archimedean_bd == ArchimedeanBD.euclidean: + if self.archimedean_descriptor == ArchimedeanDescriptor.euclidean: return -max_norm, max_norm - elif self.archimedean_bd == ArchimedeanBD.geodesic: + elif self.archimedean_descriptor == ArchimedeanDescriptor.geodesic: max_arc_length = self.get_arc_length(max_angle) return 0.0, max_arc_length.item() else: - raise ValueError("Invalid BD") + raise ValueError("Invalid descriptor") def get_min_max_params(self) -> Tuple[float, float]: """ diff --git a/qdax/tasks/qd_suite/deceptive_evolvability.py b/qdax/tasks/qd_suite/deceptive_evolvability.py index 830ad523..3be2d5c1 100644 --- a/qdax/tasks/qd_suite/deceptive_evolvability.py +++ b/qdax/tasks/qd_suite/deceptive_evolvability.py @@ -79,11 +79,11 @@ def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: Returns: The fitness and descriptor. """ - bd = multivariate_normal( + descriptor = multivariate_normal( params, self.mu_1, self.sigma_1 ) + self.beta * multivariate_normal(params, self.mu_2, self.sigma_2) constant_fitness = jnp.asarray(1.0) - return constant_fitness, bd + return constant_fitness, descriptor def get_saddle_point(self) -> Genotype: """ diff --git a/qdax/tasks/qd_suite/qd_suite_task.py b/qdax/tasks/qd_suite/qd_suite_task.py index 0d79317f..fde6fc4a 100644 --- a/qdax/tasks/qd_suite/qd_suite_task.py +++ b/qdax/tasks/qd_suite/qd_suite_task.py @@ -25,14 +25,14 @@ def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: def scoring_function( self, params: Genotype, - random_key: RNGKey, - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, + ) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Evaluate params in parallel """ fitnesses, descriptors = jax.vmap(self.evaluation)(params) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} @abc.abstractmethod def get_descriptor_size(self) -> int: @@ -64,13 +64,13 @@ def get_bounded_min_max_descriptor( The minimum and maximum descriptor assuming that the descriptor space is bounded. """ - min_bd, max_bd = self.get_min_max_descriptor() - if jnp.isinf(max_bd) or jnp.isinf(min_bd): + min_descriptor, max_descriptor = self.get_min_max_descriptor() + if jnp.isinf(max_descriptor) or jnp.isinf(min_descriptor): raise NotImplementedError( "Boundedness has not been implemented " "for this unbounded task" ) else: - return min_bd, max_bd + return min_descriptor, max_descriptor @abc.abstractmethod def get_min_max_params( diff --git a/qdax/tasks/qd_suite/ssf.py b/qdax/tasks/qd_suite/ssf.py index 601aa6ad..06713a94 100644 --- a/qdax/tasks/qd_suite/ssf.py +++ b/qdax/tasks/qd_suite/ssf.py @@ -39,10 +39,10 @@ def evaluation( norm = jnp.linalg.norm(params, ord=2) r_2k_plus_1, _, k = self._get_k(params) index = jnp.floor(norm / r_2k_plus_1) - bd = jax.lax.cond(index == 0.0, lambda: norm, lambda: r_2k_plus_1) + descriptor = jax.lax.cond(index == 0.0, lambda: norm, lambda: r_2k_plus_1) constant_fitness = jnp.asarray(1.0) - bd = jnp.asarray(bd).reshape((self.get_descriptor_size(),)) - return constant_fitness, bd + descriptor = jnp.asarray(descriptor).reshape((self.get_descriptor_size(),)) + return constant_fitness, descriptor def get_descriptor_size(self) -> int: """ diff --git a/qdax/tasks/standard_functions.py b/qdax/tasks/standard_functions.py index 82b2f875..de850c1b 100644 --- a/qdax/tasks/standard_functions.py +++ b/qdax/tasks/standard_functions.py @@ -8,7 +8,7 @@ def rastrigin(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - 2-D BD + 2-D descriptor """ x = params * 10 - 5 # scaling to [-5, 5] f = jnp.asarray(10.0 * x.shape[0]) + jnp.sum(x * x - 10 * jnp.cos(2 * jnp.pi * x)) @@ -17,7 +17,7 @@ def rastrigin(params: Genotype) -> Tuple[Fitness, Descriptor]: def sphere(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - 2-D BD + 2-D descriptor """ x = params * 10 - 5 # scaling to [-5, 5] f = (x * x).sum() @@ -26,26 +26,26 @@ def sphere(params: Genotype) -> Tuple[Fitness, Descriptor]: def rastrigin_scoring_function( params: Genotype, - random_key: RNGKey, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Scoring function for the rastrigin function """ fitnesses, descriptors = jax.vmap(rastrigin)(params) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} def sphere_scoring_function( params: Genotype, - random_key: RNGKey, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Scoring function for the sphere function """ fitnesses, descriptors = jax.vmap(sphere)(params) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} def _rastrigin_proj_scoring( @@ -105,8 +105,8 @@ def rastrigin_descriptors(x: jnp.ndarray) -> jnp.ndarray: def rastrigin_proj_scoring_function( - params: Genotype, random_key: RNGKey, minval: float = -5.12, maxval: float = 5.12 -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + params: Genotype, key: RNGKey, minval: float = -5.12, maxval: float = 5.12 +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Scoring function for the rastrigin function with a folding of the behaviour space. @@ -117,4 +117,4 @@ def rastrigin_proj_scoring_function( _rastrigin_proj_scoring, in_axes=(0, None, None) )(params, minval, maxval) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores diff --git a/qdax/utils/plotting.py b/qdax/utils/plotting.py index 8320ba89..fa2c965a 100644 --- a/qdax/utils/plotting.py +++ b/qdax/utils/plotting.py @@ -199,8 +199,8 @@ def plot_2d_map_elites_repertoire( ) # aesthetic - ax.set_xlabel("Behavior Dimension 1") - ax.set_ylabel("Behavior Dimension 2") + ax.set_xlabel("Descriptor Dimension 1") + ax.set_ylabel("Descriptor Dimension 2") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=my_cmap), cax=cax) @@ -216,8 +216,8 @@ def plot_map_elites_results( env_steps: jnp.ndarray, metrics: Dict, repertoire: MapElitesRepertoire, - min_bd: jnp.ndarray, - max_bd: jnp.ndarray, + min_descriptor: jnp.ndarray, + max_descriptor: jnp.ndarray, ) -> Tuple[Optional[Figure], Axes]: """Plots three usual QD metrics, namely the coverage, the maximum fitness and the QD-score, along the number of environment steps. This function also @@ -229,8 +229,8 @@ def plot_map_elites_results( env_steps: the array containing the number of steps done in the environment. metrics: a dictionary containing metrics from the optimizatoin process. repertoire: the final repertoire obtained. - min_bd: the minimal possible values for the bd. - max_bd: the maximal possible values for the bd. + min_descriptor: the minimal possible values for the descriptor. + max_descriptor: the maximal possible values for the descriptor. Returns: A figure and axes with the plots of the metrics and visualisation of the grid. @@ -275,8 +275,8 @@ def plot_map_elites_results( _, axes = plot_2d_map_elites_repertoire( centroids=repertoire.centroids, repertoire_fitnesses=repertoire.fitnesses, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, repertoire_descriptors=repertoire.descriptors, ax=axes[3], ) @@ -361,8 +361,8 @@ def plot_skills_trajectory( # set aesthetics ax.set_ylim(min_values[1], max_values[1]) ax.set_xlim(min_values[0], max_values[0]) - ax.set_xlabel("Behavior Dimension 1") - ax.set_ylabel("Behavior Dimension 2") + ax.set_xlabel("Descriptor Dimension 1") + ax.set_ylabel("Descriptor Dimension 2") ax.set_aspect("equal") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) @@ -705,8 +705,8 @@ def plot_multidimensional_map_elites_grid( ) # aesthetic - ax.set_xlabel("Behavior Dimension 1") - ax.set_ylabel("Behavior Dimension 2") + ax.set_xlabel("Descriptor Dimension 1") + ax.set_ylabel("Descriptor Dimension 2") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) norm = Normalize(vmin=vmin, vmax=vmax) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index fb0d37d3..6984cb80 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -115,22 +115,16 @@ def dummy_extra_scores_extractor( return extra_scores -@partial( - jax.jit, - static_argnames=( - "scoring_fn", - "num_samples", - ), -) +@partial(jax.jit, static_argnames=("scoring_fn", "num_samples")) def multi_sample_scoring_function( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_samples: int, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Wrap scoring_function to perform sampling. @@ -139,19 +133,16 @@ def multi_sample_scoring_function( Args: policies_params: policies to evaluate - random_key: JAX random key + key: JAX random key scoring_fn: scoring function used for evaluation num_samples: number of samples to generate for each individual Returns: (n, num_samples) array of fitnesses, (n, num_samples, num_descriptors) array of descriptors, - dict with num_samples extra_scores per individual, - JAX random key + dict with num_samples extra_scores per individual """ - - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, num=num_samples) + keys = jax.random.split(key, num=num_samples) # evaluate sample_scoring_fn = jax.vmap( @@ -160,13 +151,13 @@ def multi_sample_scoring_function( in_axes=(None, 0), # indicates that the vectorized axis will become axis 1, i.e., the final # output is shape (batch_size, num_samples, ...) except for the random key - out_axes=(1, 1, 1, 0), + out_axes=(1, 1, 1), ) - all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( + all_fitnesses, all_descriptors, all_extra_scores = sample_scoring_fn( policies_params, keys ) - return all_fitnesses, all_descriptors, all_extra_scores, random_key + return all_fitnesses, all_descriptors, all_extra_scores @partial( @@ -181,10 +172,10 @@ def multi_sample_scoring_function( ) def sampling( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_samples: int, extra_scores_extractor: Callable[ @@ -192,7 +183,7 @@ def sampling( ] = dummy_extra_scores_extractor, fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """Wrap scoring_function to perform sampling. This function return the expected fitnesses and descriptors for each @@ -201,7 +192,7 @@ def sampling( Args: policies_params: policies to evaluate - random_key: JAX random key + key: JAX random key scoring_fn: scoring function used for evaluation num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from @@ -221,17 +212,14 @@ def sampling( all_fitnesses, all_descriptors, all_extra_scores, - random_key, - ) = multi_sample_scoring_function( - policies_params, random_key, scoring_fn, num_samples - ) + ) = multi_sample_scoring_function(policies_params, key, scoring_fn, num_samples) # Extract final scores descriptors = descriptor_extractor(all_descriptors) fitnesses = fitness_extractor(all_fitnesses) extra_scores = extra_scores_extractor(all_extra_scores, num_samples) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores @partial( @@ -248,7 +236,7 @@ def sampling( ) def sampling_reproducibility( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey], @@ -261,7 +249,7 @@ def sampling_reproducibility( descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, -) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor]: """Wrap scoring_function to perform sampling and compute the expectation and reproducibility. @@ -271,7 +259,7 @@ def sampling_reproducibility( Args: policies_params: policies to evaluate - random_key: JAX random key + key: JAX random key scoring_fn: scoring function used for evaluation num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from @@ -287,8 +275,7 @@ def sampling_reproducibility( Returns: The expected fitnesses, descriptors and extra_scores of the individuals - The fitnesses and descriptors reproducibility of the individuals - A new random key + The fitnesses and descriptors reproducibility of the individuals. """ # Perform sampling @@ -296,10 +283,7 @@ def sampling_reproducibility( all_fitnesses, all_descriptors, all_extra_scores, - random_key, - ) = multi_sample_scoring_function( - policies_params, random_key, scoring_fn, num_samples - ) + ) = multi_sample_scoring_function(policies_params, key, scoring_fn, num_samples) # Extract final scores descriptors = descriptor_extractor(all_descriptors) @@ -316,5 +300,4 @@ def sampling_reproducibility( extra_scores, fitnesses_reproducibility, descriptors_reproducibility, - random_key, ) diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index fa7825b0..aac3102a 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -17,7 +17,7 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq from qdax.custom_types import Params, RNGKey -from qdax.environments.bd_extractors import AuroraExtraInfoNormalization +from qdax.environments.descriptor_extractors import AuroraExtraInfoNormalization def get_model( @@ -37,17 +37,17 @@ def get_model( def get_initial_params( - model: Seq2seq, random_key: RNGKey, encoder_input_shape: Tuple[int, ...] + model: Seq2seq, key: RNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: """ Returns the initial parameters of a seq2seq model. Args: model: the seq2seq model. - random_key: the random number generator. + key: the random number generator. encoder_input_shape: the shape of the encoder input. """ - random_key, rng1, rng2, rng3 = jax.random.split(random_key, 4) + key, rng1, rng2, rng3 = jax.random.split(key, 4) variables = model.init( {"params": rng1, "lstm": rng2, "dropout": rng3}, jnp.ones(encoder_input_shape, jnp.float32), @@ -60,7 +60,7 @@ def get_initial_params( def train_step( state: train_state.TrainState, batch: jax.Array, - lstm_random_key: RNGKey, + key: RNGKey, ) -> Tuple[train_state.TrainState, Dict[str, float]]: """ Trains for one step. @@ -68,12 +68,12 @@ def train_step( Args: state: the training state. batch: the batch of data. - lstm_random_key: the random number key. + key: a random key. """ """Trains one step.""" - lstm_key = jax.random.fold_in(lstm_random_key, state.step) - dropout_key, lstm_key = jax.random.split(lstm_key, 2) + key = jax.random.fold_in(key, state.step) + dropout_key, lstm_key = jax.random.split(key) # Shift input by one to avoid leakage batch_decoder = jnp.roll(batch, shift=1, axis=1) @@ -109,7 +109,7 @@ def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def lstm_ae_train( - random_key: RNGKey, + key: RNGKey, repertoire: UnstructuredRepertoire, params: Params, epoch: int, @@ -150,7 +150,7 @@ def lstm_ae_train( # select repertoire_size indexes going from 0 to the total number of # valid individuals. Those indexes will be used to select the individuals # in the training dataset. - random_key, key_select_p1 = jax.random.split(random_key, 2) + key, key_select_p1 = jax.random.split(key, 2) idx_p1 = jax.random.randint( key_select_p1, shape=(repertoire_size,), minval=0, maxval=num_indivs ) @@ -174,7 +174,7 @@ def lstm_ae_train( loss_val = 0.0 for epoch in range(num_epochs): - random_key, shuffle_key = jax.random.split(random_key, 2) + key, shuffle_key = jax.random.split(key, 2) valid_indexes = jax.random.permutation(shuffle_key, valid_indexes, axis=0) # create dataset with the observation from the sample of valid indexes @@ -194,7 +194,7 @@ def lstm_ae_train( # print(batch.shape) continue - state, loss_val = train_step(state, batch, random_key) + state, loss_val = train_step(state, batch, key) # To see the actual value we cannot jit this function (i.e. the _one_es_epoch # function nor the train function) diff --git a/qdax/utils/uncertainty_metrics.py b/qdax/utils/uncertainty_metrics.py index 62a976b1..1fb2ec0c 100644 --- a/qdax/utils/uncertainty_metrics.py +++ b/qdax/utils/uncertainty_metrics.py @@ -27,11 +27,11 @@ ) def reevaluation_function( repertoire: MapElitesRepertoire, - random_key: RNGKey, + key: RNGKey, empty_corrected_repertoire: MapElitesRepertoire, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_reevals: int, fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, @@ -40,7 +40,7 @@ def reevaluation_function( [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, scan_size: int = 0, -) -> Tuple[MapElitesRepertoire, RNGKey]: +) -> MapElitesRepertoire: """ Perform reevaluation of a repertoire and construct a corrected repertoire from it. @@ -48,7 +48,7 @@ def reevaluation_function( repertoire: repertoire to reevaluate. empty_corrected_repertoire: repertoire to be filled with reevaluated solutions, allow to use a different type of repertoire than the one from the algorithm. - random_key: JAX random key. + key: JAX random key. scoring_fn: scoring function used for evaluation. num_reevals: number of samples to generate for each individual. fitness_extractor: function to extract the final fitness from @@ -60,22 +60,22 @@ def reevaluation_function( scan_size: allow to split the reevaluations in multiple batch to reduce the memory load of the reevaluation. Returns: - The corrected repertoire and a random key. + The corrected repertoire. """ # If no reevaluations, return copies of the original container if num_reevals == 0: - return repertoire, random_key + return repertoire # Perform reevaluation + key, subkey = jax.random.split(key) ( all_fitnesses, all_descriptors, all_extra_scores, - random_key, ) = _perform_reevaluation( policies_params=repertoire.genotypes, - random_key=random_key, + key=subkey, scoring_fn=scoring_fn, num_reevals=num_reevals, scan_size=scan_size, @@ -97,7 +97,7 @@ def reevaluation_function( batch_of_extra_scores=extra_scores, ) - return corrected_repertoire, random_key + return corrected_repertoire # type: ignore @partial( @@ -115,11 +115,11 @@ def reevaluation_function( ) def reevaluation_reproducibility_function( repertoire: MapElitesRepertoire, - random_key: RNGKey, + key: RNGKey, empty_corrected_repertoire: MapElitesRepertoire, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_reevals: int, fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, @@ -130,7 +130,7 @@ def reevaluation_reproducibility_function( [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, scan_size: int = 0, -) -> Tuple[MapElitesRepertoire, MapElitesRepertoire, MapElitesRepertoire, RNGKey]: +) -> Tuple[MapElitesRepertoire, MapElitesRepertoire, MapElitesRepertoire]: """ Perform reevaluation of a repertoire and construct a corrected repertoire and a reproducibility repertoire from it. @@ -139,7 +139,7 @@ def reevaluation_reproducibility_function( repertoire: repertoire to reevaluate. empty_corrected_repertoire: repertoire to be filled with reevaluated solutions, allow to use a different type of repertoire than the one from the algorithm. - random_key: JAX random key. + key: JAX random key. scoring_fn: scoring function used for evaluation. num_reevals: number of samples to generate for each individual. fitness_extractor: function to extract the final fitness from @@ -158,7 +158,6 @@ def reevaluation_reproducibility_function( The corrected repertoire. A repertoire storing reproducibility in fitness. A repertoire storing reproducibility in descriptor. - A random key. """ # If no reevaluations, return copies of the original container @@ -167,7 +166,6 @@ def reevaluation_reproducibility_function( repertoire, repertoire, repertoire, - random_key, ) # Perform reevaluation @@ -175,10 +173,9 @@ def reevaluation_reproducibility_function( all_fitnesses, all_descriptors, all_extra_scores, - random_key, ) = _perform_reevaluation( policies_params=repertoire.genotypes, - random_key=random_key, + key=key, scoring_fn=scoring_fn, num_reevals=num_reevals, scan_size=scan_size, @@ -231,7 +228,6 @@ def reevaluation_reproducibility_function( corrected_repertoire, fit_reproducibility_repertoire, desc_reproducibility_repertoire, - random_key, ) @@ -245,39 +241,38 @@ def reevaluation_reproducibility_function( ) def _perform_reevaluation( policies_params: Genotype, - random_key: RNGKey, + key: RNGKey, scoring_fn: Callable[ [Genotype, RNGKey], - Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores], ], num_reevals: int, scan_size: int = 0, -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: +) -> Tuple[Fitness, Descriptor, ExtraScores]: """ Sub-function used to perform reevaluation of a repertoire in uncertain applications. Args: policies_params: genotypes to reevaluate. - random_key: JAX random key. + key: JAX random key. scoring_fn: scoring function used for evaluation. num_reevals: number of samples to generate for each individual. scan_size: allow to split the reevaluations in multiple batch to reduce the memory load of the reevaluation. Returns: - The fitnesses, descriptors and extra score from the reevaluation, - and a randon key. + The fitnesses, descriptors and extra score from the reevaluation. """ # If no need for scan, call the sampling function if scan_size == 0: + key, subkey = jax.random.split(key) ( all_fitnesses, all_descriptors, all_extra_scores, - random_key, ) = multi_sample_scoring_function( policies_params=policies_params, - random_key=random_key, + key=subkey, scoring_fn=scoring_fn, num_samples=num_reevals, ) @@ -292,32 +287,32 @@ def _perform_reevaluation( num_loops = num_reevals // scan_size def _sampling_scan( - random_key: RNGKey, - unused: Tuple[()], + key: RNGKey, + _: Tuple[()], ) -> Tuple[Tuple[RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: + key, subkey = jax.random.split(key) ( all_fitnesses, all_descriptors, all_extra_scores, - random_key, ) = multi_sample_scoring_function( policies_params=policies_params, - random_key=random_key, + key=subkey, scoring_fn=scoring_fn, num_samples=scan_size, ) - return (random_key), ( + return (key), ( all_fitnesses, all_descriptors, all_extra_scores, ) - (random_key), ( + (key), ( all_fitnesses, all_descriptors, all_extra_scores, - ) = jax.lax.scan(_sampling_scan, (random_key), (), length=num_loops) + ) = jax.lax.scan(_sampling_scan, (key), (), length=num_loops) all_fitnesses = jnp.hstack(all_fitnesses) all_descriptors = jnp.hstack(all_descriptors) - return all_fitnesses, all_descriptors, all_extra_scores, random_key + return all_fitnesses, all_descriptors, all_extra_scores diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index 82d7e54a..f8e8a9cd 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -32,8 +32,8 @@ def test_cma_me(emitter_type: Type[CMAEmitter]) -> None: sigma_g = 0.5 minval = -5.12 maxval = 5.12 - min_bd = -5.12 * 0.5 * num_dimensions - max_bd = 5.12 * 0.5 * num_dimensions + min_descriptor = -5.12 * 0.5 * num_dimensions + max_descriptor = 5.12 * 0.5 * num_dimensions pool_size = 3 def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: @@ -45,24 +45,24 @@ def clip(x: jnp.ndarray) -> jnp.ndarray: in_bound = (x <= maxval) * (x >= minval) return jnp.where(in_bound, x, (maxval / x)) - def _behavior_descriptor_1(x: jnp.ndarray) -> jnp.ndarray: + def _descriptor_1(x: jnp.ndarray) -> jnp.ndarray: return jnp.sum(clip(x[: x.shape[-1] // 2])) - def _behavior_descriptor_2(x: jnp.ndarray) -> jnp.ndarray: + def _descriptor_2(x: jnp.ndarray) -> jnp.ndarray: return jnp.sum(clip(x[x.shape[-1] // 2 :])) - def _behavior_descriptors(x: jnp.ndarray) -> jnp.ndarray: - return jnp.array([_behavior_descriptor_1(x), _behavior_descriptor_2(x)]) + def _descriptors(x: jnp.ndarray) -> jnp.ndarray: + return jnp.array([_descriptor_1(x), _descriptor_2(x)]) def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, Dict]: - scores, descriptors = fitness_scoring(x), _behavior_descriptors(x) + scores, descriptors = fitness_scoring(x), _descriptors(x) return scores, descriptors, {} def scoring_fn( - x: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + x: jnp.ndarray, key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12) best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) @@ -81,15 +81,17 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.key(0) + key = jax.random.key(0) + + key, subkey = jax.random.split(key) initial_population = ( - jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0 + jax.random.uniform(subkey, shape=(batch_size, num_dimensions)) * 0.0 ) centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) emitter_kwargs = { @@ -109,17 +111,16 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn ) - repertoire, emitter_state, random_key = map_elites.init( - initial_population, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(initial_population, centroids, subkey) ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index 02fdad13..2d569d77 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -64,21 +64,21 @@ def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, ExtraScores]: gradients = jnp.nan_to_num(gradients) # Compute normalized gradients - norm_gradients = jax.tree_util.tree_map( + norm_gradients = jax.tree.map( lambda x: jnp.linalg.norm(x, axis=1, keepdims=True), gradients, ) - grads = jax.tree_util.tree_map(lambda x, y: x / y, gradients, norm_gradients) + grads = jax.tree.map(lambda x, y: x / y, gradients, norm_gradients) grads = jnp.nan_to_num(grads) extra_scores = {"gradients": gradients, "normalized_grads": grads} return scores, descriptors, extra_scores def scoring_fn( - x: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + x: jnp.ndarray, key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12) best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) @@ -95,18 +95,19 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.key(0) - initial_population = jax.random.uniform( - random_key, shape=(batch_size, num_dimensions) - ) + key = jax.random.key(0) + + key, subkey = jax.random.split(key) + initial_population = jax.random.uniform(subkey, shape=(batch_size, num_dimensions)) - centroids, random_key = compute_cvt_centroids( + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( num_descriptors=2, num_init_cvt_samples=10000, num_centroids=num_centroids, minval=minval, maxval=maxval, - random_key=random_key, + key=subkey, ) emitter = CMAMEGAEmitter( @@ -121,17 +122,17 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: map_elites = MAPElites( scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn ) - repertoire, emitter_state, random_key = map_elites.init( - initial_population, centroids, random_key - ) + + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(initial_population, centroids, subkey) ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 2ecc25f6..e3633e9d 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -73,14 +73,16 @@ def test_dads_smerl() -> None: ) key = jax.random.key(seed) - env_state = jax.jit(env.reset)(rng=key) - eval_env_first_state = jax.jit(eval_env.reset)(rng=key) + + key, subkey_1, subkey_2 = jax.random.split(key, 3) + env_state = jax.jit(env.reset)(rng=subkey_1) + eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2) # Initialize buffer dummy_transition = QDTransition.init_dummy( observation_dim=env.observation_size + num_skills, action_dim=env.action_size, - descriptor_dim=env.behavior_descriptor_length, + descriptor_dim=env.descriptor_length, ) replay_buffer = TrajectoryBuffer.init( buffer_size=buffer_size, @@ -92,7 +94,7 @@ def test_dads_smerl() -> None: if descriptor_full_state: descriptor_size = env.observation_size else: - descriptor_size = env.behavior_descriptor_length + descriptor_size = env.descriptor_length dads_smerl_config = DadsSmerlConfig( # SAC config @@ -110,7 +112,7 @@ def test_dads_smerl() -> None: # DADS config num_skills=num_skills, descriptor_full_state=descriptor_full_state, - omit_input_dynamics_dim=env.behavior_descriptor_length, + omit_input_dynamics_dim=env.descriptor_length, dynamics_update_freq=dynamics_update_freq, normalize_target=normalize_target, # SMERL config @@ -123,8 +125,9 @@ def test_dads_smerl() -> None: action_size=env.action_size, descriptor_size=env.state_descriptor_length, ) + key, subkey = jax.random.split(key) training_state = dads_smerl.init( - key, + subkey, action_size=env.action_size, observation_size=env.observation_size, descriptor_size=descriptor_size, diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 76da834e..1bfee7ec 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -67,14 +67,16 @@ def test_dads() -> None: ) key = jax.random.key(seed) - env_state = jax.jit(env.reset)(rng=key) - eval_env_first_state = jax.jit(eval_env.reset)(rng=key) + + key, subkey_1, subkey_2 = jax.random.split(key, 3) + env_state = jax.jit(env.reset)(rng=subkey_1) + eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2) # Initialize buffer dummy_transition = QDTransition.init_dummy( observation_dim=env.observation_size + num_skills, action_dim=env.action_size, - descriptor_dim=env.behavior_descriptor_length, + descriptor_dim=env.descriptor_length, ) replay_buffer = ReplayBuffer.init( buffer_size=buffer_size, transition=dummy_transition @@ -83,7 +85,7 @@ def test_dads() -> None: if descriptor_full_state: descriptor_size = env.observation_size else: - descriptor_size = env.behavior_descriptor_length + descriptor_size = env.descriptor_length dads_config = DadsConfig( # SAC config @@ -101,7 +103,7 @@ def test_dads() -> None: # DADS config num_skills=num_skills, descriptor_full_state=descriptor_full_state, - omit_input_dynamics_dim=env.behavior_descriptor_length, + omit_input_dynamics_dim=env.descriptor_length, dynamics_update_freq=dynamics_update_freq, normalize_target=normalize_target, ) @@ -110,8 +112,9 @@ def test_dads() -> None: action_size=env.action_size, descriptor_size=descriptor_size, ) + key, subkey = jax.random.split(key) training_state = dads.init( - key, + subkey, action_size=env.action_size, observation_size=env.observation_size, descriptor_size=descriptor_size, diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 942abd67..3d0c72de 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -13,9 +13,9 @@ from qdax.core.neuroevolution.buffers.buffer import DCRLTransition from qdax.core.neuroevolution.networks.networks import MLP, MLPDC from qdax.custom_types import EnvState, Params, RNGKey -from qdax.environments import behavior_descriptor_extractor +from qdax.environments import descriptor_extractor from qdax.environments.wrappers import ClipRewardWrapper, OffsetRewardWrapper -from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.utils.metrics import default_qd_metrics @@ -61,7 +61,7 @@ def test_dcrlme() -> None: policy_delay = 2 # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init environment env = environments.create(env_name, episode_length=episode_length) @@ -76,13 +76,14 @@ def test_dcrlme() -> None: reset_fn = jax.jit(env.reset) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_bd, maxval=max_bd, - random_key=random_key, + key=subkey, ) # Init policy network @@ -99,14 +100,14 @@ def test_dcrlme() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size)) init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs) # Define the function to play a step with the policy in the environment def play_step_fn( - env_state: EnvState, policy_params: Params, random_key: RNGKey + env_state: EnvState, policy_params: Params, key: RNGKey ) -> Tuple[EnvState, Params, RNGKey, DCRLTransition]: actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] @@ -122,25 +123,25 @@ def play_step_fn( state_desc=state_desc, next_state_desc=next_state.info["state_descriptor"], desc=jnp.zeros( - env.behavior_descriptor_length, + env.descriptor_length, ) * jnp.nan, desc_prime=jnp.zeros( - env.behavior_descriptor_length, + env.descriptor_length, ) * jnp.nan, ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Prepare the scoring function - bd_extraction_fn = behavior_descriptor_extractor[env_name] + bd_extraction_fn = descriptor_extractor[env_name] scoring_fn = functools.partial( - reset_based_scoring_function_brax_envs, + scoring_function_brax_envs, episode_length=episode_length, play_reset_fn=reset_fn, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=bd_extraction_fn, ) # Get minimum reward value to make sure qd_score are positive @@ -195,26 +196,29 @@ def play_step_fn( ) # compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_params, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_params, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, key: RNGKey) -> Any: + repertoire, emitter_state = carry + # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, key + ) - return (repertoire, emitter_state, random_key), metrics + return (repertoire, emitter_state), metrics # Run the algorithm + keys = jax.random.split(key, num=num_iterations) ( repertoire, emitter_state, - random_key, ), metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), - (), + (repertoire, emitter_state), + keys, length=num_iterations, ) diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index f06a4298..87613264 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -70,14 +70,16 @@ def test_diayn_smerl() -> None: ) key = jax.random.key(seed) - env_state = jax.jit(env.reset)(rng=key) - eval_env_first_state = jax.jit(eval_env.reset)(rng=key) + + key, subkey_1, subkey_2 = jax.random.split(key, 3) + env_state = jax.jit(env.reset)(rng=subkey_1) + eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2) # Initialize buffer dummy_transition = QDTransition.init_dummy( observation_dim=env.observation_size + num_skills, action_dim=env.action_size, - descriptor_dim=env.behavior_descriptor_length, + descriptor_dim=env.descriptor_length, ) replay_buffer = TrajectoryBuffer.init( buffer_size=buffer_size, @@ -89,7 +91,7 @@ def test_diayn_smerl() -> None: if descriptor_full_state: descriptor_size = env.observation_size else: - descriptor_size = env.behavior_descriptor_length + descriptor_size = env.descriptor_length diayn_smerl_config = DiaynSmerlConfig( # SAC config @@ -114,8 +116,9 @@ def test_diayn_smerl() -> None: ) diayn_smerl = DIAYNSMERL(config=diayn_smerl_config, action_size=env.action_size) + key, subkey = jax.random.split(key) training_state = diayn_smerl.init( - key, + subkey, action_size=env.action_size, observation_size=env.observation_size, descriptor_size=descriptor_size, diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index 856e0174..fb75024b 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -63,14 +63,16 @@ def test_diayn() -> None: ) key = jax.random.key(seed) - env_state = jax.jit(env.reset)(rng=key) - eval_env_first_state = jax.jit(eval_env.reset)(rng=key) + + key, subkey_1, subkey_2 = jax.random.split(key, 3) + env_state = jax.jit(env.reset)(rng=subkey_1) + eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2) # Initialize buffer dummy_transition = QDTransition.init_dummy( observation_dim=env.observation_size + num_skills, action_dim=env.action_size, - descriptor_dim=env.behavior_descriptor_length, + descriptor_dim=env.descriptor_length, ) replay_buffer = ReplayBuffer.init( buffer_size=buffer_size, transition=dummy_transition @@ -99,10 +101,11 @@ def test_diayn() -> None: if descriptor_full_state: descriptor_size = env.observation_size else: - descriptor_size = env.behavior_descriptor_length + descriptor_size = env.descriptor_length + key, subkey = jax.random.split(key) training_state = diayn.init( - key, + subkey, action_size=env.action_size, observation_size=env.observation_size, descriptor_size=descriptor_size, diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 619c76e2..5fb8a28c 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -64,15 +64,14 @@ def rastrigin_scorer( scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag) - def scoring_fn( - genotypes: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, ExtraScores, RNGKey]: + def scoring_fn(genotypes: jnp.ndarray, key: RNGKey) -> Tuple[Fitness, ExtraScores]: fitnesses, _ = scoring_function(genotypes) - return fitnesses, {}, random_key + return fitnesses, {} # initial population - random_key = jax.random.key(42) - random_key, subkey = jax.random.split(random_key) + key = jax.random.key(42) + + key, subkey = jax.random.split(key) genotypes = jax.random.uniform( subkey, (batch_size, genotype_dim), @@ -109,23 +108,24 @@ def scoring_fn( metrics_function=default_ga_metrics, ) + key, subkey = jax.random.split(key) if isinstance(algo_instance, SPEA2): - repertoire, emitter_state, random_key = algo_instance.init( - genotypes, population_size, num_neighbours, random_key + repertoire, emitter_state = algo_instance.init( + genotypes, population_size, num_neighbours, subkey ) else: - repertoire, emitter_state, random_key = algo_instance.init( - genotypes, population_size, random_key + repertoire, emitter_state = algo_instance.init( + genotypes, population_size, subkey ) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( algo_instance.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index c8dcc5af..c9afbe0f 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -67,9 +67,10 @@ def test_me_pbt_sac() -> None: episode_length=episode_length, auto_reset=True, ) - min_bd, max_bd = env.behavior_descriptor_limits + min_descriptor, max_descriptor = env.descriptor_limits key = jax.random.key(seed) + key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) @@ -114,19 +115,23 @@ def test_me_pbt_sac() -> None: ) # get scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] - eval_policy = agent.get_eval_qd_fn(eval_env, bd_extraction_fn=bd_extraction_fn) + descriptor_extraction_fn = environments.descriptor_extractor[env_name] + eval_policy = agent.get_eval_qd_fn( + eval_env, descriptor_extraction_fn=descriptor_extraction_fn + ) - def scoring_function(genotypes, random_key): # type: ignore - population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_util.tree_map( + def scoring_function(genotypes, key): # type: ignore + population_size = jax.tree.leaves(genotypes)[0].shape[0] + first_states = jax.tree.map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_util.tree_map( + first_states = jax.tree.map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) - population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, {}, random_key + population_returns, population_descriptors, _, _ = eval_policy( + genotypes, first_states + ) + return population_returns, population_descriptors, {} # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -144,13 +149,14 @@ def scoring_function(genotypes, random_key): # type: ignore metrics_function=metrics_function, ) - centroids, key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) key, *keys = jax.random.split(key, num=1 + num_devices) @@ -166,9 +172,7 @@ def scoring_function(genotypes, random_key): # type: ignore # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 keys = jax.random.key_data(keys) - keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( - keys - ) + training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys) # empty optimizers states to avoid too heavy repertories training_states = jax.pmap( @@ -178,10 +182,10 @@ def scoring_function(genotypes, random_key): # type: ignore )(training_states) # initialize map-elites - repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( + repertoire, emitter_state = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - genotypes=training_states, random_key=keys + genotypes=training_states, key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) @@ -189,17 +193,15 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_util.tree_map( + initial_metrics_cpu = jax.tree.map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] for _ in range(num_loops): - repertoire, emitter_state, keys, metrics = update_fn( - repertoire, emitter_state, keys - ) - metrics_cpu = jax.tree_util.tree_map( + repertoire, emitter_state, metrics = update_fn(repertoire, emitter_state, keys) + metrics_cpu = jax.tree.map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index f243725e..bd2a7a16 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -67,9 +67,10 @@ def test_me_pbt_td3() -> None: episode_length=episode_length, auto_reset=True, ) - min_bd, max_bd = env.behavior_descriptor_limits + min_descriptor, max_descriptor = env.descriptor_limits key = jax.random.key(seed) + key, subkey = jax.random.split(key) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey) @@ -112,19 +113,23 @@ def test_me_pbt_td3() -> None: ) # get scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] - eval_policy = agent.get_eval_qd_fn(eval_env, bd_extraction_fn=bd_extraction_fn) + descriptor_extraction_fn = environments.descriptor_extractor[env_name] + eval_policy = agent.get_eval_qd_fn( + eval_env, descriptor_extraction_fn=descriptor_extraction_fn + ) - def scoring_function(genotypes, random_key): # type: ignore - population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_util.tree_map( + def scoring_function(genotypes, key): # type: ignore + population_size = jax.tree.leaves(genotypes)[0].shape[0] + first_states = jax.tree.map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_util.tree_map( + first_states = jax.tree.map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) - population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, {}, random_key + population_returns, population_descriptors, _, _ = eval_policy( + genotypes, first_states + ) + return population_returns, population_descriptors, {} # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -142,13 +147,14 @@ def scoring_function(genotypes, random_key): # type: ignore metrics_function=metrics_function, ) - centroids, key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) key, *keys = jax.random.split(key, num=1 + num_devices) @@ -164,9 +170,7 @@ def scoring_function(genotypes, random_key): # type: ignore # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 keys = jax.random.key_data(keys) - keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( - keys - ) + training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys) # empty optimizers states to avoid too heavy repertories training_states = jax.pmap( @@ -176,10 +180,10 @@ def scoring_function(genotypes, random_key): # type: ignore )(training_states) # initialize map-elites - repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( + repertoire, emitter_state = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - genotypes=training_states, random_key=keys + genotypes=training_states, key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) @@ -187,17 +191,15 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_util.tree_map( + initial_metrics_cpu = jax.tree.map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] for _ in range(num_loops): - repertoire, emitter_state, keys, metrics = update_fn( - repertoire, emitter_state, keys - ) - metrics_cpu = jax.tree_util.tree_map( + repertoire, emitter_state, metrics = update_fn(repertoire, emitter_state, keys) + metrics_cpu = jax.tree.map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 2eb280f1..924b557c 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -15,7 +15,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import EnvState, Params, RNGKey -from qdax.tasks.brax_envs import scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function def test_mees() -> None: @@ -26,8 +26,8 @@ def test_mees() -> None: policy_hidden_layer_sizes = (64, 64) num_init_cvt_samples = 1000 num_centroids = 50 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # MEES Emitter params sample_number = 128 @@ -45,9 +45,10 @@ def test_mees() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) + reset_fn = jax.jit(env.reset) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -58,7 +59,7 @@ def test_mees() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) fake_batch = jnp.zeros(shape=(1, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -67,7 +68,7 @@ def test_mees() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -89,32 +90,16 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition - # Create the initial environment states for samples and final indivs - reset_fn = jax.jit(jax.vmap(env.reset)) - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=sample_number, axis=0) - init_states_samples = reset_fn(keys) - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) - init_states = reset_fn(keys) - - # Prepare the scoring function for samples and final indivs - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + # Prepare the scoring function + descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function_brax_envs, - init_states=init_states, - episode_length=episode_length, - play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, - ) - scoring_samples_fn = functools.partial( - scoring_function_brax_envs, - init_states=init_states_samples, + scoring_function, episode_length=episode_length, + play_reset_fn=reset_fn, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) # Get minimum reward value to make sure qd_score are positive @@ -152,18 +137,19 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: mees_emitter = MEESEmitter( config=mees_emitter_config, total_generations=num_iterations, - scoring_fn=scoring_samples_fn, - num_descriptors=env.behavior_descriptor_length, + scoring_fn=scoring_fn, + num_descriptors=env.descriptor_length, ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=random_key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) # Instantiate MAP Elites @@ -173,25 +159,27 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) - - return (repertoire, emitter_state, random_key), metrics + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, subkey + ) + return (repertoire, emitter_state, key), metrics # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index 632dc993..08e96ae1 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -61,11 +61,9 @@ def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, ExtraScores]: gradients = jnp.nan_to_num(gradients) return fitnesses, descriptors, {"gradients": gradients} - def scoring_fn( - x: Genotype, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + def scoring_fn(x: Genotype, key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) - return fitnesses, descriptors, extra_scores, random_key + return fitnesses, descriptors, extra_scores worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * maxval) best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * maxval * 0.4) @@ -82,10 +80,10 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - random_key = jax.random.key(0) + key = jax.random.key(0) # defines the population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) initial_population = jax.random.uniform(subkey, shape=(100, num_dimensions)) sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid @@ -109,17 +107,16 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn ) - repertoire, emitter_state, random_key = map_elites.init( - initial_population, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(initial_population, centroids, subkey) ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index 9c4b2c83..c1e2af85 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -54,12 +54,12 @@ def test_pbt_sac() -> None: ) @jax.jit - def init_environments(random_key): # type: ignore - env_states = jax.jit(env.reset)(rng=random_key) - eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) + def init_environments(key): # type: ignore + env_states = jax.jit(env.reset)(rng=key) + eval_env_first_states = jax.jit(eval_env.reset)(rng=key) reshape_fn = jax.jit( - lambda tree: jax.tree_util.tree_map( + lambda tree: jax.tree.map( lambda x: jnp.reshape( x, ( @@ -77,6 +77,7 @@ def init_environments(random_key): # type: ignore return env_states, eval_env_first_states key = jax.random.key(seed) + key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( @@ -107,7 +108,7 @@ def init_environments(random_key): # type: ignore # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 keys = jax.random.key_data(keys) - keys, training_states, replay_buffers = jax.pmap( + training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) @@ -152,7 +153,7 @@ def init_environments(random_key): # type: ignore # PBT selection if i < (num_loops - 1): - keys, training_states, replay_buffers = select_fn( + training_states, replay_buffers = select_fn( keys, population_returns, training_states, replay_buffers ) diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index e45a9701..9a061e49 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -52,12 +52,12 @@ def test_pbt_td3() -> None: ) @jax.jit - def init_environments(random_key): # type: ignore - env_states = jax.jit(env.reset)(rng=random_key) - eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) + def init_environments(key): # type: ignore + env_states = jax.jit(env.reset)(rng=key) + eval_env_first_states = jax.jit(eval_env.reset)(rng=key) reshape_fn = jax.jit( - lambda tree: jax.tree_util.tree_map( + lambda tree: jax.tree.map( lambda x: jnp.reshape( x, ( @@ -103,7 +103,7 @@ def init_environments(random_key): # type: ignore # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 keys = jax.random.key_data(keys) - keys, training_states, replay_buffers = jax.pmap( + training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) @@ -148,7 +148,7 @@ def init_environments(random_key): # type: ignore # PBT selection if i < (num_loops - 1): - keys, training_states, replay_buffers = select_fn( + training_states, replay_buffers = select_fn( keys, population_returns, training_states, replay_buffers ) diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index a9fd336e..5dec3745 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -16,7 +16,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import EnvState, Params, RNGKey -from qdax.tasks.brax_envs import scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function def test_pgame() -> None: @@ -27,8 +27,8 @@ def test_pgame() -> None: policy_hidden_layer_sizes = (64, 64) num_init_cvt_samples = 1000 num_centroids = 50 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # @title PGA-ME Emitter Definitions Fields proportion_mutation_ga = 0.5 @@ -52,9 +52,10 @@ def test_pgame() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) + reset_fn = jax.jit(env.reset) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -65,7 +66,7 @@ def test_pgame() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=env_batch_size) fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -74,7 +75,7 @@ def test_pgame() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -96,7 +97,7 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -144,30 +145,25 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: variation_fn=variation_fn, ) - # Create the initial environment states - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0) - reset_fn = jax.jit(jax.vmap(env.reset)) - init_states = reset_fn(keys) - # Prepare the scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function_brax_envs, - init_states=init_states, + scoring_function, episode_length=episode_length, + play_reset_fn=reset_fn, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=random_key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) # Instantiate MAP Elites @@ -177,25 +173,27 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) - - return (repertoire, emitter_state, random_key), metrics + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, subkey + ) + return (repertoire, emitter_state, key), metrics # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index dfed7bb8..63e5c0a7 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -18,7 +18,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import EnvState, Params, RNGKey -from qdax.tasks.brax_envs import scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function def test_qdpg() -> None: @@ -29,8 +29,8 @@ def test_qdpg() -> None: policy_hidden_layer_sizes = (64, 64) num_init_cvt_samples = 1000 num_centroids = 50 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # mutations size quality_pg_batch_size = 3 @@ -67,9 +67,10 @@ def test_qdpg() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) + reset_fn = jax.jit(env.reset) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -80,7 +81,7 @@ def test_qdpg() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=env_batch_size) fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -89,7 +90,7 @@ def test_qdpg() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -111,7 +112,7 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -194,30 +195,25 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: score_novelty=score_novelty, ) - # Create the initial environment states - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0) - reset_fn = jax.jit(jax.vmap(env.reset)) - init_states = reset_fn(keys) - # Prepare the scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function_brax_envs, - init_states=init_states, + scoring_function, episode_length=episode_length, + play_reset_fn=reset_fn, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=random_key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) # Instantiate MAP Elites @@ -227,25 +223,27 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: # iterate over grid - repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) - - return (repertoire, emitter_state, random_key), metrics + repertoire, emitter_state, key = carry + key, subkey = jax.random.split(key) + repertoire, emitter_state, metrics = map_elites.update( + repertoire, emitter_state, subkey + ) + return (repertoire, emitter_state, key), metrics # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), _metrics = jax.lax.scan( update_scan_fn, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index 57554b92..31510670 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -82,7 +82,7 @@ def test_sac() -> None: sac = SAC(config=sac_config, action_size=env.action_size) key, subkey = jax.random.split(key) training_state = sac.init( - random_key=subkey, + key=subkey, action_size=env.action_size, observation_size=env.observation_size, ) diff --git a/tests/baselines_test/td3_test.py b/tests/baselines_test/td3_test.py index 55c09811..eab9f3bc 100644 --- a/tests/baselines_test/td3_test.py +++ b/tests/baselines_test/td3_test.py @@ -34,6 +34,8 @@ def test_td3() -> None: critic_learning_rate = 3e-4 policy_learning_rate = 3e-4 + key = jax.random.key(seed) + # Create environment env = environments.create( env_name=env_name, @@ -49,10 +51,10 @@ def test_td3() -> None: auto_reset=True, eval_metrics=True, ) - key = jax.random.key(seed) - key, subkey = jax.random.split(key) - env_state = jax.jit(env.reset)(rng=key) - eval_env_first_state = jax.jit(eval_env.reset)(rng=key) + + key, subkey_1, subkey_2 = jax.random.split(key, 3) + env_state = jax.jit(env.reset)(rng=subkey_1) + eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2) # Initialize buffer dummy_transition = Transition.init_dummy( @@ -83,7 +85,7 @@ def test_td3() -> None: key, subkey = jax.random.split(key) training_state = td3.init( - key, action_size=env.action_size, observation_size=env.observation_size + subkey, action_size=env.action_size, observation_size=env.observation_size ) # Wrap and jit play step function diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 9ec55a78..c7ff6691 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -12,7 +12,7 @@ from qdax.core.aurora import AURORA from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.custom_types import Observation -from qdax.environments.bd_extractors import ( +from qdax.environments.descriptor_extractors import ( AuroraExtraInfoNormalization, get_aurora_encoding, ) @@ -75,16 +75,17 @@ def test_aurora(env_name: str, batch_size: int) -> None: log_freq = 5 # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init environment - env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + key, subkey = jax.random.split(key) + env, policy_network, scoring_fn = create_default_brax_task_components( env_name=env_name, - random_key=random_key, + key=subkey, ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -165,7 +166,7 @@ def observation_extractor_fn( ) # init the model params - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) model_params = train_seq2seq.get_initial_params( model, subkey, (1, *observations_dims) ) @@ -182,18 +183,19 @@ def observation_extractor_fn( ) # init step of the aurora algorithm - repertoire, emitter_state, aurora_extra_info, random_key = aurora.init( + key, subkey = jax.random.split(key) + repertoire, emitter_state, aurora_extra_info = aurora.init( init_variables, aurora_extra_info, jnp.asarray(l_value_init), max_size, - random_key, + subkey, ) # initializing means and stds and AURORA - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) repertoire, aurora_extra_info = aurora.train( - repertoire, model_params, iteration=0, random_key=subkey + repertoire, model_params, iteration=0, key=subkey ) # design aurora's schedule @@ -215,10 +217,11 @@ def observation_extractor_fn( while iteration < max_iterations: # standard MAP-Elites-like loop for _ in range(log_freq): - repertoire, emitter_state, _, random_key = aurora.update( + key, subkey = jax.random.split(key) + repertoire, emitter_state, _ = aurora.update( repertoire, emitter_state, - random_key, + subkey, aurora_extra_info=aurora_extra_info, ) @@ -228,7 +231,7 @@ def observation_extractor_fn( # autoencoder steps and Container Size Control (CSC) if (iteration + 1) in schedules: # train the autoencoder (includes the CSC) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) repertoire, aurora_extra_info = aurora.train( repertoire, model_params, iteration, subkey ) diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index c81ee19e..8e22a18f 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -32,14 +32,15 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: ) state = cmaes.init() - random_key = jax.random.key(0) + key = jax.random.key(0) iteration_count = 0 for _ in range(num_iterations): iteration_count += 1 # sample - samples, random_key = cmaes.sample(state, random_key) + key, subkey = jax.random.split(key) + samples = cmaes.sample(state, subkey) # update state = cmaes.update(state, samples) diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index cee20261..de56b124 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -23,14 +23,14 @@ def test_multi_emitter() -> None: grid_shape = (100, 100) min_param = 0.0 max_param = 1.0 - min_bd = min_param - max_bd = max_param + min_descriptor = min_param + max_descriptor = max_param # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions) ) @@ -86,23 +86,22 @@ def test_multi_emitter() -> None: # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 0702b61e..a39db52d 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -15,7 +15,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import EnvState, Params, RNGKey -from qdax.tasks.brax_envs import scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function from qdax.utils.metrics import default_qd_metrics @@ -44,14 +44,15 @@ def test_map_elites(env_name: str, batch_size: int) -> None: policy_hidden_layer_sizes = (64, 64) num_init_cvt_samples = 1000 num_centroids = 50 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # Init environment env = environments.create(env_name, episode_length=episode_length) + reset_fn = jax.jit(env.reset) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -62,22 +63,16 @@ def test_map_elites(env_name: str, batch_size: int) -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - # Create the initial environment states - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) - reset_fn = jax.jit(jax.vmap(env.reset)) - init_states = reset_fn(keys) - # Define the function to play a step with the policy in the environment def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -99,16 +94,16 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Prepare the scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function_brax_envs, - init_states=init_states, + scoring_function, episode_length=episode_length, + play_reset_fn=reset_fn, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) # Define emitter @@ -128,28 +123,28 @@ def play_step_fn( ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=random_key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 383ab55a..08ae9863 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -16,7 +16,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import EnvState, Params, RNGKey -from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs @pytest.mark.parametrize( @@ -33,14 +33,14 @@ def test_mels(env_name: str, batch_size: int) -> None: policy_hidden_layer_sizes = (64, 64) num_init_cvt_samples = 1000 num_centroids = 50 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # Init environment env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -52,7 +52,7 @@ def test_mels(env_name: str, batch_size: int) -> None: # Init population of controllers. There are batch_size controllers, and each # controller will be evaluated num_samples times. - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -61,7 +61,7 @@ def test_mels(env_name: str, batch_size: int) -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """Play an environment step and return the updated state and the transition.""" @@ -82,16 +82,16 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition # Prepare the scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( - reset_based_scoring_function_brax_envs, + scoring_function_brax_envs, episode_length=episode_length, play_reset_fn=env.reset, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) # Define emitter @@ -127,28 +127,28 @@ def metrics_fn(repertoire: MELSRepertoire) -> Dict: ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=random_key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) # Compute initial repertoire - repertoire, emitter_state, random_key = mels.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = mels.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( mels.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index 4e0eb574..14b08697 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -38,6 +38,8 @@ def test_mome(num_descriptors: int) -> None: lag = 2.2 base_lag = 0.0 + key = jax.random.key(42) + def rastrigin_scorer( genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: @@ -68,10 +70,10 @@ def rastrigin_scorer( scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag) def scoring_fn( - genotypes: jnp.ndarray, random_key: RNGKey - ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + genotypes: jnp.ndarray, key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores]: fitnesses, descriptors = scoring_function(genotypes) - return fitnesses, descriptors, {}, random_key + return fitnesses, descriptors, {} reference_point = jnp.array([-150, -150]) @@ -79,8 +81,7 @@ def scoring_fn( metrics_function = partial(default_moqd_metrics, reference_point=reference_point) # initial population - random_key = jax.random.key(42) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) genotypes = jax.random.uniform( subkey, (batch_size, num_variables), @@ -111,13 +112,14 @@ def scoring_fn( batch_size=batch_size, ) - centroids, random_key = compute_cvt_centroids( + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( num_descriptors=num_descriptors, num_init_cvt_samples=20000, num_centroids=num_centroids, minval=minval, maxval=maxval, - random_key=random_key, + key=subkey, ) mome = MOME( @@ -126,18 +128,19 @@ def scoring_fn( metrics_function=metrics_function, ) - repertoire, emitter_state, random_key = mome.init( - genotypes, centroids, pareto_front_max_length, random_key + key, subkey = jax.random.split(key) + repertoire, emitter_state = mome.init( + genotypes, centroids, pareto_front_max_length, subkey ) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( mome.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index 07726b94..c0d91ff7 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -42,9 +42,7 @@ def test_insert_batch() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_util.tree_map( - lambda x: x.repeat(3, axis=0), dummy_transition - ) + simple_transition = jax.tree.map(lambda x: x.repeat(3, axis=0), dummy_transition) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) pytest.assume( @@ -85,16 +83,15 @@ def test_sample() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_util.tree_map( - lambda x: x.repeat(3, axis=0), dummy_transition - ) + simple_transition = jax.tree.map(lambda x: x.repeat(3, axis=0), dummy_transition) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) - random_key = jax.random.key(0) + key = jax.random.key(0) - samples, random_key = replay_buffer.sample(random_key, 3) + key, subkey = jax.random.split(key) + samples = replay_buffer.sample(subkey, 3) - samples_shapes = jax.tree_util.tree_map(lambda x: x.shape, samples) - transition_shapes = jax.tree_util.tree_map(lambda x: x.shape, simple_transition) + samples_shapes = jax.tree.map(lambda x: x.shape, samples) + transition_shapes = jax.tree.map(lambda x: x.shape, simple_transition) pytest.assume((samples_shapes == transition_shapes)) diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index 31e0e0c1..261dc215 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -37,14 +37,14 @@ def test_arm(task_name: str, batch_size: int) -> None: grid_shape = (100, 100) min_param = 0.0 max_param = 1.0 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions), @@ -86,23 +86,22 @@ def test_arm(task_name: str, batch_size: int) -> None: # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) @@ -114,9 +113,9 @@ def test_arm_scoring_function() -> None: # Init a random key seed = 42 - random_key = jax.random.key(seed) + key = jax.random.key(seed) - # arm has xy BD centered at 0.5 0.5 and min max range is [0,1] + # arm has xy descriptor centered at 0.5 0.5 and min max range is [0,1] # 0 params of first genotype is horizontal and points towards negative x axis # angles move in anticlockwise direction genotypes_1 = jnp.ones(shape=(1, 4)) * 0.5 # 0.5 @@ -131,27 +130,13 @@ def test_arm_scoring_function() -> None: genotypes_6 = jnp.array([[0.5, 0.5]]) genotypes_7 = jnp.array([[0.75, 0.5]]) - fitness_1, descriptors_1, _, random_key = arm_scoring_function( - genotypes_1, random_key - ) - fitness_2, descriptors_2, _, random_key = arm_scoring_function( - genotypes_2, random_key - ) - fitness_3, descriptors_3, _, random_key = arm_scoring_function( - genotypes_3, random_key - ) - fitness_4, descriptors_4, _, random_key = arm_scoring_function( - genotypes_4, random_key - ) - fitness_5, descriptors_5, _, random_key = arm_scoring_function( - genotypes_5, random_key - ) - fitness_6, descriptors_6, _, random_key = arm_scoring_function( - genotypes_6, random_key - ) - fitness_7, descriptors_7, _, random_key = arm_scoring_function( - genotypes_7, random_key - ) + fitness_1, descriptors_1, _ = arm_scoring_function(genotypes_1, key) + fitness_2, descriptors_2, _ = arm_scoring_function(genotypes_2, key) + fitness_3, descriptors_3, _ = arm_scoring_function(genotypes_3, key) + fitness_4, descriptors_4, _ = arm_scoring_function(genotypes_4, key) + fitness_5, descriptors_5, _ = arm_scoring_function(genotypes_5, key) + fitness_6, descriptors_6, _ = arm_scoring_function(genotypes_6, key) + fitness_7, descriptors_7, _ = arm_scoring_function(genotypes_7, key) # use rounding to avoid some numerical floating point errors pytest.assume( diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index 55768171..e3de0814 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -30,15 +30,16 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - seed = 42 num_init_cvt_samples = 1000 num_centroids = 50 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) - env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + key, subkey = jax.random.split(key) + env, policy_network, scoring_fn = create_default_brax_task_components( env_name=env_name, - random_key=random_key, + key=subkey, ) # Define emitter @@ -64,33 +65,34 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - ) # Compute the centroids - centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( + num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, - random_key=random_key, + minval=min_descriptor, + maxval=max_descriptor, + key=subkey, ) # Init population of controllers - init_variables, random_key = init_population_controllers( - policy_network, env, batch_size, random_key + key, subkey = jax.random.split(key) + init_variables = init_population_controllers( + policy_network, env, batch_size, subkey ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index 152c245d..1bdd97b3 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -46,14 +46,14 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: grid_shape = (100, 100) min_param = 0.0 max_param = 1.0 - min_bd = 0.0 - max_bd = 1.0 + min_descriptor = 0.0 + max_descriptor = 1.0 # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions) ) @@ -92,23 +92,22 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index dba574ac..77a4ddee 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -38,7 +38,7 @@ def test_jumanji_utils() -> None: state, _timestep = jax.jit(env.step)(state, action) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # get number of actions num_actions = env.action_spec().maximum + 1 @@ -69,7 +69,7 @@ def observation_processing( ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec @@ -84,17 +84,17 @@ def observation_processing( init_variables = jax.vmap(policy_network.init)(keys, fake_batch) # Create the initial environment states - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) init_states, init_timesteps = reset_fn(keys) # Prepare the scoring function - def bd_extraction( + def descriptor_extraction( data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.ndarray ) -> Descriptor: - """Extract a behavior descriptor from a trajectory. + """Extract a descriptor from a trajectory. This extractor takes the mean observation in the trajectory and project it in a two dimension space. @@ -105,7 +105,7 @@ def bd_extraction( linear_projection: a linear projection. Returns: - Behavior descriptors. + Descriptors. """ # pre-process the observation @@ -120,13 +120,13 @@ def bd_extraction( return descriptors # create a random projection to a two dim space - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) linear_projection = jax.random.uniform( subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32 ) - bd_extraction_fn = functools.partial( - bd_extraction, linear_projection=linear_projection + descriptor_extraction_fn = functools.partial( + descriptor_extraction, linear_projection=linear_projection ) # define the scoring function @@ -136,12 +136,11 @@ def bd_extraction( init_timesteps=init_timesteps, episode_length=episode_length, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) - fitnesses, descriptors, extra_scores, random_key = scoring_fn( - init_variables, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = scoring_fn(init_variables, subkey) pytest.assume(fitnesses.shape == (population_size,)) pytest.assume(jnp.sum(jnp.isnan(fitnesses)) == 0.0) diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index 9424cc63..ae015866 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -55,20 +55,20 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: batch_size = batch_size num_iterations = 5 min_param, max_param = task.get_min_max_params() - min_bd, max_bd = task.get_bounded_min_max_descriptor() - bd_size = task.get_descriptor_size() + min_descriptor, max_descriptor = task.get_bounded_min_max_descriptor() + descriptor_size = task.get_descriptor_size() grid_shape: Tuple[int, ...] - if bd_size == 1: + if descriptor_size == 1: grid_shape = (100,) - elif bd_size == 2: + elif descriptor_size == 2: grid_shape = (100, 100) else: - resolution_per_axis = math.floor(math.pow(10000.0, 1.0 / bd_size)) - grid_shape = tuple([resolution_per_axis for _ in range(bd_size)]) + resolution_per_axis = math.floor(math.pow(10000.0, 1.0 / descriptor_size)) + grid_shape = tuple([resolution_per_axis for _ in range(descriptor_size)]) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init population of parameters init_variables = task.get_initial_parameters(init_batch_size) @@ -107,23 +107,22 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index b30cd7cc..f69ee8af 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -36,14 +36,14 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: grid_shape = (100, 100) min_param = 0.0 max_param = 1.0 - min_bd = min_param - max_bd = max_param + min_descriptor = min_param + max_descriptor = max_param # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) init_variables = jax.random.uniform( subkey, shape=(init_batch_size, num_param_dimensions) ) @@ -82,23 +82,22 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=grid_shape, - minval=min_bd, - maxval=max_bd, + minval=min_descriptor, + maxval=max_descriptor, ) # Compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init( - init_variables, centroids, random_key - ) + key, subkey = jax.random.split(key) + repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) # Run the algorithm ( repertoire, emitter_state, - random_key, + key, ), metrics = jax.lax.scan( map_elites.scan_update, - (repertoire, emitter_state, random_key), + (repertoire, emitter_state, key), (), length=num_iterations, ) diff --git a/tests/environments_test/wrapper_test.py b/tests/environments_test/wrapper_test.py index b29d89e1..14190153 100644 --- a/tests/environments_test/wrapper_test.py +++ b/tests/environments_test/wrapper_test.py @@ -110,8 +110,8 @@ def test_wrapper(env_name: str) -> None: print("Observation size: ", env.observation_size) print("Action size: ", env.action_size) - random_key = jax.random.key(seed) - init_state = env.reset(random_key) + key = jax.random.key(seed) + init_state = env.reset(key) joint_angle = jp.concatenate( [joint.angle_vel(init_state.qp)[0] for joint in env.sys.joints] diff --git a/tests/utils_test/plotting_test.py b/tests/utils_test/plotting_test.py index 807760ce..234de9e8 100644 --- a/tests/utils_test/plotting_test.py +++ b/tests/utils_test/plotting_test.py @@ -38,8 +38,8 @@ def test_onion_grid(num_descriptors: int, grid_shape: Tuple[int, ...]) -> None: minval = jnp.array([0] * num_descriptors) maxval = jnp.array([1] * num_descriptors) - random_key = jax.random.key(seed=0) - random_key, key_desc, key_fit = jax.random.split(random_key, num=3) + key = jax.random.key(seed=0) + key, key_desc, key_fit = jax.random.split(key, num=3) number_samples_test = 300 descriptors = jax.random.uniform( diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 5c3c880f..13f65f03 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -9,7 +9,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.custom_types import EnvState, Params, RNGKey -from qdax.tasks.brax_envs import scoring_function_brax_envs +from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function from qdax.utils.sampling import ( average, closest, @@ -34,7 +34,7 @@ def test_sampling() -> None: env = environments.create(env_name, episode_length=episode_length) # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) @@ -45,7 +45,7 @@ def test_sampling() -> None: ) # Init population of controllers - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) fake_batch = jnp.zeros(shape=(1, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -54,7 +54,7 @@ def test_sampling() -> None: def play_step_fn( env_state: EnvState, policy_params: Params, - random_key: RNGKey, + key: RNGKey, ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: """ Play an environment step and return the updated state and the transition. @@ -76,22 +76,19 @@ def play_step_fn( next_state_desc=next_state.info["state_descriptor"], ) - return next_state, policy_params, random_key, transition + return next_state, policy_params, key, transition - # Create the initial environment states for samples and final indivs - reset_fn = jax.jit(jax.vmap(env.reset)) - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) - init_states = reset_fn(keys) + key, subkey = jax.random.split(key) + init_state = env.reset(subkey) - # Create the scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + # Prepare the scoring function + descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function_brax_envs, - init_states=init_states, + scoring_function, episode_length=episode_length, + play_reset_fn=lambda _: init_state, play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + descriptor_extractor=descriptor_extraction_fn, ) # Test function for different extractors @@ -110,9 +107,9 @@ def sampling_test( ) # Evaluate individuals using the scoring functions - fitnesses, descriptors, _, _ = scoring_fn(init_variables, random_key) - sample_fitnesses, sample_descriptors, _, _ = scoring_1_sample_fn( - init_variables, random_key + fitnesses, descriptors, _ = scoring_fn(init_variables, key) + sample_fitnesses, sample_descriptors, _ = scoring_1_sample_fn( + init_variables, key ) # Compare @@ -131,8 +128,8 @@ def sampling_test( ) # Evaluate individuals using the scoring functions - sample_fitnesses, sample_descriptors, _, _ = scoring_multi_sample_fn( - init_variables, random_key + sample_fitnesses, sample_descriptors, _ = scoring_multi_sample_fn( + init_variables, key ) # Compare @@ -151,6 +148,7 @@ def sampling_test( def sampling_reproducibility_test( fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray], descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray], + key: RNGKey, ) -> None: # Compare scoring against perforing a single sample @@ -163,14 +161,14 @@ def sampling_reproducibility_test( ) # Evaluate individuals using the scoring functions + key, subkey = jax.random.split(key) ( _, _, _, fitnesses_reproducibility, descriptors_reproducibility, - _, - ) = scoring_1_sample_fn(init_variables, random_key) + ) = scoring_1_sample_fn(init_variables, subkey) # Compare - all reproducibility should be 0 pytest.assume( @@ -200,14 +198,14 @@ def sampling_reproducibility_test( ) # Evaluate individuals using the scoring functions + key, subkey = jax.random.split(key) ( _, _, _, fitnesses_reproducibility, descriptors_reproducibility, - _, - ) = scoring_multi_sample_fn(init_variables, random_key) + ) = scoring_multi_sample_fn(init_variables, subkey) # Compare - all reproducibility should be 0 pytest.assume( @@ -228,6 +226,7 @@ def sampling_reproducibility_test( ) # Call the test for each type of extractor - sampling_reproducibility_test(std, std) - sampling_reproducibility_test(mad, mad) - sampling_reproducibility_test(iqr, iqr) + key_1, key_2, key_3 = jax.random.split(key, 3) + sampling_reproducibility_test(std, std, key_1) + sampling_reproducibility_test(mad, mad, key_2) + sampling_reproducibility_test(iqr, iqr, key_3) diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index 3f2caea1..f5770862 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -25,24 +25,26 @@ def test_uncertainty_metrics() -> None: genotype_dim = 8 # Init a random key - random_key = jax.random.key(seed) + key = jax.random.key(seed) # First, init a deterministic environment + key, subkey = jax.random.split(key) init_policies = jax.random.uniform( - random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 - ) - fitnesses, descriptors, extra_scores, random_key = arm_scoring_function( - init_policies, random_key + subkey, shape=(batch_size, genotype_dim), minval=0, maxval=1 ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = arm_scoring_function(init_policies, subkey) + # Initialise a container - centroids, random_key = compute_cvt_centroids( + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( num_descriptors=2, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=jnp.array([0.0, 0.0]), maxval=jnp.array([1.0, 1.0]), - random_key=random_key, + key=subkey, ) repertoire = MapElitesRepertoire.init( genotypes=init_policies, @@ -63,12 +65,13 @@ def test_uncertainty_metrics() -> None: ) # Test that reevaluation_function accurately predicts no change - corrected_repertoire, random_key = reevaluation_function( + key, subkey = jax.random.split(key) + corrected_repertoire = reevaluation_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, scoring_fn=arm_scoring_function, num_reevals=num_reevals, - random_key=random_key, + key=subkey, ) pytest.assume( jnp.allclose( @@ -77,12 +80,13 @@ def test_uncertainty_metrics() -> None: ) # Test that scanned reevaluation_function accurately predicts no change - corrected_repertoire, random_key = reevaluation_function( + key, subkey = jax.random.split(key) + corrected_repertoire = reevaluation_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, scoring_fn=arm_scoring_function, num_reevals=num_reevals, - random_key=random_key, + key=subkey, scan_size=scan_size, ) pytest.assume( @@ -92,17 +96,17 @@ def test_uncertainty_metrics() -> None: ) # Test that reevaluation_reproducibility_function accurately predicts no change + key, subkey = jax.random.split(key) ( corrected_repertoire, fit_reproducibility_repertoire, desc_reproducibility_repertoire, - random_key, ) = reevaluation_reproducibility_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, scoring_fn=arm_scoring_function, num_reevals=num_reevals, - random_key=random_key, + key=subkey, ) pytest.assume( jnp.allclose( @@ -132,8 +136,9 @@ def test_uncertainty_metrics() -> None: ) # Second, init a stochastic environment + key, subkey = jax.random.split(key) init_policies = jax.random.uniform( - random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 + subkey, shape=(batch_size, genotype_dim), minval=0, maxval=1 ) noisy_scoring_function = functools.partial( noisy_arm_scoring_function, @@ -141,18 +146,18 @@ def test_uncertainty_metrics() -> None: desc_variance=0.01, params_variance=0.0, ) - fitnesses, descriptors, extra_scores, random_key = noisy_scoring_function( - init_policies, random_key - ) + key, subkey = jax.random.split(key) + fitnesses, descriptors, extra_scores = noisy_scoring_function(init_policies, subkey) # Initialise a container - centroids, random_key = compute_cvt_centroids( + key, subkey = jax.random.split(key) + centroids = compute_cvt_centroids( num_descriptors=2, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=jnp.array([0.0, 0.0]), maxval=jnp.array([1.0, 1.0]), - random_key=random_key, + key=subkey, ) repertoire = MapElitesRepertoire.init( genotypes=init_policies, @@ -173,17 +178,17 @@ def test_uncertainty_metrics() -> None: ) # Test that reevaluation_function runs and keeps at least one solution + key, subkey = jax.random.split(key) ( corrected_repertoire, fit_reproducibility_repertoire, desc_reproducibility_repertoire, - random_key, ) = reevaluation_reproducibility_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, scoring_fn=noisy_scoring_function, num_reevals=num_reevals, - random_key=random_key, + key=subkey, ) pytest.assume(jnp.any(corrected_repertoire.fitnesses > -jnp.inf)) pytest.assume(jnp.any(fit_reproducibility_repertoire.fitnesses > -jnp.inf))