Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add algorithm CMA-ME #86

Merged
merged 50 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
a4b5ef5
WIP: cma me emitter
felixchalumeau Sep 6, 2022
f20d645
create notebook for CMA ME - works
felixchalumeau Sep 6, 2022
7ae227e
WIP: refactor code
felixchalumeau Sep 6, 2022
abbd611
WIP: debug cma me
felixchalumeau Sep 6, 2022
c653666
wip: refactor cma me emitter
felixchalumeau Sep 7, 2022
bff2493
wip: implementation of the random direction emitter
felixchalumeau Sep 7, 2022
1f54bdb
WIP: refacto + implem opt and rnd emitters
felixchalumeau Sep 7, 2022
00c367d
update todos comments
felixchalumeau Sep 7, 2022
77ffd24
wip: reproduce paper results
felixchalumeau Sep 8, 2022
8c84b90
WIP: pool of emitters
felixchalumeau Sep 8, 2022
72336c6
implement pool of emitter
felixchalumeau Sep 9, 2022
dda24ad
core fixes in cma es __init__ function
felixchalumeau Oct 3, 2022
b5ad06f
WIP: debug cma es optimizer
felixchalumeau Oct 3, 2022
a56a947
WIP: debugging cma es optimizer
felixchalumeau Oct 4, 2022
805ea97
WIP: debugging cma es
felixchalumeau Oct 4, 2022
b638534
no ci
felixchalumeau Oct 4, 2022
ae55d75
WIP: implement delayed cov matrix decomposition
felixchalumeau Oct 5, 2022
d6319b1
implement delay and afix strange shape issue
felixchalumeau Oct 5, 2022
b410124
implement stop condition in cmaes and use it in cma me emitters
felixchalumeau Oct 5, 2022
3c021e8
clean code
felixchalumeau Oct 5, 2022
ca1cd2d
fix cma mega
felixchalumeau Oct 5, 2022
ee8285b
cmamega fix + start cleaning and adding docstrings
felixchalumeau Oct 6, 2022
120cbf3
Merge branch 'develop' into feat/add-algo-cma-me
felixchalumeau Oct 7, 2022
dc3cc3f
clean cmaes + (WIP) cmaes test
felixchalumeau Oct 7, 2022
b9cb394
fix key bug + implement max emissions
felixchalumeau Oct 7, 2022
42142d3
fix cmaes opt - by decoupling sigma and cov
felixchalumeau Oct 25, 2022
dbf71d4
additional updates
felixchalumeau Oct 25, 2022
b69661c
pre-commits
felixchalumeau Oct 25, 2022
63acbeb
minor fix
felixchalumeau Oct 26, 2022
300c042
update rnd emitter ranking
felixchalumeau Oct 27, 2022
557ca2d
start cleaning
felixchalumeau Oct 27, 2022
028d8e6
fix and clean
felixchalumeau Oct 28, 2022
0bad7ff
clean docstrinfs + add cmaes test
felixchalumeau Oct 28, 2022
d5bc8da
add test for cma me
felixchalumeau Oct 28, 2022
9b717ec
clean notebooks - update README
felixchalumeau Oct 31, 2022
51dd1b7
WIP: documentation
felixchalumeau Oct 31, 2022
d750af5
update doc
felixchalumeau Oct 31, 2022
709c460
fix path in doc
felixchalumeau Oct 31, 2022
eb49a84
Merge branch 'develop' into feat/add-algo-cma-me
felixchalumeau Oct 31, 2022
7b9a1dd
conflicts
felixchalumeau Oct 31, 2022
67f619b
fix conflicts and pre-commit issues
felixchalumeau Oct 31, 2022
bf25222
fix pre-commit issues
felixchalumeau Oct 31, 2022
ebc03c5
start resolving comments
felixchalumeau Nov 22, 2022
9921d13
untested - update cma emitters design
felixchalumeau Nov 22, 2022
58da203
udpate style and fix style
felixchalumeau Nov 22, 2022
ea6ab01
update test and notebook
felixchalumeau Nov 22, 2022
dcb2851
Merge branch 'develop' into feat/add-algo-cma-me
felixchalumeau Nov 22, 2022
92f7c79
Merge branch 'develop' into feat/add-algo-cma-me
felixchalumeau Nov 22, 2022
e1041ac
Merge branch 'develop' into feat/add-algo-cma-me
felixchalumeau Nov 24, 2022
d44dc76
add the batch size properties
felixchalumeau Nov 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
rev: 0.3.9
hooks:
- id: nbstripout
args: ["notebooks/"]
args: ["examples/"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
Expand Down
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

QDax is a tool to accelerate Quality-Diversity (QD) and neuro-evolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛

QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb)
QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)

