diff --git a/README.md b/README.md index 304fe843..551680eb 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ repertoire.genotypes, repertoire.fitnesses, repertoire.descriptors ## QDax core algorithms QDax currently supports the following algorithms: + | Algorithm | Example | |-------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | @@ -138,7 +139,6 @@ QDax currently supports the following algorithms: | [MAP-Elites Low-Spread (ME-LS)](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_ls.ipynb) | - ## QDax baseline algorithms The QDax library also provides implementations for some useful baseline algorithms: diff --git a/docs/api_documentation/core/aurora.md b/docs/api_documentation/core/aurora.md new file mode 100644 index 00000000..1088b2cd --- /dev/null +++ b/docs/api_documentation/core/aurora.md @@ -0,0 +1,7 @@ +# AURORA class + +This class implement the base mechanism of AURORA. It must be used with an emitter. To get the usual AURORA algorithm, one must use the [mixing emitter](emitters.md#qdax.core.emitters.standard_emitters.MixingEmitter). + +The AURORA class can be used with other emitters to create variants, like [PGA-AURORA](pga_aurora.md). + +::: qdax.core.aurora.AURORA diff --git a/docs/api_documentation/core/pga_aurora.md b/docs/api_documentation/core/pga_aurora.md new file mode 100644 index 00000000..dc4fd6d1 --- /dev/null +++ b/docs/api_documentation/core/pga_aurora.md @@ -0,0 +1,5 @@ +# Policy Gradient Assisted AURORA (PGA-AURORA) + +To create an instance of PGA-AURORA (introduced [in this paper](https://arxiv.org/abs/2210.03516)), one needs to use an instance of [AURORA](map_elites.md) with the PGAMEEmitter, detailed below. + +::: qdax.core.emitters.pga_me_emitter.PGAMEEmitter diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb new file mode 100644 index 00000000..a6a47894 --- /dev/null +++ b/examples/aurora.ipynb @@ -0,0 +1,524 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/aurora.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with AURORA in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [AURORA](https://arxiv.org/pdf/1905.11874.pdf).\n", + "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an AURORA instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "from typing import Dict, Any\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.aurora import AURORA\n", + "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import scoring_aurora_function\n", + "from qdax.environments.bd_extractors import get_aurora_bd\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "\n", + "from qdax.utils import train_seq2seq\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "episode_length = 250 #@param {type:\"integer\"}\n", + "max_iterations = 50 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "\n", + "lstm_batch_size = 128 #@param {type:\"integer\"}\n", + "\n", + "observation_option = \"no_sd\" #@param['no_sd', 'only_sd', 'full']\n", + "hidden_size = 5 #@param {type:\"integer\"}\n", + "l_value_init = 0.2 #@param {type:\"number\"}\n", + "\n", + "traj_sampling_freq = 10 #@param {type:\"integer\"}\n", + "max_observation_size = 25 #@param {type:\"integer\"}\n", + "prior_descriptor_dim = 2 #@param {type:\"integer\"}\n", + "\n", + "log_freq = 5 #@param {type:\"integer\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", + "\n", + "\n", + "# Create the initial environment states\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", + "reset_fn = jax.jit(jax.vmap(env.reset))\n", + "init_states = reset_fn(keys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the fonction to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"\n", + " Play an environment step and return the updated state and the transition.\n", + " \"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " \n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = functools.partial(\n", + " get_aurora_bd,\n", + " option=observation_option,\n", + " hidden_size=hidden_size,\n", + " traj_sampling_freq=traj_sampling_freq,\n", + " max_observation_size=max_observation_size,\n", + ")\n", + "scoring_fn = functools.partial(\n", + " scoring_aurora_function,\n", + " init_states=init_states,\n", + " episode_length=episode_length,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict:\n", + "\n", + " # Get metrics\n", + " grid_empty = repertoire.fitnesses == -jnp.inf\n", + " qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty)\n", + " # Add offset for positive qd_score\n", + " qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty)\n", + " coverage = 100 * jnp.mean(1.0 - grid_empty)\n", + " max_fitness = jnp.max(repertoire.fitnesses)\n", + "\n", + " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "mixing_emitter = MixingEmitter(\n", + " mutation_fn=lambda x, y: (x, y),\n", + " variation_fn=variation_fn,\n", + " variation_percentage=1.0,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the MAP Elites algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate AURORA\n", + "aurora = AURORA(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=metrics_fn,\n", + ")\n", + "\n", + "aurora_dims = hidden_size\n", + "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", + "\n", + "@jax.jit\n", + "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + " \"\"\"Scan the udpate function.\"\"\"\n", + " (\n", + " repertoire,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = carry\n", + "\n", + " # update\n", + " (repertoire, _, metrics, random_key,) = aurora.update(\n", + " repertoire,\n", + " None,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " )\n", + "\n", + " return (\n", + " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " )\n", + "\n", + "# Init algorithm\n", + "# AutoEncoder Params and INIT\n", + "obs_dim = jnp.minimum(env.observation_size, max_observation_size)\n", + "if observation_option == \"full\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim + prior_descriptor_dim,\n", + " )\n", + "elif observation_option == \"no_sd\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim,\n", + " )\n", + "elif observation_option == \"only_sd\":\n", + " observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim)\n", + "else:\n", + " ValueError(\"The chosen option is not correct.\")\n", + "\n", + "# define the seq2seq model\n", + "model = train_seq2seq.get_model(\n", + " observations_dims[-1], True, hidden_size=hidden_size\n", + ")\n", + "\n", + "# init the model params\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params = train_seq2seq.get_initial_params(\n", + " model, subkey, (1, *observations_dims)\n", + ")\n", + "\n", + "print(jax.tree_map(lambda x: x.shape, model_params))\n", + "\n", + "# define arbitrary observation's mean/std\n", + "mean_observations = jnp.zeros(observations_dims[-1])\n", + "std_observations = jnp.ones(observations_dims[-1])\n", + "\n", + "# init step of the aurora algorithm\n", + "repertoire, _, random_key = aurora.init(\n", + " init_variables,\n", + " centroids,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " l_value_init,\n", + ")\n", + "\n", + "# initializing means and stds and AURORA\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n", + " subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n", + ")\n", + "\n", + "# design aurora's schedule\n", + "default_update_base = 10\n", + "update_base = int(jnp.ceil(default_update_base / log_freq))\n", + "schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch AURORA iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "current_step_estimation = 0\n", + "num_iterations = 0\n", + "\n", + "# Main loop\n", + "n_target = 1024\n", + "\n", + "previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n", + "\n", + "iteration = 0\n", + "while iteration < max_iterations:\n", + "\n", + " (\n", + " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " ) = jax.lax.scan(\n", + " update_scan_fn,\n", + " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " (),\n", + " length=log_freq,\n", + " )\n", + "\n", + " num_iterations = iteration * log_freq\n", + "\n", + " # update nb steps estimation\n", + " current_step_estimation += batch_size * episode_length * log_freq\n", + "\n", + " # autoencoder steps and CVC\n", + " if (iteration + 1) in schedules:\n", + " # train the autoencoder\n", + " random_key, subkey = jax.random.split(random_key)\n", + " (\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = train_seq2seq.lstm_ae_train(\n", + " subkey,\n", + " repertoire,\n", + " model_params,\n", + " iteration,\n", + " hidden_size=hidden_size,\n", + " batch_size=lstm_batch_size\n", + " )\n", + "\n", + " # re-addition of all the new behavioural descriotpors with the new ae\n", + " normalized_observations = (\n", + " repertoire.observations - mean_observations\n", + " ) / std_observations\n", + "\n", + " new_descriptors = model.apply(\n", + " {\"params\": model_params}, normalized_observations, method=model.encode\n", + " )\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=new_descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=repertoire.l_value,\n", + " )\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " elif iteration % 2 == 0:\n", + " # update the l value\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " # CVC Implementation to keep a constant number of individuals in the archive\n", + " current_error = num_indivs - n_target\n", + " change_rate = current_error - previous_error\n", + " prop_gain = 1 * 10e-6\n", + " l_value = (\n", + " repertoire.l_value\n", + " + (prop_gain * (current_error))\n", + " + (prop_gain * change_rate)\n", + " )\n", + "\n", + " previous_error = current_error\n", + "\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=repertoire.descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=l_value,\n", + " )\n", + "\n", + " iteration += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in metrics.items():\n", + " print(k, \" - \", v[-1])" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb new file mode 100644 index 00000000..11ed6afe --- /dev/null +++ b/examples/pga_aurora.ipynb @@ -0,0 +1,571 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pga_aurora.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with PGA-AURORA in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [PGA-AURORA](https://arxiv.org/abs/2210.03516).\n", + "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an AURORA instance and mix it with the right emitter to define PGA-AURORA\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "from typing import Dict, Any\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.aurora import AURORA\n", + "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import scoring_aurora_function\n", + "from qdax.environments.bd_extractors import get_aurora_bd\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter\n", + "\n", + "from qdax.utils import train_seq2seq\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "env_batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "episode_length = 250 #@param {type:\"integer\"}\n", + "max_iterations = 50 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "\n", + "lstm_batch_size = 128 #@param {type:\"integer\"}\n", + "\n", + "observation_option = \"no_sd\" #@param['no_sd', 'only_sd', 'full']\n", + "hidden_size = 5 #@param {type:\"integer\"}\n", + "l_value_init = 0.2 #@param {type:\"number\"}\n", + "\n", + "traj_sampling_freq = 10 #@param {type:\"integer\"}\n", + "max_observation_size = 25 #@param {type:\"integer\"}\n", + "prior_descriptor_dim = 2 #@param {type:\"integer\"}\n", + "\n", + "proportion_mutation_ga = 0.5 #@param {type:\"number\"}\n", + "\n", + "# TD3 params\n", + "replay_buffer_size = 1000000 #@param {type:\"number\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "critic_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "greedy_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "policy_learning_rate = 1e-3 #@param {type:\"number\"}\n", + "noise_clip = 0.5 #@param {type:\"number\"}\n", + "policy_noise = 0.2 #@param {type:\"number\"}\n", + "discount = 0.99 #@param {type:\"number\"}\n", + "reward_scaling = 1.0 #@param {type:\"number\"}\n", + "transitions_batch_size = 256 #@param {type:\"number\"}\n", + "soft_tau_update = 0.005 #@param {type:\"number\"}\n", + "num_critic_training_steps = 300 #@param {type:\"number\"}\n", + "num_pg_training_steps = 100 #@param {type:\"number\"}\n", + "policy_delay = 2 #@param {type:\"number\"}\n", + "\n", + "log_freq = 5 #@param {type:\"integer\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=env_batch_size)\n", + "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", + "\n", + "\n", + "# Create the initial environment states\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", + "reset_fn = jax.jit(jax.vmap(env.reset))\n", + "init_states = reset_fn(keys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the fonction to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"\n", + " Play an environment step and return the updated state and the transition.\n", + " \"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " \n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = functools.partial(\n", + " get_aurora_bd,\n", + " option=observation_option,\n", + " hidden_size=hidden_size,\n", + " traj_sampling_freq=traj_sampling_freq,\n", + " max_observation_size=max_observation_size,\n", + ")\n", + "scoring_fn = functools.partial(\n", + " scoring_aurora_function,\n", + " init_states=init_states,\n", + " episode_length=episode_length,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict:\n", + "\n", + " # Get metrics\n", + " grid_empty = repertoire.fitnesses == -jnp.inf\n", + " qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty)\n", + " # Add offset for positive qd_score\n", + " qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty)\n", + " coverage = 100 * jnp.mean(1.0 - grid_empty)\n", + " max_fitness = jnp.max(repertoire.fitnesses)\n", + "\n", + " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the PG-emitter config\n", + "pga_emitter_config = PGAMEConfig(\n", + " env_batch_size=env_batch_size,\n", + " batch_size=transitions_batch_size,\n", + " proportion_mutation_ga=proportion_mutation_ga,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " critic_learning_rate=critic_learning_rate,\n", + " greedy_learning_rate=greedy_learning_rate,\n", + " policy_learning_rate=policy_learning_rate,\n", + " noise_clip=noise_clip,\n", + " policy_noise=policy_noise,\n", + " discount=discount,\n", + " reward_scaling=reward_scaling,\n", + " replay_buffer_size=replay_buffer_size,\n", + " soft_tau_update=soft_tau_update,\n", + " num_critic_training_steps=num_critic_training_steps,\n", + " num_pg_training_steps=num_pg_training_steps,\n", + " policy_delay=policy_delay,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "\n", + "pg_emitter = PGAMEEmitter(\n", + " config=pga_emitter_config,\n", + " policy_network=policy_network,\n", + " env=env,\n", + " variation_fn=variation_fn,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the MAP Elites algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate AURORA\n", + "aurora = AURORA(\n", + " scoring_function=scoring_fn,\n", + " emitter=pg_emitter,\n", + " metrics_function=metrics_fn,\n", + ")\n", + "\n", + "aurora_dims = hidden_size\n", + "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", + "\n", + "@jax.jit\n", + "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + " \"\"\"Scan the udpate function.\"\"\"\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = carry\n", + "\n", + " # update\n", + " (repertoire, emitter_state, metrics, random_key,) = aurora.update(\n", + " repertoire,\n", + " emitter_state,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " )\n", + "\n", + " return (\n", + " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " )\n", + "\n", + "# Init algorithm\n", + "# AutoEncoder Params and INIT\n", + "obs_dim = jnp.minimum(env.observation_size, max_observation_size)\n", + "if observation_option == \"full\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim + prior_descriptor_dim,\n", + " )\n", + "elif observation_option == \"no_sd\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim,\n", + " )\n", + "elif observation_option == \"only_sd\":\n", + " observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim)\n", + "else:\n", + " ValueError(\"The chosen option is not correct.\")\n", + "\n", + "# define the seq2seq model\n", + "model = train_seq2seq.get_model(\n", + " observations_dims[-1], True, hidden_size=hidden_size\n", + ")\n", + "\n", + "# init the model params\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params = train_seq2seq.get_initial_params(\n", + " model, subkey, (1, *observations_dims)\n", + ")\n", + "\n", + "print(jax.tree_map(lambda x: x.shape, model_params))\n", + "\n", + "# define arbitrary observation's mean/std\n", + "mean_observations = jnp.zeros(observations_dims[-1])\n", + "std_observations = jnp.ones(observations_dims[-1])\n", + "\n", + "# init step of the aurora algorithm\n", + "repertoire, emitter_state, random_key = aurora.init(\n", + " init_variables,\n", + " centroids,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " l_value_init,\n", + ")\n", + "\n", + "# initializing means and stds and AURORA\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n", + " subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n", + ")\n", + "\n", + "# design aurora's schedule\n", + "default_update_base = 10\n", + "update_base = int(jnp.ceil(default_update_base / log_freq))\n", + "schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch AURORA iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "current_step_estimation = 0\n", + "num_iterations = 0\n", + "\n", + "# Main loop\n", + "n_target = 1024\n", + "\n", + "previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n", + "\n", + "iteration = 0\n", + "while iteration < max_iterations:\n", + "\n", + " (\n", + " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " ) = jax.lax.scan(\n", + " update_scan_fn,\n", + " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " (),\n", + " length=log_freq,\n", + " )\n", + "\n", + " num_iterations = iteration * log_freq\n", + "\n", + " # update nb steps estimation\n", + " current_step_estimation += env_batch_size * episode_length * log_freq\n", + "\n", + " # autoencoder steps and CVC\n", + " if (iteration + 1) in schedules:\n", + " # train the autoencoder\n", + " random_key, subkey = jax.random.split(random_key)\n", + " (\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = train_seq2seq.lstm_ae_train(\n", + " subkey,\n", + " repertoire,\n", + " model_params,\n", + " iteration,\n", + " hidden_size=hidden_size,\n", + " batch_size=lstm_batch_size\n", + " )\n", + "\n", + " # re-addition of all the new behavioural descriotpors with the new ae\n", + " normalized_observations = (\n", + " repertoire.observations - mean_observations\n", + " ) / std_observations\n", + "\n", + " new_descriptors = model.apply(\n", + " {\"params\": model_params}, normalized_observations, method=model.encode\n", + " )\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=new_descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=repertoire.l_value,\n", + " )\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " elif iteration % 2 == 0:\n", + " # update the l value\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " # CVC Implementation to keep a constant number of individuals in the archive\n", + " current_error = num_indivs - n_target\n", + " change_rate = current_error - previous_error\n", + " prop_gain = 1 * 10e-6\n", + " l_value = (\n", + " repertoire.l_value\n", + " + (prop_gain * (current_error))\n", + " + (prop_gain * change_rate)\n", + " )\n", + "\n", + " previous_error = current_error\n", + "\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=repertoire.descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=l_value,\n", + " )\n", + "\n", + " iteration += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in metrics.items():\n", + " print(k, \" - \", v[-1])" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 7a51a0bd..ab1ae221 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -104,8 +104,7 @@ "min_bd = 0. #@param {type:\"number\"}\n", "max_bd = 1.0 #@param {type:\"number\"}\n", "\n", - "#@title PGA-ME Emitter Definitions Fields\n", - "proportion_mutation_ga = 0.5\n", + "proportion_mutation_ga = 0.5 #@param {type:\"number\"}\n", "\n", "# TD3 params\n", "env_batch_size = 100 #@param {type:\"number\"}\n", diff --git a/mkdocs.yml b/mkdocs.yml index 2c0bbdb6..b71ad0b0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -125,6 +125,8 @@ nav: - SMERL: examples/smerl.ipynb - CMA ES: examples/cmaes.ipynb - NSGA2/SPEA2: examples/nsga2_spea2.ipynb + - AURORA: examples/aurora.ipynb + - PGA AURORA: examples/pga_aurora.ipynb - PBT: examples/sac_pbt.ipynb - MAPElites PBT: examples/me_sac_pbt.ipynb - Jumanji Snake: examples/jumanji_snake.ipynb @@ -139,6 +141,8 @@ nav: - CMA MEGA: api_documentation/core/cma_mega.md - MOME: api_documentation/core/mome.md - ME ES: api_documentation/core/mees.md + - AURORA: api_documentation/core/aurora.md + - PGA AURORA: api_documentation/core/pga_aurora.md - ME PBT: api_documentation/core/me_pbt.md - ME LS: api_documentation/core/mels.md - Baseline algorithms: diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py new file mode 100644 index 00000000..fed716e3 --- /dev/null +++ b/qdax/core/aurora.py @@ -0,0 +1,244 @@ +"""Core class of the AURORA algorithm.""" + +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional, Tuple + +import jax +import jax.numpy as jnp +from chex import ArrayTree + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.environments.bd_extractors import AuroraExtraInfo +from qdax.types import ( + Descriptor, + Fitness, + Genotype, + Metrics, + Observation, + Params, + RNGKey, +) + + +class AURORA: + """Core elements of the AURORA algorithm. + + Args: + scoring_function: a function that takes a batch of genotypes and compute + their fitnesses and descriptors + emitter: an emitter is used to suggest offsprings given a MAPELites + repertoire. It has two compulsory functions. A function that takes + emits a new population, and a function that update the internal state + of the emitter. + metrics_function: a function that takes a repertoire and computes + any useful metric to track its evolution + """ + + def __init__( + self, + scoring_function: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ArrayTree, RNGKey], + ], + emitter: Emitter, + metrics_function: Callable[[MapElitesRepertoire], Metrics], + encoder_function: Callable[[Observation, AuroraExtraInfo], Descriptor], + training_function: Callable[ + [RNGKey, UnstructuredRepertoire, Params, int], AuroraExtraInfo + ], + ) -> None: + self._scoring_function = scoring_function + self._emitter = emitter + self._metrics_function = metrics_function + self._encoder_fn = encoder_function + self._train_fn = training_function + + def train( + self, + repertoire: UnstructuredRepertoire, + model_params: Params, + iteration: int, + random_key: RNGKey, + ) -> Tuple[UnstructuredRepertoire, AuroraExtraInfo]: + random_key, subkey = jax.random.split(random_key) + aurora_extra_info = self._train_fn( + random_key, + repertoire, + model_params, + iteration, + ) + + # re-addition of all the new behavioural descriptors with the new ae + new_descriptors = self._encoder_fn(repertoire.observations, aurora_extra_info) + + return ( + repertoire.init( + genotypes=repertoire.genotypes, + fitnesses=repertoire.fitnesses, + descriptors=new_descriptors, + observations=repertoire.observations, + l_value=repertoire.l_value, + max_size=repertoire.max_size, + ), + aurora_extra_info, + ) + + @partial(jax.jit, static_argnames=("self",)) + def container_size_control( + self, + repertoire: UnstructuredRepertoire, + target_size: int, + previous_error: jnp.ndarray, + ) -> Tuple[UnstructuredRepertoire, jnp.ndarray]: + # update the l value + num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + + # CVC Implementation to keep a constant number of individuals in the archive + current_error = num_indivs - target_size + change_rate = current_error - previous_error + prop_gain = 1 * 10e-6 + l_value = ( + repertoire.l_value + (prop_gain * current_error) + (prop_gain * change_rate) + ) + + repertoire = repertoire.init( + genotypes=repertoire.genotypes, + fitnesses=repertoire.fitnesses, + descriptors=repertoire.descriptors, + observations=repertoire.observations, + l_value=l_value, + max_size=repertoire.max_size, + ) + + return repertoire, current_error + + def init( + self, + init_genotypes: Genotype, + aurora_extra_info: AuroraExtraInfo, + l_value: jnp.ndarray, + max_size: int, + random_key: RNGKey, + ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], AuroraExtraInfo, RNGKey]: + """Initialize an unstructured repertoire with an initial population of + genotypes. Also performs the first training of the AURORA encoder. + + Args: + init_genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + aurora_extra_info: information to perform AURORA encodings, + such as the encoder parameters + l_value: threshold distance for the unstructured repertoire + max_size: maximum size of the repertoire + random_key: a random key used for stochastic operations. + + Returns: + an initialized unstructured repertoire, with the initial state of + the emitter, and the updated information to perform AURORA encodings + """ + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + init_genotypes, + random_key, + ) + + observations = extra_scores["last_valid_observations"] + + descriptors = self._encoder_fn(observations, aurora_extra_info) + + repertoire = UnstructuredRepertoire.init( + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + observations=observations, + l_value=l_value, + max_size=max_size, + ) + + # get initial state of the emitter + emitter_state, random_key = self._emitter.init( + init_genotypes=init_genotypes, random_key=random_key + ) + + # update emitter state + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + + random_key, subkey = jax.random.split(random_key) + repertoire, updated_aurora_extra_info = self.train( + repertoire, aurora_extra_info.model_params, iteration=0, random_key=subkey + ) + + return repertoire, emitter_state, updated_aurora_extra_info, random_key + + @partial(jax.jit, static_argnames=("self",)) + def update( + self, + repertoire: MapElitesRepertoire, + emitter_state: Optional[EmitterState], + random_key: RNGKey, + aurora_extra_info: AuroraExtraInfo, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + """Main step of the AURORA algorithm. + + + Performs one iteration of the AURORA algorithm. + 1. A batch of genotypes is sampled in the archive and the genotypes are copied. + 2. The copies are mutated and crossed-over + 3. The obtained offsprings are scored and then added to the archive. + + Args: + repertoire: unstructured repertoire + emitter_state: state of the emitter + random_key: a jax PRNG random key + aurora_extra_info: extra info for computing encodings + + Results: + the updated MAP-Elites repertoire + the updated (if needed) emitter state + metrics about the updated repertoire + a new key + """ + # generate offsprings with the emitter + genotypes, random_key = self._emitter.emit( + repertoire, emitter_state, random_key + ) + # scores the offsprings + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + genotypes, + random_key, + ) + + observations = extra_scores["last_valid_observations"] + + descriptors = self._encoder_fn(observations, aurora_extra_info) + + # add genotypes and observations in the repertoire + repertoire = repertoire.add( + genotypes, + descriptors, + fitnesses, + observations, + ) + + # update emitter state after scoring is made + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + + # update the metrics + metrics = self._metrics_function(repertoire) + + return repertoire, emitter_state, metrics, random_key diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py new file mode 100644 index 00000000..f4cc0c98 --- /dev/null +++ b/qdax/core/containers/unstructured_repertoire.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +from functools import partial +from typing import Callable, Tuple + +import flax.struct +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from qdax.types import Centroid, Descriptor, Fitness, Genotype, Observation, RNGKey + + +@partial(jax.jit, static_argnames=("k_nn",)) +def get_cells_indices( + batch_of_descriptors: Descriptor, centroids: Centroid, k_nn: int +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns the array of cells indices for a batch of descriptors + given the centroids of the grid. + + Args: + batch_of_descriptors: a batch of descriptors + of shape (batch_size, num_descriptors) + centroids: centroids array of shape (num_centroids, num_descriptors) + + Returns: + the indices of the centroids corresponding to each vector of descriptors + in the batch with shape (batch_size,) + """ + + def _get_cells_indices( + _descriptors: jnp.ndarray, + _centroids: jnp.ndarray, + _k_nn: int, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Inner function. + + descriptors of shape (1, num_descriptors) + centroids of shape (num_centroids, num_descriptors) + """ + + distances = jax.vmap(jnp.linalg.norm)(_descriptors - _centroids) + + # Negating distances because we want the smallest ones + min_dist, min_args = jax.lax.top_k(-1 * distances, _k_nn) + + return min_args, -1 * min_dist + + func = jax.vmap( + _get_cells_indices, + in_axes=( + 0, + None, + None, + ), + ) + + return func(batch_of_descriptors, centroids, k_nn) # type: ignore + + +@jax.jit +def intra_batch_comp( + normed: jnp.ndarray, + current_index: jnp.ndarray, + normed_all: jnp.ndarray, + eval_scores: jnp.ndarray, + l_value: jnp.ndarray, +) -> jnp.ndarray: + """Function to know if an individual should be kept or not.""" + + # Check for individuals that are Nans, we remove them at the end + not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) + + # Fill in Nans to do computations + normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed) + eval_scores = jnp.where( + jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores + ) + + # If we do not use a fitness (i.e same fitness everywhere), we create a virtual + # fitness function to add individuals with the same bd + additional_score = jnp.where( + jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 + ) + additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0]) + + # Add scores to empty individuals + eval_scores = jnp.where( + jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores + ) + # Virtual eval_scores + eval_scores = eval_scores + additional_scores + + # For each point we check what other points are the closest ones. + knn_relevant_scores, knn_relevant_indices = jax.lax.top_k( + -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0] + ) + # We negated the scores to use top_k so we reverse it. + knn_relevant_scores = knn_relevant_scores * -1 + + # Check if the individual is close enough to compare (under l-value) + fitness = jnp.where(jnp.squeeze(knn_relevant_scores < l_value), True, False) + + # We want to eliminate the same individual (distance 0) + fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) + current_fitness = jnp.squeeze( + eval_scores.at[knn_relevant_indices.at[0].get()].get() + ) + + # Is the fitness of the other individual higher? + # If both are True then we discard the current individual since this individual + # would be replaced by the better one. + discard_indiv = jnp.logical_and( + jnp.where( + eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False + ), + fitness, + ).any() + + # Discard Individuals with Nans as their BD (mainly for the readdition where we + # have NaN bds) + discard_indiv = jnp.logical_or(discard_indiv, not_existent) + + # Negate to know if we keep the individual + return jnp.logical_not(discard_indiv) + + +class UnstructuredRepertoire(flax.struct.PyTreeNode): + """ + Class for the unstructured repertoire in Map Elites. + + Args: + genotypes: a PyTree containing all the genotypes in the repertoire ordered + by the centroids. Each leaf has a shape (num_centroids, num_features). The + PyTree can be a simple Jax array or a more complex nested structure such + as to represent parameters of neural network in Flax. + fitnesses: an array that contains the fitness of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + descriptors: an array that contains the descriptors of solutions in each cell + of the repertoire, ordered by centroids. The array shape + is (num_centroids, num_descriptors). + centroids: an array the contains the centroids of the tesselation. The array + shape is (num_centroids, num_descriptors). + observations: observations that the genotype gathered in the environment. + """ + + genotypes: Genotype + fitnesses: Fitness + descriptors: Descriptor + observations: Observation + l_value: jnp.ndarray + max_size: int = flax.struct.field(pytree_node=False) + + def get_maximal_size(self) -> int: + """Returns the maximal number of individuals in the repertoire.""" + return self.max_size + + def get_number_genotypes(self) -> jnp.ndarray: + """Returns the number of genotypes in the repertoire.""" + return jnp.sum(self.fitnesses != -jnp.inf) + + def save(self, path: str = "./") -> None: + """Saves the grid on disk in the form of .npy files. + + Flattens the genotypes to store it with .npy format. Supposes that + a user will have access to the reconstruction function when loading + the genotypes. + + Args: + path: Path where the data will be saved. Defaults to "./". + """ + + def flatten_genotype(genotype: Genotype) -> jnp.ndarray: + flatten_genotype, _unravel_pytree = ravel_pytree(genotype) + return flatten_genotype + + # flatten all the genotypes + flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes) + + # save data + jnp.save(path + "genotypes.npy", flat_genotypes) + jnp.save(path + "fitnesses.npy", self.fitnesses) + jnp.save(path + "descriptors.npy", self.descriptors) + jnp.save(path + "observations.npy", self.observations) + jnp.save(path + "l_value.npy", self.l_value) + jnp.save(path + "max_size.npy", self.max_size) + + @classmethod + def load( + cls, reconstruction_fn: Callable, path: str = "./" + ) -> UnstructuredRepertoire: + """Loads an unstructured repertoire. + + Args: + reconstruction_fn: Function to reconstruct a PyTree + from a flat array. + path: Path where the data is saved. Defaults to "./". + + Returns: + An unstructured repertoire. + """ + + flat_genotypes = jnp.load(path + "genotypes.npy") + genotypes = jax.vmap(reconstruction_fn)(flat_genotypes) + + fitnesses = jnp.load(path + "fitnesses.npy") + descriptors = jnp.load(path + "descriptors.npy") + observations = jnp.load(path + "observations.npy") + l_value = jnp.load(path + "l_value.npy") + max_size = int(jnp.load(path + "max_size.npy").item()) + + return UnstructuredRepertoire( + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + observations=observations, + l_value=l_value, + max_size=max_size, + ) + + @jax.jit + def add( + self, + batch_of_genotypes: Genotype, + batch_of_descriptors: Descriptor, + batch_of_fitnesses: Fitness, + batch_of_observations: Observation, + ) -> UnstructuredRepertoire: + """Adds a batch of genotypes to the repertoire. + + Args: + batch_of_genotypes: genotypes of the individuals to be considered + for addition in the repertoire. + batch_of_descriptors: associated descriptors. + batch_of_fitnesses: associated fitness. + batch_of_observations: associated observations. + + Returns: + A new unstructured repertoire where the relevant individuals have been + added. + """ + + # We need to replace all the descriptors that are not filled with jnp inf + filtered_descriptors = jnp.where( + jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1), + jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf), + self.descriptors, + ) + + batch_of_indices, batch_of_distances = get_cells_indices( + batch_of_descriptors, filtered_descriptors, 2 + ) + + # Save the second-nearest neighbours to check a condition + second_neighbours = batch_of_distances.at[..., 1].get() + + # Keep the Nearest neighbours + batch_of_indices = batch_of_indices.at[..., 0].get() + + # Keep the Nearest neighbours + batch_of_distances = batch_of_distances.at[..., 0].get() + + # We remove individuals that are too close to the second nn. + # This avoids having clusters of individuals after adding them. + not_novel_enough = jnp.where( + jnp.squeeze(second_neighbours <= self.l_value), True, False + ) + + # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) + batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1) + batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1) + + # TODO: Doesn't Work if Archive is full. Need to use the closest individuals + # in that case. + empty_indexes = jnp.squeeze( + jnp.nonzero( + jnp.where(jnp.isinf(self.fitnesses), 1, 0), + size=batch_of_indices.shape[0], + fill_value=-1, + )[0] + ) + batch_of_indices = jnp.where( + jnp.squeeze(batch_of_distances <= self.l_value), + jnp.squeeze(batch_of_indices), + -1, + ) + + # We get all the indices of the empty bds first and then the filled ones + # (because of -1) + sorted_bds = jax.lax.top_k( + -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] + )[1] + batch_of_indices = jnp.where( + jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value), + batch_of_indices.at[sorted_bds].get(), + empty_indexes, + ) + + batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) + + # ReIndexing of all the inputs to the correct sorted way + batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() + batch_of_genotypes = jax.tree_map( + lambda x: x.at[sorted_bds].get(), batch_of_genotypes + ) + batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() + batch_of_observations = batch_of_observations.at[sorted_bds].get() + not_novel_enough = not_novel_enough.at[sorted_bds].get() + + # Check to find Individuals with same BD within the Batch + keep_indiv = jax.jit( + jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0)) + )( + batch_of_descriptors.squeeze(), + jnp.arange( + 0, batch_of_descriptors.shape[0], 1 + ), # keep track of where we are in the batch to assure right comparisons + batch_of_descriptors.squeeze(), + batch_of_fitnesses.squeeze(), + self.l_value, + ) + + keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough)) + + # get fitness segment max + best_fitnesses = jax.ops.segment_max( + batch_of_fitnesses, + batch_of_indices.astype(jnp.int32).squeeze(), + num_segments=self.max_size, + ) + + cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0) + + # put dominated fitness to -jnp.inf + batch_of_fitnesses = jnp.where( + batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + ) + + # get addition condition + grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) + current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0) + addition_condition = batch_of_fitnesses > current_fitnesses + addition_condition = jnp.logical_and( + addition_condition, jnp.expand_dims(keep_indiv, axis=-1) + ) + + # assign fake position when relevant : num_centroids is out of bounds + batch_of_indices = jnp.where( + addition_condition, + x=batch_of_indices, + y=self.max_size, + ) + + # create new grid + new_grid_genotypes = jax.tree_map( + lambda grid_genotypes, new_genotypes: grid_genotypes.at[ + batch_of_indices.squeeze() + ].set(new_genotypes), + self.genotypes, + batch_of_genotypes, + ) + + # compute new fitness and descriptors + new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set( + batch_of_fitnesses.squeeze() + ) + new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set( + batch_of_descriptors.squeeze() + ) + + new_observations = self.observations.at[batch_of_indices.squeeze()].set( + batch_of_observations.squeeze() + ) + + return UnstructuredRepertoire( + genotypes=new_grid_genotypes, + fitnesses=new_fitnesses.squeeze(), + descriptors=new_descriptors.squeeze(), + observations=new_observations.squeeze(), + l_value=self.l_value, + max_size=self.max_size, + ) + + @partial(jax.jit, static_argnames=("num_samples",)) + def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + """Sample elements in the repertoire. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + random_key, sub_key = jax.random.split(random_key) + grid_empty = self.fitnesses == -jnp.inf + p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) + + samples = jax.tree_map( + lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), + self.genotypes, + ) + + return samples, random_key + + @classmethod + def init( + cls, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + observations: Observation, + l_value: jnp.ndarray, + max_size: int, + ) -> UnstructuredRepertoire: + """Initialize a Map-Elites repertoire with an initial population of genotypes. + Requires the definition of centroids that can be computed with any method + such as CVT or Euclidean mapping. + + Args: + genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + fitnesses: fitness of the initial genotypes of shape (batch_size,) + descriptors: descriptors of the initial genotypes + of shape (batch_size, num_descriptors) + observations: observations experienced in the evaluation task. + l_value: threshold distance of the repertoire. + max_size: maximal size of the container + + Returns: + an initialized unstructured repertoire. + """ + + # Initialize grid with default values + default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) + default_genotypes = jax.tree_map( + lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan), + genotypes, + ) + default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1])) + + default_observations = jnp.full( + shape=(max_size,) + observations.shape[1:], fill_value=jnp.nan + ) + + repertoire = UnstructuredRepertoire( + genotypes=default_genotypes, + fitnesses=default_fitnesses, + descriptors=default_descriptors, + observations=default_observations, + l_value=l_value, + max_size=max_size, + ) + + return repertoire.add( # type: ignore + genotypes, descriptors, fitnesses, observations + ) diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py new file mode 100644 index 00000000..ea7618ba --- /dev/null +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -0,0 +1,202 @@ +"""seq2seq example: Mode code. + +Inspired by Flax library - +https://github.com/google/flax/blob/main/examples/seq2seq/models.py + +Copyright 2022 The Flax Authors. +Licensed under the Apache License, Version 2.0 (the "License") +""" + + +import functools +from typing import Any, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from flax import linen as nn + +Array = Any +PRNGKey = Any + + +class EncoderLSTM(nn.Module): + """EncoderLSTM Module wrapped in a lifted scan transform.""" + + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=1, + out_axes=1, + split_rngs={"params": False}, + ) + @nn.compact + def __call__( + self, carry: Tuple[Array, Array], x: Array + ) -> Tuple[Tuple[Array, Array], Array]: + """Applies the module.""" + lstm_state, is_eos = carry + features = lstm_state[0].shape[-1] + new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) + + def select_carried_state(new_state: Array, old_state: Array) -> Array: + return jnp.where(is_eos[:, np.newaxis], old_state, new_state) + + # LSTM state is a tuple (c, h). + carried_lstm_state = tuple( + select_carried_state(*s) for s in zip(new_lstm_state, lstm_state) + ) + + return (carried_lstm_state, is_eos), y + + @staticmethod + def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: + # Use a dummy key since the default state init fn is just zeros. + return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore + jax.random.PRNGKey(0), (batch_size, hidden_size) + ) + + +class Encoder(nn.Module): + """LSTM encoder, returning state after finding the EOS token in the input.""" + + hidden_size: int + + @nn.compact + def __call__(self, inputs: Array) -> Array: + batch_size = inputs.shape[0] + lstm = EncoderLSTM(name="encoder_lstm") + init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) + + # We use the `is_eos` array to determine whether the encoder should carry + # over the last lstm state, or apply the LSTM cell on the previous state. + init_is_eos = jnp.zeros(batch_size, dtype=bool) + init_carry = (init_lstm_state, init_is_eos) + (final_state, _), _ = lstm(init_carry, inputs) + + return final_state + + +class DecoderLSTM(nn.Module): + """DecoderLSTM Module wrapped in a lifted scan transform. + + Attributes: + teacher_force: See docstring on Seq2seq module. + obs_size: Size of the observations. + """ + + teacher_force: bool + obs_size: int + + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=1, + out_axes=1, + split_rngs={"params": False, "lstm": True}, + ) + @nn.compact + def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: + """Applies the DecoderLSTM model.""" + + lstm_state, last_prediction = carry + if not self.teacher_force: + x = last_prediction + + features = lstm_state[0].shape[-1] + new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) + + logits = nn.Dense(features=self.obs_size)(y) + + return (lstm_state, logits), (logits, logits) + + +class Decoder(nn.Module): + """LSTM decoder. + + Attributes: + init_state: [batch_size, hidden_size] + Initial state of the decoder (i.e., the final state of the encoder). + teacher_force: See docstring on Seq2seq module. + obs_size: Size of the observations. + """ + + teacher_force: bool + obs_size: int + + @nn.compact + def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]: + """Applies the decoder model. + + Args: + inputs: [batch_size, max_output_len-1, obs_size] + Contains the inputs to the decoder at each time step (only used when not + using teacher forcing). Since each token at position i is fed as input + to the decoder at position i+1, the last token is not provided. + + Returns: + Pair (logits, predictions), which are two arrays of respectively decoded + logits and predictions (in one hot-encoding format). + """ + lstm = DecoderLSTM(teacher_force=self.teacher_force, obs_size=self.obs_size) + init_carry = (init_state, inputs[:, 0]) + _, (logits, predictions) = lstm(init_carry, inputs) + return logits, predictions + + +class Seq2seq(nn.Module): + """Sequence-to-sequence class using encoder/decoder architecture. + + Attributes: + teacher_force: whether to use `decoder_inputs` as input to the decoder at + every step. If False, only the first input (i.e., the "=" token) is used, + followed by samples taken from the previous output logits. + hidden_size: int, the number of hidden dimensions in the encoder and decoder + LSTMs. + obs_size: the size of the observations. + eos_id: EOS id. + """ + + teacher_force: bool + hidden_size: int + obs_size: int + + def setup(self) -> None: + self.encoder = Encoder(hidden_size=self.hidden_size) + self.decoder = Decoder(teacher_force=self.teacher_force, obs_size=self.obs_size) + + @nn.compact + def __call__( + self, encoder_inputs: Array, decoder_inputs: Array + ) -> Tuple[Array, Array]: + """Applies the seq2seq model. + + Args: + encoder_inputs: [batch_size, max_input_length, obs_size]. + padded batch of input sequences to encode. + decoder_inputs: [batch_size, max_output_length, obs_size]. + padded batch of expected decoded sequences for teacher forcing. + When sampling (i.e., `teacher_force = False`), only the first token is + input into the decoder (which is the token "="), and samples are used + for the following inputs. The second dimension of this tensor determines + how many steps will be decoded, regardless of the value of + `teacher_force`. + + Returns: + Pair (logits, predictions), which are two arrays of length `batch_size` + containing respectively decoded logits and predictions (in one hot + encoding format). + """ + # encode inputs + init_decoder_state = self.encoder(encoder_inputs) + + # decode outputs + logits, predictions = self.decoder(decoder_inputs, init_decoder_state) + + return logits, predictions + + def encode(self, encoder_inputs: Array) -> Array: + # encode inputs + init_decoder_state = self.encoder(encoder_inputs) + final_output, _hidden_state = init_decoder_state + return final_output diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index 4ec159c6..af1d51ba 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -1,8 +1,11 @@ +from __future__ import annotations + +import flax.struct import jax import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.types import Descriptor +from qdax.types import Descriptor, Params def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: @@ -36,3 +39,66 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri descriptors = descriptors / jnp.sum(1.0 - mask, axis=1) return descriptors + + +class AuroraExtraInfo(flax.struct.PyTreeNode): + """ + Information specific to the AURORA algorithm. + + Args: + model_params: the parameters of the dimensionality reduction model + """ + + model_params: Params + + +class AuroraExtraInfoNormalization(AuroraExtraInfo): + """ + Information specific to the AURORA algorithm. In particular, it contains + the normalization parameters for the observations. + + Args: + model_params: the parameters of the dimensionality reduction model + mean_observations: the mean of observations + std_observations: the std of observations + """ + + mean_observations: jnp.ndarray + std_observations: jnp.ndarray + + @classmethod + def create( + cls, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + ) -> AuroraExtraInfoNormalization: + return cls( + model_params=model_params, + mean_observations=mean_observations, + std_observations=std_observations, + ) + + +def get_aurora_encoding( + observations: jnp.ndarray, + aurora_extra_info: AuroraExtraInfoNormalization, + model: flax.linen.Module, +) -> Descriptor: + """ + Compute final aurora embedding. + + This function suppose that state descriptor is the xy position, as it + just select the final one of the state descriptors given. + """ + model_params = aurora_extra_info.model_params + mean_observations = aurora_extra_info.mean_observations + std_observations = aurora_extra_info.std_observations + + # lstm seq2seq + normalized_observations = (observations - mean_observations) / std_observations + descriptors = model.apply( + {"params": model_params}, normalized_observations, method=model.encode + ) + + return descriptors.squeeze() diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index fbffad5f..ec32e7a2 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -85,11 +85,16 @@ } """ +try: + HALFCHEETAH_SYSTEM_CONFIG = brax.envs.halfcheetah._SYSTEM_CONFIG +except AttributeError: + HALFCHEETAH_SYSTEM_CONFIG = brax.envs.half_cheetah._SYSTEM_CONFIG + # storing the classic env configurations # those are the configs from the official brax repo ENV_SYSTEM_CONFIG = { "ant": brax.envs.ant._SYSTEM_CONFIG, - "halfcheetah": brax.envs.half_cheetah._SYSTEM_CONFIG, + "halfcheetah": HALFCHEETAH_SYSTEM_CONFIG, "walker2d": brax.envs.walker2d._SYSTEM_CONFIG, "hopper": brax.envs.hopper._SYSTEM_CONFIG, # "humanoid": brax.envs.humanoid._SYSTEM_CONFIG, diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 3c84ccef..931ee9d3 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -9,7 +9,7 @@ import qdax.environments from qdax import environments -from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition from qdax.core.neuroevolution.mdp_utils import generate_unroll from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( @@ -18,6 +18,7 @@ ExtraScores, Fitness, Genotype, + Observation, Params, RNGKey, ) @@ -82,6 +83,15 @@ def default_play_step_fn( return default_play_step_fn +def get_mask_from_transitions( + data: Transition, +) -> jnp.ndarray: + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + return mask + + @partial( jax.jit, static_argnames=( @@ -134,9 +144,7 @@ def scoring_function_brax_envs( _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) # create a mask to extract data properly - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) + mask = get_mask_from_transitions(data) # scores fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) @@ -267,10 +275,10 @@ def create_brax_scoring_fn( init_state = env.reset(subkey) # Define the function to deterministically reset the environment - def deterministic_reset(key: RNGKey, init_state: EnvState) -> EnvState: - return init_state + def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: + return _init_state - play_reset_fn = partial(deterministic_reset, init_state=init_state) + play_reset_fn = partial(deterministic_reset, _init_state=init_state) # Stochastic case elif play_reset_fn is None: @@ -340,3 +348,36 @@ def create_default_brax_task_components( ) return env, policy_network, scoring_fn, random_key + + +def get_aurora_scoring_fn( + scoring_fn: Callable[ + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + ], + observation_extractor_fn: Callable[[Transition], Observation], +) -> Callable[ + [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] +]: + """Evaluates policies contained in flatten_variables in parallel + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarly + evaluated with the same environment everytime, this won't be determinist. + + When the init states are different, this is not purely stochastic. This + choice was made for performance reason, as the reset function of brax envs + is quite time-consuming. If pure stochasticity of the environment is needed + for a use case, please open an issue. + """ + + @functools.wraps(scoring_fn) + def _wrapper( + params: Params, random_key: RNGKey # Perform rollouts with each policy + ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: + fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) + data = extra_scores["transitions"] + observation = observation_extractor_fn(data) # type: ignore + extra_scores["last_valid_observations"] = observation + return fitnesses, None, extra_scores, random_key + + return _wrapper diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py new file mode 100644 index 00000000..acb14a9b --- /dev/null +++ b/qdax/utils/train_seq2seq.py @@ -0,0 +1,208 @@ +"""seq2seq addition example + +Inspired by Flax library - +https://github.com/google/flax/blob/main/examples/seq2seq/train.py + +Copyright 2022 The Flax Authors. +Licensed under the Apache License, Version 2.0 (the "License") +""" + +from typing import Any, Dict, Tuple + +import jax +import jax.numpy as jnp +import optax +from flax.training import train_state + +from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire +from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq +from qdax.environments.bd_extractors import AuroraExtraInfoNormalization +from qdax.types import Params, RNGKey + +Array = Any +PRNGKey = Any + + +def get_model( + obs_size: int, teacher_force: bool = False, hidden_size: int = 10 +) -> Seq2seq: + """ + Returns a seq2seq model. + + Args: + obs_size: the size of the observation. + teacher_force: whether to use teacher forcing. + hidden_size: the size of the hidden layer (i.e. the encoding). + """ + return Seq2seq( + teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size + ) + + +def get_initial_params( + model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...] +) -> Dict[str, Any]: + """ + Returns the initial parameters of a seq2seq model. + + Args: + model: the seq2seq model. + random_key: the random number generator. + encoder_input_shape: the shape of the encoder input. + """ + random_key, rng1, rng2, rng3 = jax.random.split(random_key, 4) + variables = model.init( + {"params": rng1, "lstm": rng2, "dropout": rng3}, + jnp.ones(encoder_input_shape, jnp.float32), + jnp.ones(encoder_input_shape, jnp.float32), + ) + return variables["params"] # type: ignore + + +@jax.jit +def train_step( + state: train_state.TrainState, + batch: Array, + lstm_random_key: PRNGKey, +) -> Tuple[train_state.TrainState, Dict[str, float]]: + """ + Trains for one step. + + Args: + state: the training state. + batch: the batch of data. + lstm_random_key: the random number key. + """ + + """Trains one step.""" + lstm_key = jax.random.fold_in(lstm_random_key, state.step) + dropout_key, lstm_key = jax.random.split(lstm_key, 2) + + # Shift input by one to avoid leakage + batch_decoder = jnp.roll(batch, shift=1, axis=1) + + # Large number as zero token + batch_decoder = batch_decoder.at[:, 0, :].set(-1000) + + def loss_fn(params: Params) -> Tuple[jnp.ndarray, jnp.ndarray]: + logits, _ = state.apply_fn( + {"params": params}, + batch, + batch_decoder, + rngs={"lstm": lstm_key, "dropout": dropout_key}, + ) + + def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return jnp.inner(y - x, y - x) / x.shape[-1] + + res = jax.vmap(mean_squared_error)( + jnp.reshape(logits.at[:, :-1, ...].get(), (logits.shape[0], -1)), + jnp.reshape( + batch_decoder.at[:, 1:, ...].get(), (batch_decoder.shape[0], -1) + ), + ) + loss = jnp.mean(res, axis=0) + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss_val, _logits), grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + + return state, loss_val + + +def lstm_ae_train( + random_key: RNGKey, + repertoire: UnstructuredRepertoire, + params: Params, + epoch: int, + model: Seq2seq, + batch_size: int = 128, +) -> AuroraExtraInfoNormalization: + + if epoch > 100: + num_epochs = 25 + alpha = 0.0001 # Gradient step size + else: + num_epochs = 100 + alpha = 0.0001 + + # compute mean/std of the obs for normalization + mean_obs = jnp.nanmean(repertoire.observations, axis=(0, 1)) + std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) + # the std where they were NaNs was set to zero. But here we divide by the + # std, so we replace the zeros by inf here. + std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) + + # TODO: maybe we could just compute this data on the valid dataset + + # create optimizer and optimized state + tx = optax.adam(alpha) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) + + ########################################################################### + # Shuffling indexes of valid individuals in the repertoire + ########################################################################### + + # size of the repertoire + repertoire_size = repertoire.max_size + + # number of individuals in the repertoire + num_indivs = repertoire.get_number_genotypes() + + # select repertoire_size indexes going from 0 to the total number of + # valid individuals. Those indexes will be used to select the individuals + # in the training dataset. + random_key, key_select_p1 = jax.random.split(random_key, 2) + idx_p1 = jax.random.randint( + key_select_p1, shape=(repertoire_size,), minval=0, maxval=num_indivs + ) + + # get indexes where fitness is not -inf. Those are the valid individuals. + indexes = jnp.argwhere( + jnp.logical_not(jnp.isinf(repertoire.fitnesses)), size=repertoire_size + ) + indexes = jnp.transpose(indexes, axes=(1, 0)) + + # get corresponding indices for the flattened repertoire fitnesses + indiv_indices = jnp.array( + jnp.ravel_multi_index(indexes, repertoire.fitnesses.shape, mode="clip") + ).astype(int) + + # filter those indices to get only the indices of valid individuals + valid_indexes = indiv_indices.at[idx_p1].get() + + # Normalising Dataset + steps_per_epoch = repertoire.observations.shape[0] // batch_size + + loss_val = 0.0 + for epoch in range(num_epochs): + random_key, shuffle_key = jax.random.split(random_key, 2) + valid_indexes = jax.random.permutation(shuffle_key, valid_indexes, axis=0) + + # create dataset with the observation from the sample of valid indexes + training_dataset = ( + repertoire.observations.at[valid_indexes, ...].get() - mean_obs + ) / std_obs + training_dataset = training_dataset.at[valid_indexes].get() + + for i in range(steps_per_epoch): + batch = jnp.asarray( + training_dataset.at[ + (i * batch_size) : (i * batch_size) + batch_size, :, : + ].get() + ) + + if batch.shape[0] < batch_size: + # print(batch.shape) + continue + + state, loss_val = train_step(state, batch, random_key) + + # To see the actual value we cannot jit this function (i.e. the _one_es_epoch + # function nor the train function) + print("Eval epoch: {}, loss: {:.4f}".format(epoch + 1, loss_val)) + + params = state.params + + return AuroraExtraInfoNormalization.create(params, mean_obs, std_obs) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py new file mode 100644 index 00000000..2b238237 --- /dev/null +++ b/tests/core_test/aurora_test.py @@ -0,0 +1,250 @@ +"""Tests AURORA implementation""" + +import functools +from typing import Tuple + +import brax.envs +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.aurora import AURORA +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.environments.bd_extractors import ( + AuroraExtraInfoNormalization, + get_aurora_encoding, +) +from qdax.tasks.brax_envs import ( + create_default_brax_task_components, + get_aurora_scoring_fn, +) +from qdax.types import Observation +from qdax.utils import train_seq2seq +from qdax.utils.metrics import default_qd_metrics +from tests.core_test.map_elites_test import get_mixing_emitter + + +def get_observation_dims( + observation_option: str, + env: brax.envs.Env, + max_observation_size: int, + episode_length: int, + traj_sampling_freq: int, + prior_descriptor_dim: int, +) -> Tuple[int, int]: + obs_dim = jnp.minimum(env.observation_size, max_observation_size) + if observation_option == "full": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim + prior_descriptor_dim, + ) + elif observation_option == "no_sd": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim, + ) + elif observation_option == "only_sd": + observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) + else: + raise ValueError(f"Unknown observation option: {observation_option}") + + return observations_dims + + +@pytest.mark.parametrize( + "env_name, batch_size", + [("halfcheetah_uni", 10), ("walker2d_uni", 10), ("hopper_uni", 10)], +) +def test_aurora(env_name: str, batch_size: int) -> None: + episode_length = 250 + max_iterations = 5 + seed = 42 + max_size = 50 + + lstm_batch_size = 12 + + observation_option = "no_sd" # "full", "no_sd", "only_sd" + hidden_size = 5 + l_value_init = 0.2 + + traj_sampling_freq = 10 + max_observation_size = 25 + prior_descriptor_dim = 2 + + log_freq = 5 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init environment + env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + env_name=env_name, + random_key=random_key, + ) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_variables = jax.vmap(policy_network.init)(keys, fake_batch) + + def observation_extractor_fn( + data: QDTransition, + ) -> Observation: + """Extract observation from the state.""" + state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size] + + # add the x/y position - (batch_size, traj_length, 2) + state_desc = data.state_desc[:, ::traj_sampling_freq] + + if observation_option == "full": + observations = jnp.concatenate([state_desc, state_obs], axis=-1) + elif observation_option == "no_sd": + observations = state_obs + elif observation_option == "only_sd": + observations = state_desc + else: + raise ValueError("Unknown observation option.") + + return observations + + # Prepare the scoring function + aurora_scoring_fn = get_aurora_scoring_fn( + scoring_fn=scoring_fn, + observation_extractor_fn=observation_extractor_fn, + ) + + # Define emitter + mixing_emitter = get_mixing_emitter(batch_size) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = environments.reward_offset[env_name] + + # Define a metrics function + metrics_fn = functools.partial(default_qd_metrics, qd_offset=reward_offset) + + # Init algorithm + # AutoEncoder Params and INIT + observations_dims = get_observation_dims( + observation_option=observation_option, + env=env, + max_observation_size=max_observation_size, + episode_length=episode_length, + traj_sampling_freq=traj_sampling_freq, + prior_descriptor_dim=prior_descriptor_dim, + ) + + # define the seq2seq model + model = train_seq2seq.get_model( + int(observations_dims[-1]), True, hidden_size=hidden_size + ) + + # define the encoder function + encoder_fn = jax.jit( + functools.partial( + get_aurora_encoding, + model=model, + ) + ) + + # define the training function + train_fn = functools.partial( + train_seq2seq.lstm_ae_train, + model=model, + batch_size=lstm_batch_size, + ) + + # Instantiate AURORA algorithm + aurora = AURORA( + scoring_function=aurora_scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + encoder_function=encoder_fn, + training_function=train_fn, + ) + + # init the model params + random_key, subkey = jax.random.split(random_key) + model_params = train_seq2seq.get_initial_params( + model, subkey, (1, *observations_dims) + ) + + # define arbitrary observation's mean/std + mean_observations = jnp.zeros(observations_dims[-1]) + std_observations = jnp.ones(observations_dims[-1]) + + # init all the information needed by AURORA to compute encodings + aurora_extra_info = AuroraExtraInfoNormalization.create( + model_params, + mean_observations, + std_observations, + ) + + # init step of the aurora algorithm + repertoire, emitter_state, aurora_extra_info, random_key = aurora.init( + init_variables, + aurora_extra_info, + jnp.asarray(l_value_init), + max_size, + random_key, + ) + + # initializing means and stds and AURORA + random_key, subkey = jax.random.split(random_key) + repertoire, aurora_extra_info = aurora.train( + repertoire, model_params, iteration=0, random_key=subkey + ) + + # design aurora's schedule + default_update_base = 10 + update_base = int(jnp.ceil(default_update_base / log_freq)) + schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base)) + + current_step_estimation = 0 + + ############################ + # Main loop + ############################ + + target_repertoire_size = 1024 + + previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size + + iteration = 0 + while iteration < max_iterations: + # standard MAP-Elites-like loop + for _ in range(log_freq): + repertoire, emitter_state, _, random_key = aurora.update( + repertoire, + emitter_state, + random_key, + aurora_extra_info=aurora_extra_info, + ) + + # update nb steps estimation + current_step_estimation += batch_size * episode_length * log_freq + + # autoencoder steps and Container Size Control (CSC) + if (iteration + 1) in schedules: + # train the autoencoder (includes the CSC) + random_key, subkey = jax.random.split(random_key) + repertoire, aurora_extra_info = aurora.train( + repertoire, model_params, iteration, subkey + ) + + elif iteration % 2 == 0: + # only CSC + repertoire, previous_error = aurora.container_size_control( + repertoire, + target_size=target_repertoire_size, + previous_error=previous_error, + ) + + iteration += 1 + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_aurora(env_name="pointmaze", batch_size=10) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 66748079..b532aa65 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -1,17 +1,14 @@ """Tests MAP Elites implementation""" import functools -from typing import Dict, Tuple +from typing import Tuple import jax import jax.numpy as jnp import pytest from qdax import environments -from qdax.core.containers.mapelites_repertoire import ( - MapElitesRepertoire, - compute_cvt_centroids, -) +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.map_elites import MAPElites @@ -19,6 +16,19 @@ from qdax.core.neuroevolution.networks.networks import MLP from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.types import EnvState, Params, RNGKey +from qdax.utils.metrics import default_qd_metrics + + +def get_mixing_emitter(batch_size: int) -> MixingEmitter: + """Create a mixing emitter with a given batch size.""" + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + return mixing_emitter @pytest.mark.parametrize( @@ -102,29 +112,13 @@ def play_step_fn( ) # Define emitter - variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) - mixing_emitter = MixingEmitter( - mutation_fn=lambda x, y: (x, y), - variation_fn=variation_fn, - variation_percentage=1.0, - batch_size=batch_size, - ) + mixing_emitter = get_mixing_emitter(batch_size) # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function - def metrics_fn(repertoire: MapElitesRepertoire) -> Dict: - - # Get metrics - grid_empty = repertoire.fitnesses == -jnp.inf - qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) - # Add offset for positive qd_score - qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) - coverage = 100 * jnp.mean(1.0 - grid_empty) - max_fitness = jnp.max(repertoire.fitnesses) - - return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + metrics_fn = functools.partial(default_qd_metrics, qd_offset=reward_offset) # Instantiate MAP-Elites map_elites = MAPElites(