Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up codebase - standardize terminology and random key usage #190

Merged
merged 17 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
[![codecov](https://codecov.io/gh/adaptive-intelligent-robotics/QDax/branch/feat/add-codecov/graph/badge.svg)](https://codecov.io/gh/adaptive-intelligent-robotics/QDax)


QDax is a tool to accelerate Quality-Diversity (QD) and neuro-evolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛
QDax is a tool to accelerate Quality-Diversity (QD) and neuroevolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛

QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)

Expand Down Expand Up @@ -60,14 +60,14 @@ num_iterations = 50
grid_shape = (100, 100)
min_param = 0.0
max_param = 1.0
min_bd = 0.0
max_bd = 1.0
min_descriptor = 0.0
max_descriptor = 1.0

# Init a random key
random_key = jax.random.key(seed)
key = jax.random.key(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
key, subkey = jax.random.split(key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
Expand Down Expand Up @@ -106,19 +106,19 @@ map_elites = MAPElites(
# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=grid_shape,
minval=min_bd,
maxval=max_bd,
minval=min_descriptor,
maxval=max_descriptor,
)

# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)
repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key)

# Run MAP-Elites loop
for i in range(num_iterations):
(repertoire, emitter_state, metrics, random_key,) = map_elites.update(
(repertoire, emitter_state, metrics, key,) = map_elites.update(
repertoire,
emitter_state,
random_key,
key,
)

# Get contents of repertoire
Expand All @@ -133,6 +133,7 @@ QDax currently supports the following algorithms:
| Algorithm | Example |
|-------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) |
| [AURORA](https://arxiv.org/abs/2106.05648) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/aurora.ipynb) |
| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) |
| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) |
| [DCRL-ME](https://arxiv.org/abs/2401.08632) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dcrlme.ipynb) |
Expand Down
2 changes: 1 addition & 1 deletion docs/api_documentation/core/mels.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

[ME-LS](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) is a variant of
MAP-Elites that thrives the search process towards solutions that are consistent
in the behavior space for uncertain domains.
in the descriptor space for uncertain domains.

::: qdax.core.mels.MELS
10 changes: 5 additions & 5 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ More importantly, QDax handles the archive management which is the key idea of Q
## Code Example
```python
# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)
repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key)

for i in range(num_iterations):

# generate new population with the emitter
genotypes, random_key = map_elites._emitter.emit(
repertoire, emitter_state, random_key
genotypes, key = map_elites._emitter.emit(
repertoire, emitter_state, key
)

# scores/evaluates the population
fitnesses, descriptors, extra_scores, random_key = map_elites._scoring_function(
genotypes, random_key
fitnesses, descriptors, extra_scores, key = map_elites._scoring_function(
genotypes, key
)

# update repertoire
Expand Down
69 changes: 32 additions & 37 deletions examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimizing with AURORA in Jax\n",
"# Optimizing with AURORA in JAX\n",
"\n",
"This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [AURORA](https://arxiv.org/pdf/1905.11874.pdf).\n",
"It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n",
Expand Down Expand Up @@ -49,8 +49,7 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"!pip install ipympl |tail -n 1\n",
"!pip install ipympl | tail -n 1\n",
"# %matplotlib widget\n",
"# from google.colab import output\n",
"# output.enable_custom_widget_manager()\n",
Expand All @@ -71,7 +70,7 @@
" create_default_brax_task_components,\n",
" get_aurora_scoring_fn,\n",
")\n",
"from qdax.environments.bd_extractors import (\n",
"from qdax.environments.descriptor_extractors import (\n",
" AuroraExtraInfoNormalization,\n",
" get_aurora_encoding,\n",
")\n",
Expand All @@ -85,8 +84,8 @@
"\n",
"\n",
"if \"COLAB_TPU_ADDR\" in os.environ:\n",
" from jax.tools import colab_tpu\n",
" colab_tpu.setup_tpu()\n",
" from jax.tools import colab_tpu\n",
" colab_tpu.setup_tpu()\n",
"\n",
"\n",
"clear_output()"
Expand All @@ -110,8 +109,8 @@
"line_sigma = 0.05 #@param {type:\"number\"}\n",
"num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n",
"num_centroids = 1024 #@param {type:\"integer\"}\n",
"min_bd = 0. #@param {type:\"number\"}\n",
"max_bd = 1.0 #@param {type:\"number\"}\n",
"min_descriptor = 0. #@param {type:\"number\"}\n",
"max_descriptor = 1.0 #@param {type:\"number\"}\n",
"\n",
"lstm_batch_size = 128 #@param {type:\"integer\"}\n",
"\n",
Expand Down Expand Up @@ -146,7 +145,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.key(seed)\n",
"key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand All @@ -157,14 +156,14 @@
")\n",
"\n",
"# Init population of controllers\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"keys = jax.random.split(subkey, num=batch_size)\n",
"fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n",
"init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n",
"\n",
"\n",
"# Create the initial environment states\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n",
"reset_fn = jax.jit(jax.vmap(env.reset))\n",
"init_states = reset_fn(keys)"
Expand All @@ -187,9 +186,9 @@
"source": [
"# Define the fonction to play a step with the policy in the environment\n",
"def play_step_fn(\n",
" env_state,\n",
" policy_params,\n",
" random_key,\n",
" env_state,\n",
" policy_params,\n",
" key,\n",
"):\n",
" \"\"\"\n",
" Play an environment step and return the updated state and the transition.\n",
Expand All @@ -211,7 +210,7 @@
" next_state_desc=next_state.info[\"state_descriptor\"],\n",
" )\n",
"\n",
" return next_state, policy_params, random_key, transition"
" return next_state, policy_params, key, transition"
]
},
{
Expand All @@ -220,7 +219,7 @@
"source": [
"## Define the scoring function and the way metrics are computed\n",
"\n",
"The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual."
"The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual."
]
},
{
Expand All @@ -230,9 +229,10 @@
"outputs": [],
"source": [
"# Prepare the scoring function\n",
"env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n",
"key, subkey = jax.random.split(key)\n",
"env, policy_network, scoring_fn = create_default_brax_task_components(\n",
" env_name=env_name,\n",
" random_key=random_key,\n",
" key=subkey,\n",
")\n",
"\n",
"def observation_extractor_fn(\n",
Expand Down Expand Up @@ -324,24 +324,18 @@
"@jax.jit\n",
"def update_scan_fn(carry: Any, unused: Any) -> Any:\n",
" \"\"\"Scan the udpate function.\"\"\"\n",
" (\n",
" repertoire,\n",
" random_key,\n",
" aurora_extra_info\n",
" ) = carry\n",
" repertoire, key, aurora_extra_info = carry\n",
"\n",
" # update\n",
" (repertoire, _, metrics, random_key,) = aurora.update(\n",
" key, subkey = jax.random.split(key)\n",
" repertoire, _, metrics = aurora.update(\n",
" repertoire,\n",
" None,\n",
" random_key,\n",
" subkey,\n",
" aurora_extra_info=aurora_extra_info,\n",
" )\n",
"\n",
" return (\n",
" (repertoire, random_key, aurora_extra_info),\n",
" metrics,\n",
" )\n",
" return (repertoire, key, aurora_extra_info), metrics\n",
"\n",
"# Init algorithm\n",
"# AutoEncoder Params and INIT\n",
Expand All @@ -367,7 +361,7 @@
")\n",
"\n",
"# Init the model params\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"model_params = train_seq2seq.get_initial_params(\n",
" model, subkey, (1, *observations_dims)\n",
")\n",
Expand Down Expand Up @@ -410,18 +404,19 @@
")\n",
"\n",
"# init step of the aurora algorithm\n",
"repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n",
"key, subkey = jax.random.split(key)\n",
"repertoire, emitter_state, aurora_extra_info = aurora.init(\n",
" init_variables,\n",
" aurora_extra_info,\n",
" jnp.asarray(l_value_init),\n",
" max_observation_size,\n",
" random_key,\n",
" subkey,\n",
")\n",
"\n",
"# initializing means and stds and AURORA\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"repertoire, aurora_extra_info = aurora.train(\n",
" repertoire, model_params, iteration=0, random_key=subkey\n",
" repertoire, model_params, iteration=0, key=subkey\n",
")\n",
"\n",
"# design aurora's schedule\n",
Expand Down Expand Up @@ -455,11 +450,11 @@
"while iteration < max_iterations:\n",
"\n",
" (\n",
" (repertoire, random_key, aurora_extra_info),\n",
" (repertoire, key, aurora_extra_info),\n",
" metrics,\n",
" ) = jax.lax.scan(\n",
" update_scan_fn,\n",
" (repertoire, random_key, aurora_extra_info),\n",
" (repertoire, key, aurora_extra_info),\n",
" (),\n",
" length=log_freq,\n",
" )\n",
Expand All @@ -472,7 +467,7 @@
" # autoencoder steps and CVC\n",
" if (iteration + 1) in schedules:\n",
" # train the autoencoder\n",
" random_key, subkey = jax.random.split(random_key)\n",
" key, subkey = jax.random.split(key)\n",
" repertoire, aurora_extra_info = aurora.train(\n",
" repertoire, model_params, iteration, subkey\n",
" )\n",
Expand Down
11 changes: 5 additions & 6 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"id": "1",
"metadata": {},
"source": [
"# Optimizing with CMA-ES in Jax\n",
"# Optimizing with CMA-ES in JAX\n",
"\n",
"This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n",
"\n",
Expand Down Expand Up @@ -178,7 +178,7 @@
"outputs": [],
"source": [
"state = cmaes.init()\n",
"random_key = jax.random.key(0)"
"key = jax.random.key(0)"
]
},
{
Expand All @@ -204,8 +204,6 @@
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"means = [state.mean]\n",
"covs = [(state.sigma**2) * state.cov_matrix]\n",
"\n",
Expand All @@ -214,7 +212,8 @@
" iteration_count += 1\n",
"\n",
" # sample\n",
" samples, random_key = cmaes.sample(state, random_key)\n",
" key, subkey = jax.random.split(key)\n",
" samples = cmaes.sample(state, subkey)\n",
"\n",
" # udpate\n",
" state = cmaes.update(state, samples)\n",
Expand Down Expand Up @@ -285,7 +284,7 @@
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# sample points to show fitness landscape\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n",
"f_x = fitness_fn(x)\n",
"\n",
Expand Down
Loading
Loading