- QDax [paper](https://arxiv.org/abs/2202.01258)
- QDax [documentation](https://qdax.readthedocs.io/en/latest/)
Expand All @@ -32,7 +32,7 @@ Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To
However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/).

## Basic API Usage
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/mapelites.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).

However, a summary of the main API usage is provided below:
```python
Expand Down Expand Up @@ -124,24 +124,25 @@ QDax currently supports the following algorithms:

| Algorithm | Example |
| --- | --- |
| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) |
| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) |
| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/pgame_example.ipynb) |
| [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/omgmega_example.ipynb) |
| [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/cmamega_example.ipynb) |
| [Multi-Objective Quality-Diversity (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mome_example.ipynb) |
| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) |
| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) |
| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) |
| [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmame.ipynb) |
| [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb) |
| [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmamega.ipynb) |
| [Multi-Objective Quality-Diversity (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) |


## QDax baseline algorithms
The QDax library also provides implementations for some useful baseline algorithms:

| Algorithm | Example |
| --- | --- |
| [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/diayn_example.ipynb) |
| [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/dads_example.ipynb) |
| [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/smerl_example.ipynb) |
| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |
| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |
| [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/diayn.ipynb) |
| [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dads.ipynb) |
| [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/smerl.ipynb) |
| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb) |
| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb) |

## QDax Tasks
The QDax library also provides numerous implementations for several standard Quality-Diversity tasks.
Expand Down
13 changes: 13 additions & 0 deletions docs/api_documentation/core/cmame.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Covariance Matrix Adaptation MAP Elites (CMAME)

To create an instance of CMAME, one need to use an instance of [MAP-Elites](map_elites.md) with the desired CMA Emitter - optimizing, random direction, improvement - detailed below.To use the pool of emitter mechanism, use the CMAPoolEmitter.

Three emitter types:

::: qdax.core.emitters.cma_emitter.CMAEmitter
::: qdax.core.emitters.cma_rnd_emitter.CMARndEmitter
::: qdax.core.emitters.cma_opt_emitter.CMAOptimizingEmitter

Pool of homogeneous emitters:

::: qdax.core.emitters.cma_pool_emitter.CMAPoolEmitter
1 change: 1 addition & 0 deletions docs/examples
1 change: 0 additions & 1 deletion docs/notebooks

This file was deleted.

308 changes: 308 additions & 0 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "222bbe00",
"metadata": {},
"source": [
"# Optimizing with CMA-ES in Jax\n",
"\n",
"This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n",
"\n",
"- how to define the problem\n",
"- how to create a CMA-ES optimizer\n",
"- how to launch a certain number of optimizing steps\n",
"- how to visualise the optimization process"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d731f067",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"\n",
"from qdax.core.cmaes import CMAES"
]
},
{
"cell_type": "markdown",
"id": "7b6e910b",
"metadata": {},
"source": [
"## Set the hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "404fb0dc",
"metadata": {},
"outputs": [],
"source": [
"#@title Hyperparameters\n",
"#@markdown ---\n",
"num_iterations = 1000 #@param {type:\"integer\"}\n",
"num_dimensions = 100 #@param {type:\"integer\"}\n",
"batch_size = 36 #@param {type:\"integer\"}\n",
"num_best = 18 #@param {type:\"integer\"}\n",
"sigma_g = 0.5 # 0.5 #@param {type:\"number\"}\n",
"minval = -5.12 #@param {type:\"number\"}\n",
"optim_problem = \"sphere\" #@param[\"rastrigin\", \"sphere\"]\n",
"#@markdown ---"
]
},
{
"cell_type": "markdown",
"id": "ccc7cbeb",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Define the fitness function - choose rastrigin or sphere"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "436dccbb",
"metadata": {},
"outputs": [],
"source": [
"def rastrigin_scoring(x: jnp.ndarray):\n",
" first_term = 10 * x.shape[-1]\n",
" second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n",
" return -(first_term + second_term)\n",
"\n",
"def sphere_scoring(x: jnp.ndarray):\n",
" return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n",
"\n",
"if optim_problem == \"sphere\":\n",
" fitness_fn = sphere_scoring\n",
"elif optim_problem == \"rastrigin\":\n",
" fitness_fn = jax.vmap(rastrigin_scoring)\n",
"else:\n",
" raise Exception(\"Invalid opt function name given\")"
]
},
{
"cell_type": "markdown",
"id": "62bdd2a4",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Define a CMA-ES optimizer instance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cf03f55",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"cmaes = CMAES(\n",
" population_size=batch_size,\n",
" num_best=num_best,\n",
" search_dim=num_dimensions,\n",
" fitness_function=fitness_fn,\n",
" mean_init=jnp.zeros((num_dimensions,)),\n",
" init_sigma=sigma_g,\n",
" delay_eigen_decomposition=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f1f69f50",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Init the CMA-ES optimizer state"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a95b74d",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"state = cmaes.init()\n",
"random_key = jax.random.PRNGKey(0)"
]
},
{
"cell_type": "markdown",
"id": "ac2d5c0d",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Run optimization iterations"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "363198ca",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"means = [state.mean]\n",
"covs = [(state.sigma**2) * state.cov_matrix]\n",
"\n",
"iteration_count = 0\n",
"for _ in range(num_iterations):\n",
" iteration_count += 1\n",
" \n",
" # sample\n",
" samples, random_key = cmaes.sample(state, random_key)\n",
" \n",
" # udpate\n",
" state = cmaes.update(state, samples)\n",
" \n",
" # check stop condition\n",
" stop_condition = cmaes.stop_condition(state)\n",
"\n",
" if stop_condition:\n",
" break\n",
" \n",
" # store data for plotting\n",
" means.append(state.mean)\n",
" covs.append((state.sigma**2) * state.cov_matrix)\n",
" \n",
"print(\"Num iterations before stop condition: \", iteration_count)"
]
},
{
"cell_type": "markdown",
"id": "0e5820b8",
"metadata": {},
"source": [
"## Check final fitnesses and distribution mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e4a2c7b",
"metadata": {},
"outputs": [],
"source": [
"# checking final fitness values\n",
"fitnesses = fitness_fn(samples)\n",
"\n",
"print(\"Min fitness in the final population: \", jnp.min(fitnesses))\n",
"print(\"Mean fitness in the final population: \", jnp.mean(fitnesses))\n",
"print(\"Max fitness in the final population: \", jnp.max(fitnesses))\n",
"\n",
"# checking mean of the final distribution\n",
"print(\"Final mean of the distribution: \\n\", means[-1])\n",
"# print(\"Final covariance matrix of the distribution: \", covs[-1])"
]
},
{
"cell_type": "markdown",
"id": "f3bd2b0f",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Visualization of the optimization trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad85551c",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# sample points to show fitness landscape\n",
"random_key, subkey = jax.random.split(random_key)\n",
"x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n",
"f_x = fitness_fn(x)\n",
"\n",
"# plot fitness landscape\n",
"points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)\n",
"fig.colorbar(points)\n",
"\n",
"# plot cma-es trajectory\n",
"traj_min = 0\n",
"traj_max = iteration_count\n",
"for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):\n",
" ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')\n",
" ax.add_patch(ellipse)\n",
" ax.plot(mean[0], mean[1], color='k', marker='x')\n",
" \n",
"ax.set_title(f\"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading