diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1cd2511..d037820e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: rev: 0.3.9 hooks: - id: nbstripout - args: ["notebooks/"] + args: ["examples/"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.0.1 hooks: diff --git a/README.md b/README.md index 7fbb061b..0f55d24e 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ QDax is a tool to accelerate Quality-Diversity (QD) and neuro-evolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛 -QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) +QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) - QDax [paper](https://arxiv.org/abs/2202.01258) - QDax [documentation](https://qdax.readthedocs.io/en/latest/) @@ -32,7 +32,7 @@ Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). ## Basic API Usage -For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). +For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/mapelites.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). However, a summary of the main API usage is provided below: ```python @@ -124,13 +124,14 @@ QDax currently supports the following algorithms: | Algorithm | Example | | --- | --- | -| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) | -| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) | -| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/pgame_example.ipynb) | -| [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/omgmega_example.ipynb) | -| [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/cmamega_example.ipynb) | -| [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mome_example.ipynb) | -| [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mees_example.ipynb) | +| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | +| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | +| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) | +| [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmame.ipynb) | +| [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb) | +| [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmamega.ipynb) | +| [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) | +| [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) | ## QDax baseline algorithms @@ -138,11 +139,11 @@ The QDax library also provides implementations for some useful baseline algorith | Algorithm | Example | | --- | --- | -| [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/diayn_example.ipynb) | -| [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/dads_example.ipynb) | -| [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/smerl_example.ipynb) | -| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) | -| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) | +| [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/diayn.ipynb) | +| [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dads.ipynb) | +| [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/smerl.ipynb) | +| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb) | +| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb) | ## QDax Tasks The QDax library also provides numerous implementations for several standard Quality-Diversity tasks. diff --git a/docs/api_documentation/core/cmame.md b/docs/api_documentation/core/cmame.md new file mode 100644 index 00000000..350ab39d --- /dev/null +++ b/docs/api_documentation/core/cmame.md @@ -0,0 +1,13 @@ +# Covariance Matrix Adaptation MAP Elites (CMAME) + +To create an instance of CMAME, one need to use an instance of [MAP-Elites](map_elites.md) with the desired CMA Emitter - optimizing, random direction, improvement - detailed below.To use the pool of emitter mechanism, use the CMAPoolEmitter. + +Three emitter types: + +::: qdax.core.emitters.cma_emitter.CMAEmitter +::: qdax.core.emitters.cma_rnd_emitter.CMARndEmitter +::: qdax.core.emitters.cma_opt_emitter.CMAOptimizingEmitter + +Pool of homogeneous emitters: + +::: qdax.core.emitters.cma_pool_emitter.CMAPoolEmitter diff --git a/docs/examples b/docs/examples new file mode 120000 index 00000000..785887f7 --- /dev/null +++ b/docs/examples @@ -0,0 +1 @@ +../examples/ \ No newline at end of file diff --git a/docs/notebooks b/docs/notebooks deleted file mode 120000 index 50699112..00000000 --- a/docs/notebooks +++ /dev/null @@ -1 +0,0 @@ -../examples/notebooks/ \ No newline at end of file diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb new file mode 100644 index 00000000..26f8a764 --- /dev/null +++ b/examples/cmaes.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "222bbe00", + "metadata": {}, + "source": [ + "# Optimizing with CMA-ES in Jax\n", + "\n", + "This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create a CMA-ES optimizer\n", + "- how to launch a certain number of optimizing steps\n", + "- how to visualise the optimization process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d731f067", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Ellipse\n", + "\n", + "from qdax.core.cmaes import CMAES" + ] + }, + { + "cell_type": "markdown", + "id": "7b6e910b", + "metadata": {}, + "source": [ + "## Set the hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "404fb0dc", + "metadata": {}, + "outputs": [], + "source": [ + "#@title Hyperparameters\n", + "#@markdown ---\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "num_dimensions = 100 #@param {type:\"integer\"}\n", + "batch_size = 36 #@param {type:\"integer\"}\n", + "num_best = 18 #@param {type:\"integer\"}\n", + "sigma_g = 0.5 # 0.5 #@param {type:\"number\"}\n", + "minval = -5.12 #@param {type:\"number\"}\n", + "optim_problem = \"sphere\" #@param[\"rastrigin\", \"sphere\"]\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "id": "ccc7cbeb", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Define the fitness function - choose rastrigin or sphere" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "436dccbb", + "metadata": {}, + "outputs": [], + "source": [ + "def rastrigin_scoring(x: jnp.ndarray):\n", + " first_term = 10 * x.shape[-1]\n", + " second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n", + " return -(first_term + second_term)\n", + "\n", + "def sphere_scoring(x: jnp.ndarray):\n", + " return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n", + "\n", + "if optim_problem == \"sphere\":\n", + " fitness_fn = sphere_scoring\n", + "elif optim_problem == \"rastrigin\":\n", + " fitness_fn = jax.vmap(rastrigin_scoring)\n", + "else:\n", + " raise Exception(\"Invalid opt function name given\")" + ] + }, + { + "cell_type": "markdown", + "id": "62bdd2a4", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Define a CMA-ES optimizer instance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cf03f55", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "cmaes = CMAES(\n", + " population_size=batch_size,\n", + " num_best=num_best,\n", + " search_dim=num_dimensions,\n", + " fitness_function=fitness_fn,\n", + " mean_init=jnp.zeros((num_dimensions,)),\n", + " init_sigma=sigma_g,\n", + " delay_eigen_decomposition=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f1f69f50", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Init the CMA-ES optimizer state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a95b74d", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "state = cmaes.init()\n", + "random_key = jax.random.PRNGKey(0)" + ] + }, + { + "cell_type": "markdown", + "id": "ac2d5c0d", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Run optimization iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "363198ca", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "means = [state.mean]\n", + "covs = [(state.sigma**2) * state.cov_matrix]\n", + "\n", + "iteration_count = 0\n", + "for _ in range(num_iterations):\n", + " iteration_count += 1\n", + " \n", + " # sample\n", + " samples, random_key = cmaes.sample(state, random_key)\n", + " \n", + " # udpate\n", + " state = cmaes.update(state, samples)\n", + " \n", + " # check stop condition\n", + " stop_condition = cmaes.stop_condition(state)\n", + "\n", + " if stop_condition:\n", + " break\n", + " \n", + " # store data for plotting\n", + " means.append(state.mean)\n", + " covs.append((state.sigma**2) * state.cov_matrix)\n", + " \n", + "print(\"Num iterations before stop condition: \", iteration_count)" + ] + }, + { + "cell_type": "markdown", + "id": "0e5820b8", + "metadata": {}, + "source": [ + "## Check final fitnesses and distribution mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e4a2c7b", + "metadata": {}, + "outputs": [], + "source": [ + "# checking final fitness values\n", + "fitnesses = fitness_fn(samples)\n", + "\n", + "print(\"Min fitness in the final population: \", jnp.min(fitnesses))\n", + "print(\"Mean fitness in the final population: \", jnp.mean(fitnesses))\n", + "print(\"Max fitness in the final population: \", jnp.max(fitnesses))\n", + "\n", + "# checking mean of the final distribution\n", + "print(\"Final mean of the distribution: \\n\", means[-1])\n", + "# print(\"Final covariance matrix of the distribution: \", covs[-1])" + ] + }, + { + "cell_type": "markdown", + "id": "f3bd2b0f", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Visualization of the optimization trajectory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad85551c", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(12, 6))\n", + "\n", + "# sample points to show fitness landscape\n", + "random_key, subkey = jax.random.split(random_key)\n", + "x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n", + "f_x = fitness_fn(x)\n", + "\n", + "# plot fitness landscape\n", + "points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)\n", + "fig.colorbar(points)\n", + "\n", + "# plot cma-es trajectory\n", + "traj_min = 0\n", + "traj_max = iteration_count\n", + "for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):\n", + " ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')\n", + " ax.add_patch(ellipse)\n", + " ax.plot(mean[0], mean[1], color='k', marker='x')\n", + " \n", + "ax.set_title(f\"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb new file mode 100644 index 00000000..c9d6f67e --- /dev/null +++ b/examples/cmame.ipynb @@ -0,0 +1,388 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "###### [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/cmamega_example.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with CMA-ME in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing parameters on Rastrigin or Sphere problem with [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create a CMA-ME emitter\n", + "- how to create a MAP-Elites instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import matplotlib as mpl\n", + "import matplotlib.cm as cm\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import jax \n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", + " import chex\n", + " \n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "from qdax.core.map_elites import MAPElites\n", + "from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter\n", + "from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter\n", + "from qdax.core.emitters.cma_improvement_emitter import CMAImprovementEmitter\n", + "from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter\n", + "from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids, MapElitesRepertoire\n", + "\n", + "from typing import Dict" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set the hyperparameters\n", + "\n", + "Most hyperparameters are similar to those introduced in [Differentiable Quality Diversity paper](https://arxiv.org/pdf/2106.03894.pdf)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "num_iterations = 70000 #70000 #10000\n", + "num_dimensions = 100 #1000 #@param {type:\"integer\"}\n", + "grid_shape = (500, 500) # (500, 500) \n", + "batch_size = 36 #36 #@param {type:\"integer\"}\n", + "sigma_g = .5 #@param {type:\"number\"}\n", + "minval = -5.12 #@param {type:\"number\"}\n", + "maxval = 5.12 #@param {type:\"number\"}\n", + "min_bd = -5.12 * 0.5 * num_dimensions #@param {type:\"number\"}\n", + "max_bd = 5.12 * 0.5 * num_dimensions #@param {type:\"number\"}\n", + "emitter_type = \"imp\" #@param[\"opt\", \"imp\", \"rnd\"]\n", + "pool_size = 15 #@param {type:\"integer\"}\n", + "optim_problem = \"rastrigin\" #@param[\"rastrigin\", \"sphere\"]\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defines the scoring function: rastrigin or sphere" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rastrigin_scoring(x: jnp.ndarray):\n", + " first_term = 10 * x.shape[-1]\n", + " second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n", + " return -(first_term + second_term)\n", + "\n", + "def sphere_scoring(x: jnp.ndarray):\n", + " return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n", + "\n", + "if optim_problem == \"sphere\":\n", + " fitness_scoring = sphere_scoring\n", + "elif optim_problem == \"rastrigin\":\n", + " fitness_scoring = rastrigin_scoring\n", + "else:\n", + " raise Exception(\"Invalid opt function name given\")\n", + "\n", + "def clip(x: jnp.ndarray):\n", + " in_bound = (x <= maxval) * (x >= minval)\n", + " return jnp.where(\n", + " condition=in_bound,\n", + " x=x,\n", + " y=(maxval / x)\n", + " )\n", + "\n", + "def _behavior_descriptor_1(x: jnp.ndarray):\n", + " return jnp.sum(clip(x[:x.shape[-1]//2]))\n", + "\n", + "def _behavior_descriptor_2(x: jnp.ndarray):\n", + " return jnp.sum(clip(x[x.shape[-1]//2:]))\n", + "\n", + "def _behavior_descriptors(x: jnp.ndarray):\n", + " return jnp.array([_behavior_descriptor_1(x), _behavior_descriptor_2(x)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def scoring_function(x):\n", + " scores, descriptors = fitness_scoring(x), _behavior_descriptors(x)\n", + " return scores, descriptors, {}\n", + "\n", + "def scoring_fn(x, random_key):\n", + " fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n", + " return fitnesses, descriptors, extra_scores, random_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the metrics that will be used" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12)\n", + "best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)\n", + "\n", + "num_centroids = math.prod(grid_shape)\n", + "\n", + "def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:\n", + "\n", + " # get metrics\n", + " grid_empty = repertoire.fitnesses == -jnp.inf\n", + " adjusted_fitness = (\n", + " (repertoire.fitnesses - worst_objective) * 100 / (best_objective - worst_objective)\n", + " )\n", + " qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) # / num_centroids\n", + " coverage = 100 * jnp.mean(1.0 - grid_empty)\n", + " max_fitness = jnp.max(adjusted_fitness)\n", + " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define initial population, emitter and MAP Elites instance\n", + "\n", + "The emitter is defined using the CMAME emitter class. This emitter is given to a MAP-Elites instance to create an instance of the CMA-ME algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_key = jax.random.PRNGKey(0)\n", + "# in CMA-ME settings (from the paper), there is no init population\n", + "# we multipy by zero to reproduce this setting\n", + "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", + "\n", + "centroids = compute_euclidean_centroids(\n", + " grid_shape=grid_shape,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + ")\n", + "\n", + "emitter_kwargs = {\n", + " \"batch_size\": batch_size,\n", + " \"genotype_dim\": num_dimensions,\n", + " \"centroids\": centroids,\n", + " \"sigma_g\": sigma_g,\n", + " \"min_count\": 1,\n", + " \"max_count\": None,\n", + "}\n", + "\n", + "if emitter_type == \"opt\":\n", + " emitter = CMAOptimizingEmitter(**emitter_kwargs)\n", + "elif emitter_type == \"imp\":\n", + " emitter = CMAImprovementEmitter(**emitter_kwargs)\n", + "elif emitter_type == \"rnd\":\n", + " emitter = CMARndEmitter(**emitter_kwargs)\n", + "else:\n", + " raise Exception(\"Invalid emitter type\")\n", + "\n", + "emitter = CMAPoolEmitter(\n", + " num_states=pool_size,\n", + " emitter=emitter\n", + ")\n", + "\n", + "map_elites = MAPElites(\n", + " scoring_function=scoring_fn,\n", + " emitter=emitter,\n", + " metrics_function=metrics_fn\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init the repertoire and emitter state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run optimization/illumination process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " map_elites.scan_update,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=num_iterations,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in metrics.items():\n", + " print(f\"{k} after {num_iterations * batch_size}: {v[-1]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot results\n", + "\n", + "Update the savefig variable to save your results locally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env_steps = jnp.arange(num_iterations) * batch_size\n", + "\n", + "\n", + "# Customize matplotlib params\n", + "font_size = 16\n", + "params = {\n", + " \"axes.labelsize\": font_size,\n", + " \"axes.titlesize\": font_size,\n", + " \"legend.fontsize\": font_size,\n", + " \"xtick.labelsize\": font_size,\n", + " \"ytick.labelsize\": font_size,\n", + " \"text.usetex\": False,\n", + " \"axes.titlepad\": 10,\n", + "}\n", + "\n", + "mpl.rcParams.update(params)\n", + "\n", + "# Visualize the training evolution and final repertoire\n", + "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(40, 10))\n", + "\n", + "# env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "\n", + "axes[0].plot(env_steps, metrics[\"coverage\"])\n", + "axes[0].set_xlabel(\"Environment steps\")\n", + "axes[0].set_ylabel(\"Coverage in %\")\n", + "axes[0].set_title(\"Coverage evolution during training\")\n", + "axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable=\"box\")\n", + "\n", + "axes[1].plot(env_steps, metrics[\"max_fitness\"])\n", + "axes[1].set_xlabel(\"Environment steps\")\n", + "axes[1].set_ylabel(\"Maximum fitness\")\n", + "axes[1].set_title(\"Maximum fitness evolution during training\")\n", + "axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable=\"box\")\n", + "\n", + "axes[2].plot(env_steps, metrics[\"qd_score\"])\n", + "axes[2].set_xlabel(\"Environment steps\")\n", + "axes[2].set_ylabel(\"QD Score\")\n", + "axes[2].set_title(\"QD Score evolution during training\")\n", + "axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable=\"box\")\n", + "\n", + "# udpate this variable to save your results locally\n", + "savefig = False\n", + "if savefig:\n", + " figname = \"cma_me_\" + optim_problem + \"_\" + str(num_dimensions) + \"_\" + emitter_type + \".png\"\n", + " print(\"Save figure in: \", figname)\n", + " plt.savefig(figname)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/notebooks/cmamega_example.ipynb b/examples/cmamega.ipynb similarity index 91% rename from examples/notebooks/cmamega_example.ipynb rename to examples/cmamega.ipynb index 73920dc9..509e52ea 100644 --- a/examples/notebooks/cmamega_example.ipynb +++ b/examples/cmamega.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/cmamega_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmamega.ipynb)" ] }, { @@ -75,14 +75,14 @@ "source": [ "#@title QD Training Definitions Fields\n", "#@markdown ---\n", - "num_iterations = 20000\n", + "num_iterations = 20000 #@param {type:\"integer\"}\n", "num_dimensions = 1000 #@param {type:\"integer\"}\n", "num_centroids = 10000 #@param {type:\"integer\"}\n", "minval = -5.12 #@param {type:\"number\"}\n", "maxval = 5.12 #@param {type:\"number\"}\n", "batch_size = 36 #@param {type:\"integer\"}\n", "learning_rate = 1 #@param {type:\"number\"}\n", - "sigma_g = 10 #@param {type:\"number\"}\n", + "sigma_g = 3.16 #@param {type:\"number\"} # square root of 10, the value given in the paper\n", "minval = -5.12 #@param {type:\"number\"}\n", "maxval = 5.12 #@param {type:\"number\"}\n", "#@markdown ---" @@ -110,10 +110,10 @@ " return x*(x<=maxval)*(x>=+minval) + maxval/x*((x>maxval)+(x<+minval))\n", "\n", "def _rastrigin_descriptor_1(x: jnp.ndarray):\n", - " return jnp.mean(clip(x[:x.shape[0]//2]))\n", + " return jnp.mean(clip(x[:x.shape[-1]//2]))\n", "\n", "def _rastrigin_descriptor_2(x: jnp.ndarray):\n", - " return jnp.mean(clip(x[x.shape[0]//2:]))\n", + " return jnp.mean(clip(x[x.shape[-1]//2:]))\n", "\n", "def rastrigin_descriptors(x: jnp.ndarray):\n", " return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)])\n", @@ -199,7 +199,8 @@ "outputs": [], "source": [ "random_key = jax.random.PRNGKey(0)\n", - "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions))\n", + "# no initial population - give all the same value as emitter init value\n", + "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", "centroids, random_key = compute_cvt_centroids(\n", " num_descriptors=2, \n", @@ -215,6 +216,7 @@ " batch_size=batch_size,\n", " learning_rate=learning_rate,\n", " num_descriptors=2,\n", + " centroids=centroids,\n", " sigma_g=sigma_g,\n", ")\n", "\n", @@ -250,6 +252,23 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in metrics.items():\n", + " print(f\"{k} after {num_iterations * batch_size}: {v[-1]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualise results" + ] + }, { "cell_type": "code", "execution_count": null, @@ -270,7 +289,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('qdaxpy38')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/examples/notebooks/dads_example.ipynb b/examples/dads.ipynb similarity index 99% rename from examples/notebooks/dads_example.ipynb rename to examples/dads.ipynb index a6fb3fbd..d4476c6d 100644 --- a/examples/notebooks/dads_example.ipynb +++ b/examples/dads.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/dads_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dads.ipynb)" ] }, { diff --git a/examples/notebooks/diayn_example.ipynb b/examples/diayn.ipynb similarity index 99% rename from examples/notebooks/diayn_example.ipynb rename to examples/diayn.ipynb index d114ab3f..f7744e79 100644 --- a/examples/notebooks/diayn_example.ipynb +++ b/examples/diayn.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/diayn_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/diayn.ipynb)" ] }, { diff --git a/examples/notebooks/mapelites_example.ipynb b/examples/mapelites.ipynb similarity index 99% rename from examples/notebooks/mapelites_example.ipynb rename to examples/mapelites.ipynb index 1a324ea5..120e7b54 100644 --- a/examples/notebooks/mapelites_example.ipynb +++ b/examples/mapelites.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)" ] }, { diff --git a/examples/notebooks/mees_example.ipynb b/examples/mees.ipynb similarity index 100% rename from examples/notebooks/mees_example.ipynb rename to examples/mees.ipynb diff --git a/examples/notebooks/mome_example.ipynb b/examples/mome.ipynb similarity index 99% rename from examples/notebooks/mome_example.ipynb rename to examples/mome.ipynb index 86f14106..6a6f7d39 100644 --- a/examples/notebooks/mome_example.ipynb +++ b/examples/mome.ipynb @@ -5,7 +5,7 @@ "id": "59f748d3", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mome_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb)" ] }, { diff --git a/examples/notebooks/nsga2_spea2_example.ipynb b/examples/nsga2_spea2.ipynb similarity index 99% rename from examples/notebooks/nsga2_spea2_example.ipynb rename to examples/nsga2_spea2.ipynb index df903fd5..5cbe02a2 100644 --- a/examples/notebooks/nsga2_spea2_example.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb)" ] }, { diff --git a/examples/notebooks/omgmega_example.ipynb b/examples/omgmega.ipynb similarity index 97% rename from examples/notebooks/omgmega_example.ipynb rename to examples/omgmega.ipynb index 6b10e0f5..d75a0077 100644 --- a/examples/notebooks/omgmega_example.ipynb +++ b/examples/omgmega.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/omgmega_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb)" ] }, { @@ -188,7 +188,7 @@ "\n", "# defines the population\n", "random_key, subkey = jax.random.split(random_key)\n", - "initial_population = jax.random.uniform(subkey, shape=(init_population_size, num_dimensions))\n", + "initial_population = jax.random.normal(subkey, shape=(init_population_size, num_dimensions))\n", "\n", "sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid \n", "grid_shape = (sqrt_centroids, sqrt_centroids)\n", @@ -280,7 +280,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('qdaxpy38')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/examples/notebooks/pgame_example.ipynb b/examples/pgame.ipynb similarity index 99% rename from examples/notebooks/pgame_example.ipynb rename to examples/pgame.ipynb index 7d790d3e..bc295cca 100644 --- a/examples/notebooks/pgame_example.ipynb +++ b/examples/pgame.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/pgame_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb)" ] }, { diff --git a/examples/notebooks/smerl_example.ipynb b/examples/smerl.ipynb similarity index 99% rename from examples/notebooks/smerl_example.ipynb rename to examples/smerl.ipynb index 82c71fbb..46d92eea 100644 --- a/examples/notebooks/smerl_example.ipynb +++ b/examples/smerl.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/smerl_example.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/smerl.ipynb)" ] }, { diff --git a/mkdocs.yml b/mkdocs.yml index a368a160..132bb1ff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -112,20 +112,24 @@ nav: - Guides: - Contributing: guides/CONTRIBUTING.md - Examples: - - MAPElites: notebooks/mapelites_example.ipynb - - PGAME: notebooks/pgame_example.ipynb - - OMG MEGA: notebooks/omgmega_example.ipynb - - CMA MEGA: notebooks/cmamega_example.ipynb - - MOME: notebooks/mome_example.ipynb - - DIAYN: notebooks/diayn_example.ipynb - - DADS: notebooks/dads_example.ipynb - - SMERL: notebooks/smerl_example.ipynb - - MEES: notebooks/mees_example.ipynb + - MAPElites: examples/mapelites.ipynb + - PGAME: examples/pgame.ipynb + - CMA ME: examples/cmame.ipynb + - OMG MEGA: examples/omgmega.ipynb + - CMA MEGA: examples/cmamega.ipynb + - MOME: examples/mome.ipynb + - MEES: examples/mees.ipynb + - DIAYN: examples/diayn.ipynb + - DADS: examples/dads.ipynb + - SMERL: examples/smerl.ipynb + - CMA ES: examples/cmaes.ipynb + - NSGA2/SPEA2: examples/nsga2_spea2.ipynb - API documentation: - Core: - Core algorithms: - MAP Elites: api_documentation/core/map_elites.md - PGAME: api_documentation/core/pgame.md + - CMA ME: api_documentation/core/cmame.md - OMG MEGA: api_documentation/core/omg_mega.md - CMA MEGA: api_documentation/core/cma_mega.md - MOME: api_documentation/core/mome.md diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index fd2d8eb3..481a49bf 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -2,28 +2,44 @@ Definition of CMAES class, containing main functions necessary to build a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772 """ -import functools +from functools import partial from typing import Callable, Optional, Tuple import flax import jax import jax.numpy as jnp -from qdax.types import Fitness, Genotype, RNGKey +from qdax.types import Fitness, Genotype, Mask, RNGKey class CMAESState(flax.struct.PyTreeNode): - """ - Describe a state of the Covariance matrix adaptation evolution strategy + """Describe a state of the Covariance Matrix Adaptation Evolution Strategy (CMA-ES) algorithm. + + Args: + mean: mean of the gaussian distribution used to generate solutions + cov_matrix: covariance matrix of the gaussian distribution used to + generate solutions - (multiplied by sigma for sampling). + num_updates: number of updates made by the CMAES optimizer since the + beginning of the process. + sigma: the step size of the optimization steps. Multiplies the cov matrix + to get the real cov matrix used for the sampling process. + p_c: evolution path + p_s: evolution path + eigen_updates: track the latest update to know when to do the next one. + eigenvalues: latest eigenvalues + invsqrt_cov: latest inv sqrt value of the cov matrix. """ mean: jnp.ndarray cov_matrix: jnp.ndarray num_updates: int - step_size: float + sigma: float p_c: jnp.ndarray p_s: jnp.ndarray + eigen_updates: int + eigenvalues: jnp.ndarray + invsqrt_cov: jnp.ndarray class CMAES: @@ -37,18 +53,32 @@ def __init__( search_dim: int, fitness_function: Callable[[Genotype], Fitness], num_best: Optional[int] = None, - weight_decay: float = 0.01, init_sigma: float = 1e-3, mean_init: Optional[jnp.ndarray] = None, bias_weights: bool = True, - init_step_size: float = 1e-3, + delay_eigen_decomposition: bool = False, ): + """Instantiate a CMA-ES optimizer. + + Args: + population_size: size of the running population. + search_dim: number of dimensions in the search space. + fitness_function: fitness function that is being optimized. + num_best: number of best individuals in the population being considered + for the update of the distributions. Defaults to None. + init_sigma: Initial value of the step size. Defaults to 1e-3. + mean_init: Initial value of the distribution mean. Defaults to None. + bias_weights: Should the weights be biased towards best individuals. + Defaults to True. + delay_eigen_decomposition: should the update of the inverse of the + cov matrix be delayed. As this operation is a time bottleneck, having + it delayed improves the time perfs by a significant margin. + Defaults to False. + """ self._population_size = population_size - self._weight_decay = weight_decay self._search_dim = search_dim self._fitness_function = fitness_function self._init_sigma = init_sigma - self._init_step_size = init_step_size # Default values if values are not provided if num_best is None: @@ -63,8 +93,9 @@ def __init__( # weights parameters if bias_weights: + # heuristic from Nicolas Hansen original implementation self._weights = jnp.log( - (self._num_best + 1) / jnp.arange(start=1, stop=(self._num_best + 1)) + (self._num_best + 0.5) / jnp.arange(start=1, stop=(self._num_best + 1)) ) else: self._weights = jnp.ones(self._num_best) @@ -78,20 +109,35 @@ def __init__( self._c_c = (4 + self._parents_eff / self._search_dim) / ( self._search_dim + 4 + 2 * self._parents_eff / self._search_dim ) - tmp = self._parents_eff - 2 + 1 / self._parents_eff + + # learning rate for rank-1 update of C self._c_1 = 2 / (self._parents_eff + (self._search_dim + jnp.sqrt(2)) ** 2) + + # learning rate for rank-(num best) updates + tmp = 2 * (self._parents_eff - 2 + 1 / self._parents_eff) self._c_cov = min( 1 - self._c_1, tmp / (self._parents_eff + (self._search_dim + 2) ** 2) ) + + # damping for sigma self._d_s = ( 1 - + 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1) - 1)) + + 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1)) - 1) + self._c_s ) self._chi = jnp.sqrt(self._search_dim) * ( 1 - 1 / (4 * self._search_dim) + 1 / (21 * self._search_dim**2) ) + # threshold for new eigen decomposition - from pyribs + self._eigen_comput_period = 1 + if delay_eigen_decomposition: + self._eigen_comput_period = ( + 0.5 + * self._population_size + / (self._search_dim * (self._c_1 + self._c_cov)) + ) + def init(self) -> CMAESState: """ Init the CMA-ES algorithm. @@ -99,16 +145,26 @@ def init(self) -> CMAESState: Returns: an initial state for the algorithm """ + + # initial cov matrix + cov_matrix = jnp.eye(self._search_dim) + + # initial inv sqrt of the cov matrix - cov is already diag + invsqrt_cov = jnp.diag(1 / jnp.sqrt(jnp.diag(cov_matrix))) + return CMAESState( mean=self._mean_init, - cov_matrix=self._init_sigma * jnp.eye(self._search_dim), - step_size=self._init_step_size, - num_updates=1, + cov_matrix=cov_matrix, + sigma=self._init_sigma, + num_updates=0, p_c=jnp.zeros(shape=(self._search_dim,)), p_s=jnp.zeros(shape=(self._search_dim,)), + eigen_updates=0, + eigenvalues=jnp.ones(shape=(self._search_dim,)), + invsqrt_cov=invsqrt_cov, ) - @functools.partial(jax.jit, static_argnames=("self",)) + @partial(jax.jit, static_argnames=("self",)) def sample( self, cmaes_state: CMAESState, random_key: RNGKey ) -> Tuple[Genotype, RNGKey]: @@ -128,58 +184,129 @@ def sample( subkey, shape=(self._population_size,), mean=cmaes_state.mean, - cov=cmaes_state.cov_matrix, + cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix, ) return samples, random_key - @functools.partial(jax.jit, static_argnames=("self",)) + @partial(jax.jit, static_argnames=("self",)) def update_state( - self, cmaes_state: CMAESState, sorted_candidates: Genotype + self, + cmaes_state: CMAESState, + sorted_candidates: Genotype, ) -> CMAESState: + return self._update_state( # type: ignore + cmaes_state=cmaes_state, + sorted_candidates=sorted_candidates, + weights=self._weights, + ) + + @partial(jax.jit, static_argnames=("self",)) + def update_state_with_mask( + self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask + ) -> CMAESState: + """Update weights with a mask, then update the state. + Convention: 1 stays, 0 a removed. """ - Updates the state when candidates have already been sorted and selected. + + # update weights by multiplying by a mask + weights = jnp.multiply(self._weights, mask) + weights = weights / (weights.sum()) + + return self._update_state( # type: ignore + cmaes_state=cmaes_state, + sorted_candidates=sorted_candidates, + weights=weights, + ) + + @partial(jax.jit, static_argnames=("self",)) + def _update_state( + self, + cmaes_state: CMAESState, + sorted_candidates: Genotype, + weights: jnp.ndarray, + ) -> CMAESState: + """Updates the state when candidates have already been + sorted and selected. Args: cmaes_state: current state of the algorithm sorted_candidates: a batch of sorted and selected genotypes + weights: weights used to recombine the candidates Returns: An updated algorithm state """ + # retrieve elements from the current state p_c = cmaes_state.p_c p_s = cmaes_state.p_s - step_size = cmaes_state.step_size + sigma = cmaes_state.sigma num_updates = cmaes_state.num_updates cov = cmaes_state.cov_matrix mean = cmaes_state.mean - # update mean + eigen_updates = cmaes_state.eigen_updates + eigenvalues = cmaes_state.eigenvalues + invsqrt_cov = cmaes_state.invsqrt_cov + + # update mean by recombination old_mean = mean - mean = self._weights @ sorted_candidates - z = 1 / step_size * (mean - old_mean).T - eig, u = jnp.linalg.eigh(cov) - invsqrt = u @ jnp.diag(1 / jnp.sqrt(eig)) @ u.T + mean = weights @ sorted_candidates + + def update_eigen( + operand: Tuple[jnp.ndarray, int] + ) -> Tuple[int, jnp.ndarray, jnp.ndarray]: + + # unpack data + cov, num_updates = operand + + # enfore symmetry - did not change anything + cov = jnp.triu(cov) + jnp.triu(cov, 1).T + + # get eigen decomposition: eigenvalues, eigenvectors + eig, u = jnp.linalg.eigh(cov) + + # compute new invsqrt + invsqrt = u @ jnp.diag(1 / jnp.sqrt(eig)) @ u.T + + # update the eigen value decomposition tracker + eigen_updates = num_updates + + return eigen_updates, eig, invsqrt + + # condition for recomputing the eig decomposition + eigen_condition = (num_updates - eigen_updates) >= self._eigen_comput_period + + # decomposition of cov + eigen_updates, eigenvalues, invsqrt = jax.lax.cond( + eigen_condition, + update_eigen, + lambda _: (eigen_updates, eigenvalues, invsqrt_cov), + operand=(cov, num_updates), + ) + + z = (1 / sigma) * (mean - old_mean) z_w = invsqrt @ z - # update evolution paths + # update evolution paths - cumulation p_s = (1 - self._c_s) * p_s + jnp.sqrt( self._c_s * (2 - self._c_s) * self._parents_eff - ) * z_w.squeeze() + ) * z_w tmp_1 = jnp.linalg.norm(p_s) / jnp.sqrt( 1 - (1 - self._c_s) ** (2 * num_updates) ) <= self._chi * (1.4 + 2 / (self._search_dim + 1)) - p_c = (1 - self._c_c) * p_c + 1 * jnp.sqrt( + p_c = (1 - self._c_c) * p_c + tmp_1 * jnp.sqrt( self._c_c * (2 - self._c_c) * self._parents_eff - ) * z.squeeze() + ) * z # update covariance matrix pp_c = jnp.expand_dims(p_c, axis=1) - coeff_tmp = 1 / step_size * (sorted_candidates - mean) - cov_rank = coeff_tmp.T @ jnp.diag(self._weights.squeeze()) @ coeff_tmp + + coeff_tmp = (sorted_candidates - old_mean) / sigma + cov_rank = coeff_tmp.T @ jnp.diag(weights.squeeze()) @ coeff_tmp cov = ( (1 - self._c_cov - self._c_1) * cov @@ -189,24 +316,27 @@ def update_state( ) # update step size - step_size = step_size * jnp.exp( + sigma = sigma * jnp.exp( (self._c_s / self._d_s) * (jnp.linalg.norm(p_s) / self._chi - 1) ) cmaes_state = CMAESState( mean=mean, cov_matrix=cov, - step_size=step_size, + sigma=sigma, num_updates=num_updates + 1, p_c=p_c, p_s=p_s, + eigen_updates=eigen_updates, + eigenvalues=eigenvalues, + invsqrt_cov=invsqrt, ) + return cmaes_state - @functools.partial(jax.jit, static_argnames=("self",)) + @partial(jax.jit, static_argnames=("self",)) def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState: - """ - Updates the distribution. + """Updates the distribution. Args: cmaes_state: current state of the algorithm @@ -223,3 +353,41 @@ def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState: new_state = self.update_state(cmaes_state, sorted_candidates) return new_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def stop_condition(self, cmaes_state: CMAESState) -> bool: + """Determines if the current optimization path must be stopped. + + A set of 5 conditions are computed, one condition is enough to + stop the process. This function does not stop the process but simply + retrieves the value. It is not called in the update function but can be + used to manually stopped the process (see example in CMA ME emitter). + + Args: + cmaes_state: current CMAES state + + Returns: + A boolean stating if the process should be stopped. + """ + + # NaN appears because of float precision is reached + nan_condition = jnp.sum(jnp.isnan(cmaes_state.eigenvalues)) > 0 + + eig_dispersion = jnp.max(cmaes_state.eigenvalues) / jnp.min( + cmaes_state.eigenvalues + ) + first_condition = eig_dispersion > 1e14 + + area = cmaes_state.sigma * jnp.sqrt(jnp.max(cmaes_state.eigenvalues)) + second_condition = area < 1e-11 + + third_condition = jnp.max(cmaes_state.eigenvalues) < 1e-7 + fourth_condition = jnp.min(cmaes_state.eigenvalues) > 1e7 + + return ( # type: ignore + nan_condition + + first_condition + + second_condition + + third_condition + + fourth_condition + ) diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py new file mode 100644 index 00000000..f9d58caa --- /dev/null +++ b/qdax/core/emitters/cma_emitter.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import partial +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp + +from qdax.core.cmaes import CMAES, CMAESState +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + get_cells_indices, +) +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +class CMAEmitterState(EmitterState): + """ + Emitter state for the CMA-ME emitter. + + Args: + random_key: a random key to handle stochastic operations. Used for + state update only, another key is used to emit. This might be + subject to refactoring discussions in the future. + cmaes_state: state of the underlying CMA-ES algorithm + previous_fitnesses: store last fitnesses of the repertoire. Used to + compute the improvment. + emit_count: count the number of emission events. + """ + + random_key: RNGKey + cmaes_state: CMAESState + previous_fitnesses: Fitness + emit_count: int + + +class CMAEmitter(Emitter, ABC): + def __init__( + self, + batch_size: int, + genotype_dim: int, + centroids: Centroid, + sigma_g: float, + min_count: Optional[int] = None, + max_count: Optional[float] = None, + ): + """ + Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the + Rapid Illumination of Behavior Space" by Fontaine et al. + + Args: + batch_size: number of solutions sampled at each iteration + genotype_dim: dimension of the genotype space. + centroids: centroids used for the repertoire. + sigma_g: standard deviation for the coefficients - called step size. + min_count: minimum number of CMAES opt step before being considered for + reinitialisation. + max_count: maximum number of CMAES opt step authorized. + """ + self._batch_size = batch_size + + # define a CMAES instance + self._cmaes = CMAES( + population_size=batch_size, + search_dim=genotype_dim, + # no need for fitness function in that specific case + fitness_function=None, # type: ignore + num_best=batch_size, + init_sigma=sigma_g, + mean_init=None, # will be init at zeros in cmaes + bias_weights=True, + delay_eigen_decomposition=True, + ) + + # minimum number of emitted solution before an emitter can be re-initialized + if min_count is None: + min_count = 0 + + self._min_count = min_count + + if max_count is None: + max_count = jnp.inf + + self._max_count = max_count + + self._centroids = centroids + + self._cma_initial_state = self._cmaes.init() + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._batch_size + + @partial(jax.jit, static_argnames=("self",)) + def init( + self, init_genotypes: Genotype, random_key: RNGKey + ) -> Tuple[CMAEmitterState, RNGKey]: + """ + Initializes the CMA-MEGA emitter + + + Args: + init_genotypes: initial genotypes to add to the grid. + random_key: a random key to handle stochastic operations. + + Returns: + The initial state of the emitter. + """ + + # Initialize repertoire with default values + num_centroids = self._centroids.shape[0] + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + + # return the initial state + random_key, subkey = jax.random.split(random_key) + return ( + CMAEmitterState( + random_key=subkey, + cmaes_state=self._cma_initial_state, + previous_fitnesses=default_fitnesses, + emit_count=0, + ), + random_key, + ) + + @partial(jax.jit, static_argnames=("self",)) + def emit( + self, + repertoire: Optional[MapElitesRepertoire], + emitter_state: CMAEmitterState, + random_key: RNGKey, + ) -> Tuple[Genotype, RNGKey]: + """ + Emits new individuals. Interestingly, this method does not directly modifies + individuals from the repertoire but sample from a distribution. Hence the + repertoire is not used in the emit function. + + Args: + repertoire: a repertoire of genotypes (unused). + emitter_state: the state of the CMA-MEGA emitter. + random_key: a random key to handle random operations. + + Returns: + New genotypes and a new random key. + """ + # emit from CMA-ES + offsprings, random_key = self._cmaes.sample( + cmaes_state=emitter_state.cmaes_state, random_key=random_key + ) + + return offsprings, random_key + + @partial( + jax.jit, + static_argnames=("self",), + ) + def state_update( + self, + emitter_state: CMAEmitterState, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: Optional[ExtraScores] = None, + ) -> Optional[EmitterState]: + """ + Updates the CMA-ME emitter state. + + Note: we use the update_state function from CMAES, a function that assumes + that the candidates are already sorted. We do this because we have to sort + them in this function anyway, in order to apply the right weights to the + terms when update theta. + + Args: + emitter_state: current emitter state + repertoire: the current genotypes repertoire + genotypes: the genotypes of the batch of emitted offspring (unused). + fitnesses: the fitnesses of the batch of emitted offspring. + descriptors: the descriptors of the emitted offspring. + extra_scores: unused + + Returns: + The updated emitter state. + """ + + # retrieve elements from the emitter state + cmaes_state = emitter_state.cmaes_state + + # Compute the improvements - needed for re-init condition + indices = get_cells_indices(descriptors, repertoire.centroids) + improvements = fitnesses - emitter_state.previous_fitnesses[indices] + + ranking_criteria = self._ranking_criteria( + emitter_state=emitter_state, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + improvements=improvements, + ) + + # get the indices + sorted_indices = jnp.flip(jnp.argsort(ranking_criteria)) + + # sort the candidates + sorted_candidates = jax.tree_util.tree_map( + lambda x: x[sorted_indices], genotypes + ) + sorted_improvements = improvements[sorted_indices] + + # compute reinitialize condition + emit_count = emitter_state.emit_count + 1 + + # check if the criteria are too similar + sorted_criteria = ranking_criteria[sorted_indices] + flat_criteria_condition = ( + jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12 + ) + + # check all conditions + reinitialize = ( + jnp.all(improvements < 0) * (emit_count > self._min_count) + + (emit_count > self._max_count) + + self._cmaes.stop_condition(cmaes_state) + + flat_criteria_condition + ) + + # If true, draw randomly and re-initialize parameters + def update_and_reinit( + operand: Tuple[ + CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey + ], + ) -> Tuple[CMAEmitterState, RNGKey]: + return self._update_and_init_emitter_state(*operand) + + def update_wo_reinit( + operand: Tuple[ + CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey + ], + ) -> Tuple[CMAEmitterState, RNGKey]: + """Update the emitter when no reinit event happened. + + Here lies a divergence compared to the original implementation. We + are getting better results when using no mask and doing the update + with the whole batch of individuals rather than keeping only the one + than were added to the archive. + + Interestingly, keeping the best half was not doing better. We think that + this might be due to the small batch size used. + + This applies for the setting from the paper CMA-ME. Those facts might + not be true with other problems and hyperparameters. + + To replicate the code described in the paper, replace: + `mask = jnp.ones_like(sorted_improvements)` + + by: + ``` + mask = sorted_improvements >= 0 + mask = mask + 1e-6 + ``` + + RMQ: the addition of 1e-6 is here to fix a numerical + instability. + """ + + (cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand + + # Update CMA Parameters + mask = jnp.ones_like(sorted_improvements) + + cmaes_state = self._cmaes.update_state_with_mask( + cmaes_state, sorted_candidates, mask=mask + ) + + emitter_state = emitter_state.replace( + cmaes_state=cmaes_state, + emit_count=emit_count, + ) + + return emitter_state, random_key + + # Update CMA Parameters + emitter_state, random_key = jax.lax.cond( + reinitialize, + update_and_reinit, + update_wo_reinit, + operand=( + cmaes_state, + emitter_state, + repertoire, + emit_count, + emitter_state.random_key, + ), + ) + + # update the emitter state + emitter_state = emitter_state.replace( + random_key=random_key, previous_fitnesses=repertoire.fitnesses + ) + + return emitter_state + + def _update_and_init_emitter_state( + self, + cmaes_state: CMAESState, + emitter_state: CMAEmitterState, + repertoire: MapElitesRepertoire, + emit_count: int, + random_key: RNGKey, + ) -> Tuple[CMAEmitterState, RNGKey]: + """Update the emitter state in the case of a reinit event. + Reinit the cmaes state and use an individual from the repertoire + as the starting mean. + + Args: + cmaes_state: current cmaes state + emitter_state: current cmame state + repertoire: most recent repertoire + emit_count: counter of the emitter + random_key: key to handle stochastic events + + Returns: + The updated emitter state. + """ + + # re-sample + random_genotype, random_key = repertoire.sample(random_key, 1) + + # remove the batch dim + new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) + + cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0) + + emitter_state = emitter_state.replace( + cmaes_state=cmaes_init_state, emit_count=0 + ) + + return emitter_state, random_key + + @abstractmethod + def _ranking_criteria( + self, + emitter_state: CMAEmitterState, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: Optional[ExtraScores], + improvements: jnp.ndarray, + ) -> jnp.ndarray: + """Defines how the genotypes should be sorted. Impacts the update + of the CMAES state. In the end, this defines the type of CMAES emitter + used (optimizing, random direction or improvement). + + Args: + emitter_state: current state of the emitter. + repertoire: latest repertoire of genotypes. + genotypes: emitted genotypes. + fitnesses: corresponding fitnesses. + descriptors: corresponding fitnesses. + extra_scores: corresponding extra scores. + improvements: improvments of the emitted genotypes. This corresponds + to the difference between their fitness and the fitness of the + individual occupying the cell of corresponding fitness. + + Returns: + The values to take into account in order to rank the emitted genotypes. + Here, it's the improvement, or the fitness when the cell was previously + unoccupied. Additionally, genotypes that discovered a new cell are + given on offset to be ranked in front of other genotypes. + """ + + pass diff --git a/qdax/core/emitters/cma_improvement_emitter.py b/qdax/core/emitters/cma_improvement_emitter.py new file mode 100644 index 00000000..28424f3f --- /dev/null +++ b/qdax/core/emitters/cma_improvement_emitter.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Optional + +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype + + +class CMAImprovementEmitter(CMAEmitter): + """Class for the emitter of CMA ME from "Covariance Matrix Adaptation + for the Rapid Illumination of Behavior Space" by Fontaine et al. + + This class implements the improvement emitter, where the update of the + distribution is biased towards solution that improve the QD score. + + Args: + batch_size: number of solutions sampled at each iteration + genotype_dim: dimension of the genotype space. + centroids: centroids used for the repertoire. + sigma_g: standard deviation for the coefficients - called step size. + min_count: minimum number of CMAES opt step before being considered for + reinitialisation. + max_count: maximum number of CMAES opt step authorized. + """ + + def _ranking_criteria( + self, + emitter_state: CMAEmitterState, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: Optional[ExtraScores], + improvements: jnp.ndarray, + ) -> jnp.ndarray: + """Defines how the genotypes should be sorted. Impacts the update + of the CMAES state. In the end, this defines the type of CMAES emitter + used (optimizing, random direction or improvement). + + Args: + emitter_state: current state of the emitter. + repertoire: latest repertoire of genotypes. + genotypes: emitted genotypes. + fitnesses: corresponding fitnesses. + descriptors: corresponding fitnesses. + extra_scores: corresponding extra scores. + improvements: improvments of the emitted genotypes. This corresponds + to the difference between their fitness and the fitness of the + individual occupying the cell of corresponding fitness. + + Returns: + The values to take into account in order to rank the emitted genotypes. + Here, it's the improvement, or the fitness when the cell was previously + unoccupied. Additionally, genotypes that discovered a new cell are + given on offset to be ranked in front of other genotypes. + """ + + # condition for being a new cell + condition = improvements == jnp.inf + + # criteria: fitness if new cell, improvement else + ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + + # make sure to have all the new cells first + new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) + + ranking_criteria = jnp.where( + condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + ) + + return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index cce2a8a3..f63654fd 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -12,7 +12,15 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Gradient, RNGKey +from qdax.types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Gradient, + RNGKey, +) class CMAMEGAState(EmitterState): @@ -26,12 +34,15 @@ class CMAMEGAState(EmitterState): state update only, another key is used to emit. This might be subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm + previous_fitnesses: store last fitnesses of the repertoire. Used to + compute the improvment. """ theta: Genotype theta_grads: Gradient random_key: RNGKey cmaes_state: CMAESState + previous_fitnesses: Fitness class CMAMEGAEmitter(Emitter): @@ -43,8 +54,8 @@ def __init__( batch_size: int, learning_rate: float, num_descriptors: int, + centroids: Centroid, sigma_g: float, - step_size: Optional[float] = None, ): """ Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by @@ -57,21 +68,21 @@ def __init__( batch_size: number of solutions sampled at each iteration learning_rate: rate at which the mean of the distribution is updated. num_descriptors: number of descriptors + centroids: centroids of the repertoire used to store the genotypes sigma_g: standard deviation for the coefficients - step_size: size of the steps used in CMAES updates """ + self._scoring_function = scoring_function self._batch_size = batch_size self._learning_rate = learning_rate + + # weights used to update the gradient direction through a linear combination self._weights = jnp.expand_dims( jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1 ) self._weights = self._weights / (self._weights.sum()) - if step_size is None: - step_size = 1.0 - - # define a CMAES instance + # define a CMAES instance - used to update the coeffs self._cmaes = CMAES( population_size=batch_size, search_dim=num_descriptors + 1, @@ -79,10 +90,12 @@ def __init__( fitness_function=None, # type: ignore num_best=batch_size, init_sigma=sigma_g, - init_step_size=step_size, bias_weights=True, + delay_eigen_decomposition=True, ) + self._centroids = centroids + self._cma_initial_state = self._cmaes.init() @partial(jax.jit, static_argnames=("self",)) @@ -90,7 +103,7 @@ def init( self, init_genotypes: Genotype, random_key: RNGKey ) -> Tuple[CMAMEGAState, RNGKey]: """ - Initializes the CMA-MEGA emitter + Initializes the CMA-MEGA emitter. Args: @@ -111,19 +124,24 @@ def init( _, _, extra_score, random_key = self._scoring_function(theta, random_key) theta_grads = extra_score["normalized_grads"] + # Initialize repertoire with default values + num_centroids = self._centroids.shape[0] + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + # return the initial state random_key, subkey = jax.random.split(random_key) return ( CMAMEGAState( theta=theta, theta_grads=theta_grads, - cmaes_state=self._cma_initial_state, random_key=subkey, + cmaes_state=self._cma_initial_state, + previous_fitnesses=default_fitnesses, ), random_key, ) - @partial(jax.jit, static_argnames=("self", "batch_size")) + @partial(jax.jit, static_argnames=("self",)) def emit( self, repertoire: Optional[MapElitesRepertoire], @@ -131,7 +149,7 @@ def emit( random_key: RNGKey, ) -> Tuple[Genotype, RNGKey]: """ - Emits new individuals. Interestingly, this method does not directly modify + Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function. @@ -152,7 +170,7 @@ def emit( grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0)) # Draw random coefficients - use the emitter state key - coeffs, _ = self._cmaes.sample( + coeffs, random_key = self._cmaes.sample( cmaes_state=cmaes_state, random_key=emitter_state.random_key ) @@ -208,9 +226,23 @@ def state_update( # Update the archive and compute the improvements indices = get_cells_indices(descriptors, repertoire.centroids) - improvements = fitnesses - repertoire.fitnesses[indices] + improvements = fitnesses - emitter_state.previous_fitnesses[indices] + + # condition for being a new cell + condition = improvements == jnp.inf - sorted_indices = jnp.argsort(improvements)[::-1] + # criteria: fitness if new cell, improvement else + ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + + # make sure to have all the new cells first + new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) + + ranking_criteria = jnp.where( + condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + ) + + # sort indices according to the criteria + sorted_indices = jnp.flip(jnp.argsort(ranking_criteria)) # Draw the coeffs - reuse the emitter state key to get same coeffs coeffs, random_key = self._cmaes.sample( @@ -235,38 +267,23 @@ def state_update( cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates) # If no improvement draw randomly and re-initialize parameters - reinitialize = jnp.all(improvements < 0) + reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition( + cmaes_state + ) # re-sample random_theta, random_key = repertoire.sample(random_key, 1) - # update - theta = jnp.nan_to_num(theta) * (1 - reinitialize) + random_theta * reinitialize - mean = self._cma_initial_state.mean * reinitialize + jnp.nan_to_num( - cmaes_state.mean - ) * (1 - reinitialize) - cov = self._cma_initial_state.cov_matrix * reinitialize + jnp.nan_to_num( - cmaes_state.cov_matrix - ) * (1 - reinitialize) - p_c = self._cma_initial_state.p_c * reinitialize + jnp.nan_to_num( - cmaes_state.p_c - ) * (1 - reinitialize) - p_s = self._cma_initial_state.p_s * reinitialize + jnp.nan_to_num( - cmaes_state.p_s - ) * (1 - reinitialize) - step_size = self._cma_initial_state.step_size * reinitialize + jnp.nan_to_num( - cmaes_state.step_size - ) * (1 - reinitialize) - num_updates = 1 * reinitialize + cmaes_state.num_updates * (1 - reinitialize) - - # define new cmaes state - cmaes_state = CMAESState( - mean=mean, - cov_matrix=cov, - p_c=p_c, - p_s=p_s, - step_size=step_size, - num_updates=num_updates, + # update theta in case of reinit + theta = jax.tree_util.tree_map( + lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta + ) + + # update cmaes state in case of reinit + cmaes_state = jax.tree_util.tree_map( + lambda x, y: jnp.where(reinitialize, x=x, y=y), + self._cma_initial_state, + cmaes_state, ) # score theta @@ -276,8 +293,9 @@ def state_update( emitter_state = CMAMEGAState( theta=theta, theta_grads=extra_score["normalized_grads"], - cmaes_state=cmaes_state, random_key=random_key, + cmaes_state=cmaes_state, + previous_fitnesses=repertoire.fitnesses, ) return emitter_state diff --git a/qdax/core/emitters/cma_opt_emitter.py b/qdax/core/emitters/cma_opt_emitter.py new file mode 100644 index 00000000..d9c5bf71 --- /dev/null +++ b/qdax/core/emitters/cma_opt_emitter.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Optional + +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype + + +class CMAOptimizingEmitter(CMAEmitter): + def _ranking_criteria( + self, + emitter_state: CMAEmitterState, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: Optional[ExtraScores], + improvements: jnp.ndarray, + ) -> jnp.ndarray: + """Defines how the genotypes should be sorted. Impacts the update + of the CMAES state. In the end, this defines the type of CMAES emitter + used (optimizing, random direction or improvement). + + Args: + emitter_state: current state of the emitter. + repertoire: latest repertoire of genotypes. + genotypes: emitted genotypes. + fitnesses: corresponding fitnesses. + descriptors: corresponding fitnesses. + extra_scores: corresponding extra scores. + improvements: improvments of the emitted genotypes. This corresponds + to the difference between their fitness and the fitness of the + individual occupying the cell of corresponding fitness. + + Returns: + The values to take into account in order to rank the emitted genotypes. + Here, it is the fitness of the genotype. + """ + + return fitnesses diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py new file mode 100644 index 00000000..d5424a01 --- /dev/null +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from functools import partial +from typing import Any, Optional, Tuple + +import jax +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +class CMAPoolEmitterState(EmitterState): + """ + Emitter state for the pool of CMA emitters. + + This is for a pool of homogeneous emitters. + + Args: + current_index: the index of the current emitter state used. + emitter_states: the batch of emitter states currently used. + """ + + current_index: int + emitter_states: CMAEmitterState + + +class CMAPoolEmitter(Emitter): + def __init__(self, num_states: int, emitter: CMAEmitter): + """Instantiate a pool of homogeneous emitters. + + Args: + num_states: the number of emitters to consider. We can use a + single emitter object and a batched emitter state. + emitter: the type of emitter for the pool. + """ + self._num_states = num_states + self._emitter = emitter + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._emitter.batch_size + + @partial(jax.jit, static_argnames=("self",)) + def init( + self, init_genotypes: Genotype, random_key: RNGKey + ) -> Tuple[CMAPoolEmitterState, RNGKey]: + """ + Initializes the CMA-MEGA emitter + + + Args: + init_genotypes: initial genotypes to add to the grid. + random_key: a random key to handle stochastic operations. + + Returns: + The initial state of the emitter. + """ + + def scan_emitter_init( + carry: RNGKey, unused: Any + ) -> Tuple[RNGKey, CMAEmitterState]: + random_key = carry + emitter_state, random_key = self._emitter.init(init_genotypes, random_key) + return random_key, emitter_state + + # init all the emitter states + random_key, emitter_states = jax.lax.scan( + scan_emitter_init, random_key, (), length=self._num_states + ) + + # define the emitter state of the pool + emitter_state = CMAPoolEmitterState( + current_index=0, emitter_states=emitter_states + ) + + return ( + emitter_state, + random_key, + ) + + @partial(jax.jit, static_argnames=("self",)) + def emit( + self, + repertoire: Optional[MapElitesRepertoire], + emitter_state: CMAPoolEmitterState, + random_key: RNGKey, + ) -> Tuple[Genotype, RNGKey]: + """ + Emits new individuals. + + Args: + repertoire: a repertoire of genotypes (unused). + emitter_state: the state of the CMA-MEGA emitter. + random_key: a random key to handle random operations. + + Returns: + New genotypes and a new random key. + """ + + # retrieve the relevant emitter state + current_index = emitter_state.current_index + used_emitter_state = jax.tree_util.tree_map( + lambda x: x[current_index], emitter_state.emitter_states + ) + + # use it to emit offsprings + offsprings, random_key = self._emitter.emit( + repertoire, used_emitter_state, random_key + ) + + return offsprings, random_key + + @partial( + jax.jit, + static_argnames=("self",), + ) + def state_update( + self, + emitter_state: CMAPoolEmitterState, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: Optional[ExtraScores] = None, + ) -> Optional[EmitterState]: + """ + Updates the emitter state. + + Args: + emitter_state: current emitter state + repertoire: the current genotypes repertoire + genotypes: the genotypes of the batch of emitted offspring (unused). + fitnesses: the fitnesses of the batch of emitted offspring. + descriptors: the descriptors of the emitted offspring. + extra_scores: unused + + Returns: + The updated emitter state. + """ + + # retrieve the emitter that has been used and it's emitter state + current_index = emitter_state.current_index + emitter_states = emitter_state.emitter_states + + used_emitter_state = jax.tree_util.tree_map( + lambda x: x[current_index], emitter_states + ) + + # update the used emitter state + used_emitter_state = self._emitter.state_update( + emitter_state=used_emitter_state, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + + # update the emitter state + emitter_states = jax.tree_util.tree_map( + lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state + ) + + # determine the next emitter to be used + emit_counts = emitter_states.emit_count + + new_index = jnp.argmin(emit_counts) + + emitter_state = emitter_state.replace( + current_index=new_index, emitter_states=emitter_states + ) + + return emitter_state # type: ignore diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py new file mode 100644 index 00000000..4afb2f5d --- /dev/null +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from functools import partial +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp + +from qdax.core.cmaes import CMAESState +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +class CMARndEmitterState(CMAEmitterState): + """ + Emitter state for the CMA-ME random direction emitter. + + + Args: + random_key: a random key to handle stochastic operations. Used for + state update only, another key is used to emit. This might be + subject to refactoring discussions in the future. + cmaes_state: state of the underlying CMA-ES algorithm + previous_fitnesses: store last fitnesses of the repertoire. Used to + compute the improvment. + emit_count: count the number of emission events. + random_direction: direction of the behavior space we are trying to + explore. + """ + + random_direction: Descriptor + + +class CMARndEmitter(CMAEmitter): + @partial(jax.jit, static_argnames=("self",)) + def init( + self, init_genotypes: Genotype, random_key: RNGKey + ) -> Tuple[CMARndEmitterState, RNGKey]: + """ + Initializes the CMA-MEGA emitter + + + Args: + init_genotypes: initial genotypes to add to the grid. + random_key: a random key to handle stochastic operations. + + Returns: + The initial state of the emitter. + """ + + # Initialize repertoire with default values + num_centroids = self._centroids.shape[0] + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + + # take a random direction + random_key, subkey = jax.random.split(random_key) + random_direction = jax.random.uniform( + subkey, + shape=(self._centroids.shape[-1],), + ) + + # return the initial state + random_key, subkey = jax.random.split(random_key) + + return ( + CMARndEmitterState( + random_key=subkey, + cmaes_state=self._cma_initial_state, + previous_fitnesses=default_fitnesses, + emit_count=0, + random_direction=random_direction, + ), + random_key, + ) + + def _update_and_init_emitter_state( + self, + cmaes_state: CMAESState, + emitter_state: CMAEmitterState, + repertoire: MapElitesRepertoire, + emit_count: int, + random_key: RNGKey, + ) -> Tuple[CMAEmitterState, RNGKey]: + """Update the emitter state in the case of a reinit event. + Reinit the cmaes state and use an individual from the repertoire + as the starting mean. + + Args: + cmaes_state: current cmaes state + emitter_state: current cmame state + repertoire: most recent repertoire + emit_count: counter of the emitter + random_key: key to handle stochastic events + + Returns: + The updated emitter state. + """ + + # re-sample + random_genotype, random_key = repertoire.sample(random_key, 1) + + # get new mean - remove the batch dim + new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) + + # define the corresponding cmaes init state + cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0) + + # take a new random direction + random_key, subkey = jax.random.split(random_key) + random_direction = jax.random.uniform( + subkey, + shape=(self._centroids.shape[-1],), + ) + + emitter_state = emitter_state.replace( + cmaes_state=cmaes_init_state, + emit_count=0, + random_direction=random_direction, + ) + + return emitter_state, random_key + + def _ranking_criteria( + self, + emitter_state: CMARndEmitterState, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: Optional[ExtraScores], + improvements: jnp.ndarray, + ) -> jnp.ndarray: + """Defines how the genotypes should be sorted. Impacts the update + of the CMAES state. In the end, this defines the type of CMAES emitter + used (optimizing, random direction or improvement). + + Args: + emitter_state: current state of the emitter. + repertoire: latest repertoire of genotypes. + genotypes: emitted genotypes. + fitnesses: corresponding fitnesses. + descriptors: corresponding fitnesses. + extra_scores: corresponding extra scores. + improvements: improvments of the emitted genotypes. This corresponds + to the difference between their fitness and the fitness of the + individual occupying the cell of corresponding fitness. + + Returns: + The values to take into account in order to rank the emitted genotypes. + Here, it is the dot product of the descriptor with the current random + direction. + """ + + # criteria: projection of the descriptors along the random direction + ranking_criteria = jnp.dot(descriptors, emitter_state.random_direction) + + # make sure to have all the new cells first + new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) + + # condition for being a new cell + condition = improvements == jnp.inf + + ranking_criteria = jnp.where( + condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + ) + + return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 503de574..9de27eb0 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -28,7 +28,7 @@ class OMGMEGAEmitter(Emitter): NOTE: in order to implement this emitter while staying in the MAPElites framework, we had to make two temporary design choices: - - in the emit_fn function, we use the same random key to sample from the + - in the emit function, we use the same random key to sample from the genotypes and gradients repertoire, in order to get the gradients that correspond to the right genotypes. Although acceptable, this is definitely not the best coding practice and we would prefer to get rid of this in a @@ -37,7 +37,7 @@ class OMGMEGAEmitter(Emitter): sampling the indices to be sampled, the other one retrieving the corresponding elements. This would enable to reuse the indices instead of doing this double sampling. - - in the state_update_fn, we have to insert the gradients in the gradients + - in the state_update, we have to insert the gradients in the gradients repertoire in the same way the individuals were inserted. Once again, this is slightly unoptimal because the same addition mecanism has to be computed two times. One solution that we are discussing and that is very similar to the first @@ -65,13 +65,20 @@ def __init__( Args: batch_size: number of solutions sampled at each iteration - sigma_g: standard deviation for the coefficients + sigma_g: CAUTION - square of the standard deviation for the coefficients. + This notation can be misleading as, although it's called sigma, it + refers to the variance and not the standard deviation. num_descriptors: number of descriptors centroids: centroids used to create the repertoire of solutions. This will be used to create the repertoire of gradients. """ + # set the mean of the coeff distribution to zero self._mu = jnp.zeros(num_descriptors + 1) + + # set the cov matrix to sigma * I self._sigma = jnp.eye(num_descriptors + 1) * sigma_g + + # define other parameters of the distribution self._batch_size = batch_size self._centroids = centroids self._num_descriptors = num_descriptors @@ -103,7 +110,7 @@ def init( shape=(num_centroids, self._centroids.shape[-1]) ) - # instantiate de gradients repertoire + # instantiate the gradients repertoire gradients_repertoire = MapElitesRepertoire( genotypes=default_gradients, fitnesses=default_fitnesses, @@ -147,7 +154,7 @@ def emit( # sample gradients - use the same random key for sampling # See class docstrings for discussion about this choice - (gradients, random_key,) = emitter_state.gradients_repertoire.sample( + gradients, random_key = emitter_state.gradients_repertoire.sample( random_key, num_samples=self._batch_size ) diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py new file mode 100644 index 00000000..c86bd622 --- /dev/null +++ b/tests/baselines_test/cmame_test.py @@ -0,0 +1,129 @@ +"""Tests CMA ME implementation""" + +from typing import Dict, Tuple, Type + +import jax +import jax.numpy as jnp +import pytest + +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + compute_euclidean_centroids, +) +from qdax.core.emitters.cma_emitter import CMAEmitter +from qdax.core.emitters.cma_improvement_emitter import CMAImprovementEmitter +from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter +from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter +from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter +from qdax.core.map_elites import MAPElites +from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey + + +@pytest.mark.parametrize( + "emitter_type", + [CMAOptimizingEmitter, CMARndEmitter, CMAImprovementEmitter], +) +def test_cma_me(emitter_type: Type[CMAEmitter]) -> None: + + num_iterations = 1000 + num_dimensions = 20 + grid_shape = (50, 50) + batch_size = 36 + sigma_g = 0.5 + minval = -5.12 + maxval = 5.12 + min_bd = -5.12 * 0.5 * num_dimensions + max_bd = 5.12 * 0.5 * num_dimensions + pool_size = 3 + + def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: + return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1) + + fitness_scoring = sphere_scoring + + def clip(x: jnp.ndarray) -> jnp.ndarray: + in_bound = (x <= maxval) * (x >= minval) + return jnp.where(condition=in_bound, x=x, y=(maxval / x)) + + def _behavior_descriptor_1(x: jnp.ndarray) -> jnp.ndarray: + return jnp.sum(clip(x[: x.shape[-1] // 2])) + + def _behavior_descriptor_2(x: jnp.ndarray) -> jnp.ndarray: + return jnp.sum(clip(x[x.shape[-1] // 2 :])) + + def _behavior_descriptors(x: jnp.ndarray) -> jnp.ndarray: + return jnp.array([_behavior_descriptor_1(x), _behavior_descriptor_2(x)]) + + def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, Dict]: + scores, descriptors = fitness_scoring(x), _behavior_descriptors(x) + return scores, descriptors, {} + + def scoring_fn( + x: jnp.ndarray, random_key: RNGKey + ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) + return fitnesses, descriptors, extra_scores, random_key + + worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12) + best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) + + def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: + + # get metrics + grid_empty = repertoire.fitnesses == -jnp.inf + adjusted_fitness = ( + (repertoire.fitnesses - worst_objective) + * 100 + / (best_objective - worst_objective) + ) + qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) # / num_centroids + coverage = 100 * jnp.mean(1.0 - grid_empty) + max_fitness = jnp.max(adjusted_fitness) + return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + + random_key = jax.random.PRNGKey(0) + initial_population = ( + jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0 + ) + + centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, + ) + + emitter_kwargs = { + "batch_size": batch_size, + "genotype_dim": num_dimensions, + "centroids": centroids, + "sigma_g": sigma_g, + "min_count": 1, + "max_count": None, + } + + emitter = emitter_type(**emitter_kwargs) + + emitter = CMAPoolEmitter(num_states=pool_size, emitter=emitter) + + map_elites = MAPElites( + scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn + ) + + repertoire, emitter_state, random_key = map_elites.init( + initial_population, centroids, random_key + ) + + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + map_elites.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(metrics["coverage"][-1] > 25) + pytest.assume(metrics["max_fitness"][-1] > 95) + pytest.assume(metrics["qd_score"][-1] > 50000) + + +if __name__ == "__main__": + test_cma_me(emitter_type=CMAEmitter) diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index 8f327a95..fdd9330b 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -114,6 +114,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: batch_size=batch_size, learning_rate=learning_rate, num_descriptors=2, + centroids=centroids, sigma_g=sigma_g, ) diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py new file mode 100644 index 00000000..16321fd4 --- /dev/null +++ b/tests/core_test/cmaes_test.py @@ -0,0 +1,59 @@ +"""Tests CMA ES implementation""" + +import jax +import jax.numpy as jnp +import pytest + +from qdax.core.cmaes import CMAES + + +def test_cmaes() -> None: + + num_iterations = 10000 + num_dimensions = 100 + batch_size = 36 + num_best = 36 + sigma_g = 0.5 + minval = -5.12 + + def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: + return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1) + + fitness_fn = sphere_scoring + + cmaes = CMAES( + population_size=batch_size, + num_best=num_best, + search_dim=num_dimensions, + fitness_function=fitness_fn, # type: ignore + mean_init=jnp.zeros((num_dimensions,)), + init_sigma=sigma_g, + delay_eigen_decomposition=True, + ) + + state = cmaes.init() + random_key = jax.random.PRNGKey(0) + + iteration_count = 0 + for _ in range(num_iterations): + iteration_count += 1 + + # sample + samples, random_key = cmaes.sample(state, random_key) + + # udpate + state = cmaes.update(state, samples) + + # check stop condition + stop_condition = cmaes.stop_condition(state) + + if stop_condition: + break + + fitnesses = fitness_fn(samples) + + pytest.assume(jnp.min(fitnesses) < 0.001) + + +if __name__ == "__main__": + test_cmaes()