From 529ddf65ec0ae050e05e9b805d47a80e22867550 Mon Sep 17 00:00:00 2001 From: Jeroen Van Goey Date: Fri, 26 Jan 2024 11:42:41 +0200 Subject: [PATCH 1/5] fix: fix typos and add codespell pre-commit hook --- .pre-commit-config.yaml | 22 ++++++++++++++----- docs/api_documentation/core/pbt.md | 2 +- docs/overview.md | 4 ++-- examples/aurora.ipynb | 6 ++--- examples/cmaes.ipynb | 4 ++-- examples/cmame.ipynb | 6 ++--- examples/cmamega.ipynb | 2 +- examples/dads.ipynb | 8 +++---- examples/diayn.ipynb | 6 ++--- examples/distributed_mapelites.ipynb | 4 ++-- examples/mapelites.ipynb | 4 ++-- examples/mees.ipynb | 6 ++--- examples/mome.ipynb | 2 +- examples/nsga2_spea2.ipynb | 2 +- examples/omgmega.ipynb | 2 +- examples/pga_aurora.ipynb | 6 ++--- examples/pgame.ipynb | 4 ++-- examples/qdpg.ipynb | 4 ++-- examples/sac_pbt.ipynb | 2 +- examples/smerl.ipynb | 4 ++-- examples/td3_pbt.ipynb | 2 +- qdax/baselines/dads.py | 4 ++-- qdax/baselines/diayn.py | 4 ++-- qdax/baselines/sac.py | 2 +- qdax/core/cmaes.py | 2 +- qdax/core/containers/archive.py | 12 +++++----- qdax/core/containers/mapelites_repertoire.py | 2 +- qdax/core/containers/repertoire.py | 2 +- qdax/core/containers/spea2_repertoire.py | 2 +- .../containers/unstructured_repertoire.py | 2 +- qdax/core/distributed_map_elites.py | 3 ++- qdax/core/emitters/cma_emitter.py | 4 ++-- qdax/core/emitters/cma_improvement_emitter.py | 2 +- qdax/core/emitters/cma_mega_emitter.py | 4 ++-- qdax/core/emitters/cma_opt_emitter.py | 2 +- qdax/core/emitters/cma_rnd_emitter.py | 4 ++-- qdax/core/emitters/dpg_emitter.py | 2 +- qdax/core/emitters/emitter.py | 4 ++-- qdax/core/emitters/mees_emitter.py | 4 ++-- qdax/core/emitters/omg_mega_emitter.py | 10 ++++----- qdax/core/emitters/qdpg_emitter.py | 2 +- qdax/environments/__init__.py | 2 +- qdax/environments/bd_extractors.py | 2 +- qdax/environments/pointmaze.py | 2 +- qdax/tasks/README.md | 6 ++--- qdax/tasks/brax_envs.py | 6 ++--- qdax/tasks/hypervolume_functions.py | 10 ++++----- qdax/tasks/jumanji_envs.py | 4 ++-- qdax/utils/plotting.py | 4 ++-- qdax/utils/sampling.py | 2 +- tests/baselines_test/mees_test.py | 2 +- tests/baselines_test/pgame_test.py | 2 +- tests/baselines_test/qdpg_test.py | 2 +- tests/core_test/cmaes_test.py | 2 +- tests/core_test/map_elites_test.py | 2 +- tests/utils_test/sampling_test.py | 2 +- 56 files changed, 118 insertions(+), 107 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9329a64..3281d8fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 24.1.0 hooks: - id: black language_version: python3.9 args: ["--target-version", "py39"] - repo: https://github.com/PyCQA/flake8 - rev: 3.8.4 + rev: 7.0.0 hooks: - id: flake8 args: ['--max-line-length=88', '--extend-ignore=E203'] @@ -21,12 +21,12 @@ repos: - flake8-comprehensions - flake8-bugbear - repo: https://github.com/kynan/nbstripout - rev: 0.3.9 + rev: 0.6.1 hooks: - id: nbstripout args: ["examples/"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.5.0 hooks: - id: debug-statements - id: requirements-txt-fixer @@ -42,6 +42,16 @@ repos: - id: trailing-whitespace # This hook trims trailing whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 + rev: v1.8.0 hooks: - id: mypy + +- repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + name: codespell + description: Checks for common misspellings in text files. + entry: codespell + language: python + types: [text] \ No newline at end of file diff --git a/docs/api_documentation/core/pbt.md b/docs/api_documentation/core/pbt.md index 6f482d32..82cc1396 100644 --- a/docs/api_documentation/core/pbt.md +++ b/docs/api_documentation/core/pbt.md @@ -2,7 +2,7 @@ [PBT](https://arxiv.org/abs/1711.09846) is optimization method to jointly optimise a population of models and their hyperparameters to maximize performance. -To use PBT in QDax to train SAC, one can use the two following components (see [examples](../../examples/sac_pbt.ipynb) to see how to use the components appropriatly): +To use PBT in QDax to train SAC, one can use the two following components (see [examples](../../examples/sac_pbt.ipynb) to see how to use the components appropriately): ::: qdax.baselines.sac_pbt.PBTSAC diff --git a/docs/overview.md b/docs/overview.md index a5b6def7..00de8b20 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -1,6 +1,6 @@ # QDax Overview -QDax has been designed to be modular yet flexible so it's easy for anyone to use and extend on the different state-of-the-art QD algortihms available. +QDax has been designed to be modular yet flexible so it's easy for anyone to use and extend on the different state-of-the-art QD algorithms available. For instance, MAP-Elites is designed to work with a few modular and simple components: `container`, `emitter`, and `scoring_function`. ## Key concepts @@ -17,7 +17,7 @@ The `scoring_function` defines the problem/task we want to solve and functions t With this modularity, a user can easily swap out any one of the components and pass it to the `MAPElites` class, avoiding having to re-implement all the steps of the algorithm. Under one layer of abstraction, users have a bit more flexibility. QDax has similarities to the simple and commonly found `ask`/`tell` interface. The `ask` function is similar to the `emit` function in QDax and the `tell` function is similar to the `update` function in QDax. Likewise, the `eval` of solutions is analogous to the `scoring function` in QDax. -More importantly, QDax handles the archive management which is the key idea of QD algorihtms and not present or needed in standard optimization algorihtms or evolutionary strategies. +More importantly, QDax handles the archive management which is the key idea of QD algorithms and not present or needed in standard optimization algorithms or evolutionary strategies. ## Code Example ```python diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index e4b86238..29be77ca 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -14,7 +14,7 @@ "# Optimizing with AURORA in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [AURORA](https://arxiv.org/pdf/1905.11874.pdf).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an emitter\n", @@ -198,7 +198,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", @@ -336,7 +336,7 @@ "\n", "@jax.jit\n", "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", - " \"\"\"Scan the udpate function.\"\"\"\n", + " \"\"\"Scan the update function.\"\"\"\n", " (\n", " repertoire,\n", " random_key,\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index c8e2a9fe..41a59ed7 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -15,7 +15,7 @@ "source": [ "# Optimizing with CMA-ES in Jax\n", "\n", - "This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create a CMA-ES optimizer\n", @@ -227,7 +227,7 @@ " # sample\n", " samples, random_key = cmaes.sample(state, random_key)\n", "\n", - " # udpate\n", + " # update\n", " state = cmaes.update(state, samples)\n", "\n", " # check stop condition\n", diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 3c355eea..e70506a8 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -13,7 +13,7 @@ "source": [ "# Optimizing with CMA-ME in Jax\n", "\n", - "This notebook shows how to use QDax to find diverse and performing parameters on Rastrigin or Sphere problem with [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to find diverse and performing parameters on Rastrigin or Sphere problem with [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf). It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create a CMA-ME emitter\n", @@ -219,7 +219,7 @@ "source": [ "random_key = jax.random.PRNGKey(0)\n", "# in CMA-ME settings (from the paper), there is no init population\n", - "# we multipy by zero to reproduce this setting\n", + "# we multiply by zero to reproduce this setting\n", "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", "centroids = compute_euclidean_centroids(\n", @@ -362,7 +362,7 @@ "axes[2].set_title(\"QD Score evolution during training\")\n", "axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable=\"box\")\n", "\n", - "# udpate this variable to save your results locally\n", + "# update this variable to save your results locally\n", "savefig = False\n", "if savefig:\n", " figname = \"cma_me_\" + optim_problem + \"_\" + str(num_dimensions) + \"_\" + emitter_type + \".png\"\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index e5749993..3a49dcd7 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -13,7 +13,7 @@ "source": [ "# Optimizing with CMA-MEGA in Jax\n", "\n", - "This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with [CMA-MEGA](https://arxiv.org/pdf/2106.03894.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with [CMA-MEGA](https://arxiv.org/pdf/2106.03894.pdf). It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create a cma-mega emitter\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index b3cc43b5..ec7c111e 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -13,7 +13,7 @@ "source": [ "# Training DADS with Jax\n", "\n", - "This notebook shows how to use QDax to train [DADS](https://arxiv.org/abs/1907.01657) on a Brax environment. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to train [DADS](https://arxiv.org/abs/1907.01657) on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "- how to define an environment\n", "- how to define a replay buffer\n", "- how to create a dads instance\n", @@ -107,7 +107,7 @@ "\n", "Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290), [DIAYN paper](https://arxiv.org/abs/1802.06070) and [DADS paper](https://arxiv.org/abs/1907.01657).\n", "\n", - "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and dynamics. In DADS, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function direclty on the full state." + "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and dynamics. In DADS, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function directly on the full state." ] }, { @@ -277,7 +277,7 @@ " deterministic=True,\n", " env=eval_env,\n", " skills=skills,\n", - " evaluation=True, # needed by normalizatoin mecanism\n", + " evaluation=True, # needed by normalizatoin mechanism\n", ")\n", "\n", "play_step = functools.partial(\n", @@ -327,7 +327,7 @@ "source": [ "## Prepare last utils for the training loop\n", "\n", - "Many Reinforcement Learning algorithm have similar training process, that can be divided in a precise training step that is repeted several times. Most of the differences are captured inside the `play_step` and in the `update` functions. Hence, once those are defined, the iteration works in the same way. For this reason, instead of coding the same function for each algorithm, we have created the `do_iteration_fn` that can be used by most of them. In the training script, the user just has to partial the function to give `play_step`, `update` plus a few other parameter." + "Many Reinforcement Learning algorithm have similar training process, that can be divided in a precise training step that is repeated several times. Most of the differences are captured inside the `play_step` and in the `update` functions. Hence, once those are defined, the iteration works in the same way. For this reason, instead of coding the same function for each algorithm, we have created the `do_iteration_fn` that can be used by most of them. In the training script, the user just has to partial the function to give `play_step`, `update` plus a few other parameter." ] }, { diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 0562e7c2..d1755225 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -13,7 +13,7 @@ "source": [ "# Training DIAYN with Jax\n", "\n", - "This notebook shows how to use QDax to train [DIAYN](https://arxiv.org/abs/1802.06070) on a Brax environment. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to train [DIAYN](https://arxiv.org/abs/1802.06070) on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "- how to define an environment\n", "- how to define a replay buffer\n", "- how to create a diayn instance\n", @@ -107,7 +107,7 @@ "\n", "Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290) and [DIAYN paper](https://arxiv.org/abs/1802.06070).\n", "\n", - "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function direclty on the full state." + "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function directly on the full state." ] }, { @@ -317,7 +317,7 @@ "source": [ "## Prepare last utils for the training loop\n", "\n", - "Many Reinforcement Learning algorithm have similar training process, that can be divided in a precise training step that is repeted several times. Most of the differences are captured inside the `play_step` and in the `update` functions. Hence, once those are defined, the iteration works in the same way. For this reason, instead of coding the same function for each algorithm, we have created the `do_iteration_fn` that can be used by most of them. In the training script, the user just has to partial the function to give `play_step`, `update` plus a few other parameter." + "Many Reinforcement Learning algorithm have similar training process, that can be divided in a precise training step that is repeated several times. Most of the differences are captured inside the `play_step` and in the `update` functions. Hence, once those are defined, the iteration works in the same way. For this reason, instead of coding the same function for each algorithm, we have created the `do_iteration_fn` that can be used by most of them. In the training script, the user just has to partial the function to give `play_step`, `update` plus a few other parameter." ] }, { diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 2e6fd991..0881b1ed 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -14,7 +14,7 @@ "# Optimizing with MAP-Elites in Jax (multi-devices example)\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [MAP-Elites](https://arxiv.org/abs/1504.04909).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an emitter\n", @@ -224,7 +224,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b1fea651..1a7511a8 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -14,7 +14,7 @@ "# Optimizing with MAP-Elites in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [MAP-Elites](https://arxiv.org/abs/1504.04909).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an emitter\n", @@ -185,7 +185,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb index ad1a4740..1839f2b7 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -18,7 +18,7 @@ "# Optimizing with MEES in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers with MAP-Elites-ES introduced in [Scaling MAP-Elites to Deep Neuroevolution](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create the MEES emitter\n", @@ -201,7 +201,7 @@ }, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", @@ -259,7 +259,7 @@ " behavior_descriptor_extractor=bd_extraction_fn,\n", ")\n", "\n", - "# Prepare the scoring functions for the offspring generated folllowing\n", + "# Prepare the scoring functions for the offspring generated following\n", "# the approximated gradient (each of them is evaluated 30 times)\n", "sampling_fn = functools.partial(\n", " sampling,\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 555381e6..684f058d 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -15,7 +15,7 @@ "source": [ "# Optimizing multiple objectives with MOME in Jax\n", "\n", - "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [Multi-Objective MAP-Elites](https://arxiv.org/pdf/2202.03057.pdf) (MOME) algorithm. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [Multi-Objective MAP-Elites](https://arxiv.org/pdf/2202.03057.pdf) (MOME) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an emitter instance\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index be662981..ad0952e3 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -13,7 +13,7 @@ "source": [ "# Optimizing multiple objectives with NSGA2 & SPEA2 in Jax\n", "\n", - "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [NSGA2](https://ieeexplore.ieee.org/document/996017) and [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) algorithms. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using [NSGA2](https://ieeexplore.ieee.org/document/996017) and [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) algorithms. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an emitter instance\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 8d417cc0..afb7c0d0 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -14,7 +14,7 @@ "# Optimizing with OMG-MEGA in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with [OMG-MEGA](https://arxiv.org/pdf/2106.03894.pdf).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an omg-mega emitter\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 6152ce63..5df67100 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -14,7 +14,7 @@ "# Optimizing with PGA-AURORA in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [PGA-AURORA](https://arxiv.org/abs/2210.03516).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create an emitter\n", @@ -216,7 +216,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", @@ -382,7 +382,7 @@ "\n", "@jax.jit\n", "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", - " \"\"\"Scan the udpate function.\"\"\"\n", + " \"\"\"Scan the update function.\"\"\"\n", " (\n", " repertoire,\n", " emitter_state,\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 9b638b2d..0b9cc0b0 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -14,7 +14,7 @@ "# Optimizing with PGAME in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Policy Gradient Assisted MAP-Elites](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create the PGAME emitter\n", @@ -192,7 +192,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 102d5262..23282162 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -14,7 +14,7 @@ "# Optimizing with QDPG in Jax\n", "\n", "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [QDPG - Quality Diversity Policy Gradient in MAP-Elites](https://arxiv.org/abs/2006.08505).\n", - "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "\n", "- how to define the problem\n", "- how to create the QDPG emitter\n", @@ -205,7 +205,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 7762083f..d4281203 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -304,7 +304,7 @@ }, "outputs": [], "source": [ - "# get eval policy fonction\n", + "# get eval policy function\n", "eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name=\"p\", devices=devices)" ] }, diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index fe655fe2..15da3fd2 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -13,7 +13,7 @@ "source": [ "# Training DIAYN SMERL with Jax\n", "\n", - "This notebook shows how to use QDax to train DIAYN SMERL on a Brax environment. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "This notebook shows how to use QDax to train DIAYN SMERL on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", "- how to define an environment\n", "- how to define a replay buffer\n", "- how to create a diayn smerl instance\n", @@ -108,7 +108,7 @@ "\n", "Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290), [DIAYN paper](https://arxiv.org/abs/1802.06070) and [SMERL paper](https://arxiv.org/pdf/2010.14484.pdf).\n", "\n", - "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function direclty on the full state." + "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function directly on the full state." ] }, { diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index ec98b9da..00a98bae 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -264,7 +264,7 @@ }, "outputs": [], "source": [ - "# get eval policy fonction\n", + "# get eval policy function\n", "eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name=\"p\", devices=devices)" ] }, diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 41f2ff08..88296f62 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -68,7 +68,7 @@ class DADS(SAC): of skills, is used to evaluate the skills in the environment and hence to generate transitions. The sampling is hence fixed and perfectly uniform. - We plan to add continous skill as an option in the future. We also plan + We plan to add continuous skill as an option in the future. We also plan to release the current constraint on the number of batched environments by sampling from the skills rather than having this fixed setting. """ @@ -490,7 +490,7 @@ def _update_networks( (training_state, transitions), ) - # udpate alpha + # update alpha ( alpha_params, alpha_optimizer_state, diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index c03cfb3f..f0bd31ec 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -64,7 +64,7 @@ class DIAYN(SAC): Since we are using categorical skills, the current loss function used to train the discriminator is the categorical cross entropy loss. - We plan to add continous skill as an option in the future. We also plan + We plan to add continuous skill as an option in the future. We also plan to release the current constraint on the number of batched environments by sampling from the skills rather than having this fixed setting. """ @@ -408,7 +408,7 @@ def _update_networks( training_state.discriminator_params, discriminator_updates ) - # udpate alpha + # update alpha ( alpha_params, alpha_optimizer_state, diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index a5ce15c5..e53533f8 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -162,7 +162,7 @@ def select_action( random_key: RNGKey, deterministic: bool = False, ) -> Tuple[Action, RNGKey]: - """Selects an action acording to SAC policy. + """Selects an action according to SAC policy. Args: obs: agent observation(s) diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index 481a49bf..05e6fffc 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -261,7 +261,7 @@ def update_eigen( # unpack data cov, num_updates = operand - # enfore symmetry - did not change anything + # enforce symmetry - did not change anything cov = jnp.triu(cov) + jnp.triu(cov, 1).T # get eigen decomposition: eigenvalues, eigenvectors diff --git a/qdax/core/containers/archive.py b/qdax/core/containers/archive.py index 8af808f3..cb3c57a0 100644 --- a/qdax/core/containers/archive.py +++ b/qdax/core/containers/archive.py @@ -15,9 +15,9 @@ class Archive(PyTreeNode): An example of use of the archive is the algorithm QDPG: state descriptors are stored in this archive and a novelty scorer compares - new state desciptors to the state descriptors stored in this archive. + new state descriptors to the state descriptors stored in this archive. - Note: notations suppose that the elements are called state desciptors. + Note: notations suppose that the elements are called state descriptors. If we where to use this structure for another application, it would be better to change the variables name for another one. Does not seem necessary at the moment though. @@ -157,7 +157,7 @@ def insert(self, state_descriptors: jnp.ndarray) -> Archive: """ state_descriptors = state_descriptors.reshape((-1, state_descriptors.shape[-1])) - # get nearest neigbor for each new state descriptor + # get nearest neighbor for each new state descriptor values, _indices = knn(self.data, state_descriptors, 1) # get indices where distance bigger than threshold @@ -187,7 +187,7 @@ def iterate_fn( state_descriptor = condition_data["state_descriptor"] # do the filtering among the added elements - # get nearest neigbor for each new state descriptor + # get nearest neighbor for each new state descriptor values, _indices = knn(new_elements, state_descriptor.reshape(1, -1), 1) # get indices where distance bigger than threshold @@ -255,7 +255,7 @@ def score_euclidean_novelty( def knn( data: jnp.ndarray, new_data: jnp.ndarray, k: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """K nearest neigbors - Brute force implementation. + """K nearest neighbors - Brute force implementation. Using euclidean distance. Code from https://www.kernel-operations.io/keops/_auto_benchmarks/ @@ -264,7 +264,7 @@ def knn( Args: data: given reference data. new_data: data to be compared to the reference data. - k: number of neigbors to consider. + k: number of neighbors to consider. Returns: The distances and indices of the nearest neighbors. diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..a6fae9a6 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -333,7 +333,7 @@ def init( fitnesses: fitness of the initial genotypes of shape (batch_size,) descriptors: descriptors of the initial genotypes of shape (batch_size, num_descriptors) - centroids: tesselation centroids of shape (batch_size, num_descriptors) + centroids: tessellation centroids of shape (batch_size, num_descriptors) extra_scores: unused extra_scores of the initial genotypes Returns: diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py index f50d53b7..64082bcc 100644 --- a/qdax/core/containers/repertoire.py +++ b/qdax/core/containers/repertoire.py @@ -47,6 +47,6 @@ def add(self) -> Repertoire: repertoire. Returns: - The udpated repertoire. + The updated repertoire. """ pass diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index 54870db4..2cbbd539 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -60,7 +60,7 @@ def add( """Updates the population with the new solutions. To decide which individuals to keep, we count, for each solution, - the number of solutions by which tey are dominated. We keep only + the number of solutions by which they are dominated. We keep only the solutions that are the less dominated ones. Args: diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index f4cc0c98..903c1650 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -139,7 +139,7 @@ class UnstructuredRepertoire(flax.struct.PyTreeNode): descriptors: an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors). - centroids: an array the contains the centroids of the tesselation. The array + centroids: an array the contains the centroids of the tessellation. The array shape is (num_centroids, num_descriptors). observations: observations that the genotype gathered in the environment. """ diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index c8a1ea44..7659e779 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -158,7 +158,8 @@ def get_distributed_init_fn( devices: hardware devices. Returns: - A callable function that inits the MAP-Elites algorithm in a ditributed way. + A callable function that inits the MAP-Elites algorithm in a distributed + way. """ return jax.pmap( # type: ignore partial(self.init, centroids=centroids), diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index f9d58caa..cee37ca1 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -26,7 +26,7 @@ class CMAEmitterState(EmitterState): subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm previous_fitnesses: store last fitnesses of the repertoire. Used to - compute the improvment. + compute the improvement. emit_count: count the number of emission events. """ @@ -367,7 +367,7 @@ def _ranking_criteria( fitnesses: corresponding fitnesses. descriptors: corresponding fitnesses. extra_scores: corresponding extra scores. - improvements: improvments of the emitted genotypes. This corresponds + improvements: improvements of the emitted genotypes. This corresponds to the difference between their fitness and the fitness of the individual occupying the cell of corresponding fitness. diff --git a/qdax/core/emitters/cma_improvement_emitter.py b/qdax/core/emitters/cma_improvement_emitter.py index 28424f3f..f408a3ef 100644 --- a/qdax/core/emitters/cma_improvement_emitter.py +++ b/qdax/core/emitters/cma_improvement_emitter.py @@ -47,7 +47,7 @@ def _ranking_criteria( fitnesses: corresponding fitnesses. descriptors: corresponding fitnesses. extra_scores: corresponding extra scores. - improvements: improvments of the emitted genotypes. This corresponds + improvements: improvements of the emitted genotypes. This corresponds to the difference between their fitness and the fitness of the individual occupying the cell of corresponding fitness. diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f63654fd..3564fc52 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -35,7 +35,7 @@ class CMAMEGAState(EmitterState): subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm previous_fitnesses: store last fitnesses of the repertoire. Used to - compute the improvment. + compute the improvement. """ theta: Genotype @@ -62,7 +62,7 @@ def __init__( Fontaine et al. Args: - scoring_function: a function to score individuals, outputing fitness, + scoring_function: a function to score individuals, outputting fitness, descriptors and extra scores. With this emitter, the extra score contains gradients and normalized gradients. batch_size: number of solutions sampled at each iteration diff --git a/qdax/core/emitters/cma_opt_emitter.py b/qdax/core/emitters/cma_opt_emitter.py index d9c5bf71..cb230f84 100644 --- a/qdax/core/emitters/cma_opt_emitter.py +++ b/qdax/core/emitters/cma_opt_emitter.py @@ -31,7 +31,7 @@ def _ranking_criteria( fitnesses: corresponding fitnesses. descriptors: corresponding fitnesses. extra_scores: corresponding extra scores. - improvements: improvments of the emitted genotypes. This corresponds + improvements: improvements of the emitted genotypes. This corresponds to the difference between their fitness and the fitness of the individual occupying the cell of corresponding fitness. diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 4afb2f5d..5189e1de 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -23,7 +23,7 @@ class CMARndEmitterState(CMAEmitterState): subject to refactoring discussions in the future. cmaes_state: state of the underlying CMA-ES algorithm previous_fitnesses: store last fitnesses of the repertoire. Used to - compute the improvment. + compute the improvement. emit_count: count the number of emission events. random_direction: direction of the behavior space we are trying to explore. @@ -142,7 +142,7 @@ def _ranking_criteria( fitnesses: corresponding fitnesses. descriptors: corresponding fitnesses. extra_scores: corresponding extra scores. - improvements: improvments of the emitted genotypes. This corresponds + improvements: improvements of the emitted genotypes. This corresponds to the difference between their fitness and the fitness of the individual occupying the cell of corresponding fitness. diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8b858db4..65b59524 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -53,7 +53,7 @@ class DiversityPGEmitter(QualityPGEmitter): """ A diversity policy gradient emitter used to implement QDPG algorithm. - Please not that the inheritence between DiversityPGEmitter and QualityPGEmitter + Please not that the inheritance between DiversityPGEmitter and QualityPGEmitter could be increased with changes in the way transitions samples are handled in the QualityPGEmitter. But this would modify the computation/memory strategy of the current implementation. Hence, we won't apply this yet and will discuss this with diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d32ed981..47f3a81d 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -12,7 +12,7 @@ class EmitterState(PyTreeNode): """The state of an emitter. Emitters are used to suggest offspring when evolving a population of genotypes. To emit new genotypes, some - emitters need to have a state, that carries useful informations, like + emitters need to have a state, that carries useful information, like running means, distribution parameters, critics, replay buffers etc... The object emitter state is used to store them and is updated along @@ -83,7 +83,7 @@ def state_update( """This function gives an opportunity to update the emitter state after the genotypes have been scored. - As a matter of fact, many emitter states needs informations from + As a matter of fact, many emitter states needs information from the evaluations of the genotypes in order to be updated, for instance: - CMA emitter: to update the rank of the covariance matrix - PGA emitter: to fill the replay buffer and update the critic/greedy diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index b5bb1ada..f2052428 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -442,7 +442,7 @@ def _es_emitter( scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray], ) -> Tuple[Genotype, optax.OptState, RNGKey]: """Main es component, given a parent and a way to infer the score from - the fitnesses and descriptors fo its es-samples, return its + the fitnesses and descriptors of its es-samples, return its approximated-gradient-generated offspring. Args: @@ -670,7 +670,7 @@ def state_update( assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, ( "ERROR: MAP-Elites-ES generates 1 offspring per generation, " - + "batch_size should be 1, the inputed batch has size:" + + "batch_size should be 1, the inputted batch has size:" + str(jax.tree_util.tree_leaves(genotypes)[0].shape[0]) ) diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7336750d..228997b3 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -15,7 +15,7 @@ class OMGMEGAEmitterState(EmitterState): Args: gradients_repertoire: MapElites repertoire containing the gradients - of the indivuals. + of the individuals. """ gradients_repertoire: MapElitesRepertoire @@ -39,11 +39,11 @@ class OMGMEGAEmitter(Emitter): sampling. - in the state_update, we have to insert the gradients in the gradients repertoire in the same way the individuals were inserted. Once again, this is - slightly unoptimal because the same addition mecanism has to be computed two + slightly unoptimal because the same addition mechanism has to be computed two times. One solution that we are discussing and that is very similar to the first - solution discussed above, would be to decompose the addition mecanism in two - phases: one outputing the indices at which individuals will be added, and then - the actual insertion step. This would enable to re-use the same indices to add + solution discussed above, would be to decompose the addition mechanism in two + phases: one outputting the indices at which individuals will be added, and then + the actual insertion step. This would enable to reusethe same indices to add the gradients instead of having to recompute them. The two design choices seem acceptable and enable to have OMG MEGA compatible diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index eefd1566..26a0b693 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -1,7 +1,7 @@ """Implementation of an updated version of the algorithm QDPG presented in the paper https://arxiv.org/abs/2006.08505. -QDPG has been udpated to enter in the container+emitter framework of QD. Furthermore, +QDPG has been updated to enter in the container+emitter framework of QD. Furthermore, it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm. """ diff --git a/qdax/environments/__init__.py b/qdax/environments/__init__.py index 054c75f7..f0b7e9d1 100644 --- a/qdax/environments/__init__.py +++ b/qdax/environments/__init__.py @@ -25,7 +25,7 @@ from qdax.environments.pointmaze import PointMaze from qdax.environments.wrappers import CompletedEvalWrapper -# experimentally determinated offset (except for antmaze) +# experimentally determined offset (except for antmaze) # should be sufficient to have only positive rewards but no guarantee reward_offset = { "pointmaze": 2.3431, diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index af1d51ba..69f6b924 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -9,7 +9,7 @@ def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: - """Compute final xy positon. + """Compute final xy position. This function suppose that state descriptor is the xy position, as it just select the final one of the state descriptors given. diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index b5f86ef5..00faa7d7 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -110,7 +110,7 @@ def reset(self, rng: jp.ndarray) -> State: x_init = jp.random_uniform(rng1, (), low=self._x_min, high=self._x_max) / 10 y_init = jp.random_uniform(rng2, (), low=self._y_min, high=-0.7) obs_init = jp.array([x_init, y_init]) - # create fake qp (to re-use brax.State) + # create fake qp (to reusebrax.State) fake_qp = brax.QP.zero() # init reward, metrics and infos reward, done = jp.zeros(2) diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md index 56528323..0232b1a7 100644 --- a/qdax/tasks/README.md +++ b/qdax/tasks/README.md @@ -1,7 +1,7 @@ # QD Tasks The `tasks` directory provides default `scoring_function`'s to import easily to perform experiments without the boilerplate code so that the main script is kept simple and is not bloated. It provides a set of fixed tasks that is not meant to be modified. If you are developing and require the flexibility of modifying the task and the details that come along with it, we recommend copying and writing your own custom `scoring_function` in your main script instead of importing from `tasks`. -The `tasks` directory also serves as a way to maintain a QD benchmark task suite that can be easily accesed. We implement several benchmark task across a range of domains. The tasks here are classical tasks from QD literature as well as more recent benchmarks tasks proposed at the [QD Benchmarks Workshop at GECCO 2022](https://quality-diversity.github.io/workshop). +The `tasks` directory also serves as a way to maintain a QD benchmark task suite that can be easily accessed. We implement several benchmark task across a range of domains. The tasks here are classical tasks from QD literature as well as more recent benchmarks tasks proposed at the [QD Benchmarks Workshop at GECCO 2022](https://quality-diversity.github.io/workshop). ## Arm | Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | @@ -89,8 +89,8 @@ desc_size = 2 | Square | n | $[0,1]^n$ | n | $[0,1]^n$ | | | Checkered | n | $[0,1]^n$ | n | $[0,1]^n$ | | | Empty Circle | n | $[0,1]^n$ | n | $[0,1]^n$ | | -| Non-continous Islands | n | $[0,1]^n$ | n | $[0,1]^n$ | | -| Continous Islands | n | $[0,1]^n$ | n | $[0,1]^n$ | | +| Non-continuous Islands | n | $[0,1]^n$ | n | $[0,1]^n$ | | +| Continuous Islands | n | $[0,1]^n$ | n | $[0,1]^n$ | | ### Example Usage diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 931ee9d3..1a622b42 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -115,7 +115,7 @@ def scoring_function_brax_envs( This rollout is only deterministic when all the init states are the same. If the init states are fixed but different, as a policy is not necessarily - evaluated with the same environment everytime, this won't be determinist. + evaluated with the same environment every time, this won't be deterministic. When the init states are different, this is not purely stochastic. Args: @@ -361,8 +361,8 @@ def get_aurora_scoring_fn( """Evaluates policies contained in flatten_variables in parallel This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarly - evaluated with the same environment everytime, this won't be determinist. + If the init states are fixed but different, as a policy is not necessary + evaluated with the same environment every time, this won't be deterministic. When the init states are different, this is not purely stochastic. This choice was made for performance reason, as the reset function of brax envs diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py index f4936574..89e9c50d 100644 --- a/qdax/tasks/hypervolume_functions.py +++ b/qdax/tasks/hypervolume_functions.py @@ -13,7 +13,7 @@ def square(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - Seach space should be [0,1]^n + Search space should be [0,1]^n BD space should be [0,1]^n """ freq = 5 @@ -24,7 +24,7 @@ def square(params: Genotype) -> Tuple[Fitness, Descriptor]: def checkered(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - Seach space should be [0,1]^n + Search space should be [0,1]^n BD space should be [0,1]^n """ freq = 5 @@ -35,7 +35,7 @@ def checkered(params: Genotype) -> Tuple[Fitness, Descriptor]: def empty_circle(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - Seach space should be [0,1]^n + Search space should be [0,1]^n BD space should be [0,1]^n """ @@ -52,7 +52,7 @@ def _gaussian(x: jnp.ndarray, mu: float, sig: float) -> jnp.ndarray: def non_continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - Seach space should be [0,1]^n + Search space should be [0,1]^n BD space should be [0,1]^n """ f = jnp.prod(params) @@ -62,7 +62,7 @@ def non_continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: def continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: """ - Seach space should be [0,1]^n + Search space should be [0,1]^n BD space should be [0,1]^n """ coeff = 20 diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index 14455d66..03695d66 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -152,8 +152,8 @@ def jumanji_scoring_function( deterministic or pseudo-deterministic environments. This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarly - evaluated with the same environment everytime, this won't be determinist. + If the init states are fixed but different, as a policy is not necessary + evaluated with the same environment every time, this won't be deterministic. When the init states are different, this is not purely stochastic. """ diff --git a/qdax/utils/plotting.py b/qdax/utils/plotting.py index 9b107c7e..68fff64a 100644 --- a/qdax/utils/plotting.py +++ b/qdax/utils/plotting.py @@ -102,7 +102,7 @@ def plot_2d_map_elites_repertoire( Args: centroids: the centroids of the repertoire repertoire_fitnesses: the fitness of the repertoire - minval: minimum values for the descritors + minval: minimum values for the descriptors maxval: maximum values for the descriptors repertoire_descriptors: the descriptors. Defaults to None. ax: a matplotlib axe for the figure to plot. Defaults to None. @@ -229,7 +229,7 @@ def plot_map_elites_results( env_steps: the array containing the number of steps done in the environment. metrics: a dictionary containing metrics from the optimizatoin process. repertoire: the final repertoire obtained. - min_bd: the mimimal possible values for the bd. + min_bd: the minimal possible values for the bd. max_bd: the maximal possible values for the bd. Returns: diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index bf5c1ae4..618fc72a 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -262,7 +262,7 @@ def sampling_reproducibility( descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, ) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor, RNGKey]: """Wrap scoring_function to perform sampling and compute the - expectation and reproduciblity. + expectation and reproducibility. This function return the reproducibility of fitnesses and descriptors for each individual over `num_samples` evaluations using the provided extractor diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 3f3314fd..cd7477ed 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -63,7 +63,7 @@ def test_mees() -> None: fake_batch = jnp.zeros(shape=(1, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - # Define the fonction to play a step with the policy in the environment + # Define the function to play a step with the policy in the environment def play_step_fn( env_state: EnvState, policy_params: Params, diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 9cb1b3fb..24434275 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -70,7 +70,7 @@ def test_pgame() -> None: fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - # Define the fonction to play a step with the policy in the environment + # Define the function to play a step with the policy in the environment def play_step_fn( env_state: EnvState, policy_params: Params, diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 1889f197..e8d0c257 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -85,7 +85,7 @@ def test_qdpg() -> None: fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - # Define the fonction to play a step with the policy in the environment + # Define the function to play a step with the policy in the environment def play_step_fn( env_state: EnvState, policy_params: Params, diff --git a/tests/core_test/cmaes_test.py b/tests/core_test/cmaes_test.py index 16321fd4..e73fac56 100644 --- a/tests/core_test/cmaes_test.py +++ b/tests/core_test/cmaes_test.py @@ -41,7 +41,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: # sample samples, random_key = cmaes.sample(state, random_key) - # udpate + # update state = cmaes.update(state, samples) # check stop condition diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index b532aa65..170843df 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -73,7 +73,7 @@ def test_map_elites(env_name: str, batch_size: int) -> None: reset_fn = jax.jit(jax.vmap(env.reset)) init_states = reset_fn(keys) - # Define the fonction to play a step with the policy in the environment + # Define the function to play a step with the policy in the environment def play_step_fn( env_state: EnvState, policy_params: Params, diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 6ce6cbe9..2288552a 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -50,7 +50,7 @@ def test_sampling() -> None: fake_batch = jnp.zeros(shape=(1, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - # Define the fonction to play a step with the policy in the environment + # Define the function to play a step with the policy in the environment def play_step_fn( env_state: EnvState, policy_params: Params, From d70694065a9851d47a4b66da8302dd7120c07ae5 Mon Sep 17 00:00:00 2001 From: Jeroen Van Goey Date: Fri, 26 Jan 2024 11:48:47 +0200 Subject: [PATCH 2/5] fix: add space --- qdax/core/emitters/omg_mega_emitter.py | 2 +- qdax/environments/pointmaze.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 228997b3..113abfab 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -43,7 +43,7 @@ class OMGMEGAEmitter(Emitter): times. One solution that we are discussing and that is very similar to the first solution discussed above, would be to decompose the addition mechanism in two phases: one outputting the indices at which individuals will be added, and then - the actual insertion step. This would enable to reusethe same indices to add + the actual insertion step. This would enable to reuse the same indices to add the gradients instead of having to recompute them. The two design choices seem acceptable and enable to have OMG MEGA compatible diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index 00faa7d7..a9715329 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -110,7 +110,7 @@ def reset(self, rng: jp.ndarray) -> State: x_init = jp.random_uniform(rng1, (), low=self._x_min, high=self._x_max) / 10 y_init = jp.random_uniform(rng2, (), low=self._y_min, high=-0.7) obs_init = jp.array([x_init, y_init]) - # create fake qp (to reusebrax.State) + # create fake qp (to reuse brax.State) fake_qp = brax.QP.zero() # init reward, metrics and infos reward, done = jp.zeros(2) From 38a5f6fc0b83a9750828ce1a1c9449aac73b04a2 Mon Sep 17 00:00:00 2001 From: Jeroen Van Goey Date: Sun, 28 Jan 2024 11:33:22 +0200 Subject: [PATCH 3/5] fix: fix flake8 errors --- qdax/core/neuroevolution/sac_td3_utils.py | 5 +++-- .../buffers_test/trajectory_buffer_test.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/qdax/core/neuroevolution/sac_td3_utils.py b/qdax/core/neuroevolution/sac_td3_utils.py index 1c54511a..e6617480 100644 --- a/qdax/core/neuroevolution/sac_td3_utils.py +++ b/qdax/core/neuroevolution/sac_td3_utils.py @@ -5,6 +5,7 @@ We are currently thinking about elegant ways to unify both in order to avoid code repetition. """ + # TODO: Uniformize with the functions in mdp_utils from functools import partial from typing import Any, Callable, Tuple @@ -75,8 +76,8 @@ def generate_unroll( ], ], ) -> Tuple[EnvState, TrainingState, Transition]: - """Generates an episode according to the agent's policy, returns the final state of the - episode and the transitions of the episode. + """Generates an episode according to the agent's policy, returns the final state of + the episode and the transitions of the episode. """ def _scan_play_step_fn( diff --git a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py index 75f68b40..97a91b0d 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py @@ -202,8 +202,8 @@ def test_trajectory_buffer_insert() -> None: multy_step_episodic_data, equal_nan=True, ), - "Episodic data when transitions are added sequentially is not consistent to when\ - theya are added as batch.", + "Episodic data when transitions are added sequentially is not consistent to " + "when they are added as batch.", ) pytest.assume( From 84a543e4d3ccffc9a784b7841b421f111e26cff9 Mon Sep 17 00:00:00 2001 From: Jeroen Van Goey Date: Sun, 28 Jan 2024 11:34:56 +0200 Subject: [PATCH 4/5] fix: fix line ending --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3281d8fe..c96676b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,4 +54,4 @@ repos: description: Checks for common misspellings in text files. entry: codespell language: python - types: [text] \ No newline at end of file + types: [text] From 302be48da5ba11735610424530793e5b6f47a0bc Mon Sep 17 00:00:00 2001 From: Jeroen Van Goey Date: Sun, 28 Jan 2024 11:44:42 +0200 Subject: [PATCH 5/5] fix: run black on all files --- examples/scripts/me_example.py | 7 ++++++- qdax/baselines/dads.py | 15 ++++++++++++--- qdax/baselines/genetic_algorithm.py | 1 + qdax/baselines/sac.py | 15 ++++++++++++--- qdax/baselines/td3.py | 10 ++++++++-- qdax/baselines/td3_pbt.py | 5 ++++- qdax/core/cmaes.py | 1 + qdax/core/distributed_map_elites.py | 14 ++++++++++++-- qdax/core/emitters/dpg_emitter.py | 18 +++++++++++++++--- qdax/core/emitters/mees_emitter.py | 1 + qdax/core/emitters/pbt_variation_operators.py | 5 ++++- qdax/core/emitters/qdpg_emitter.py | 1 + qdax/core/emitters/qpg_emitter.py | 12 ++++++++++-- qdax/core/map_elites.py | 8 +++++++- qdax/core/mels.py | 1 + .../networks/seq2seq_networks.py | 1 - .../core/neuroevolution/normalization_utils.py | 1 - qdax/tasks/brax_envs.py | 1 + qdax/tasks/jumanji_envs.py | 1 + qdax/utils/sampling.py | 1 + tests/baselines_test/cmame_test.py | 6 +++++- tests/baselines_test/cmamega_test.py | 6 +++++- tests/baselines_test/dads_smerl_test.py | 1 + tests/baselines_test/dads_test.py | 1 + tests/baselines_test/ga_test.py | 6 +++++- tests/baselines_test/mees_test.py | 6 +++++- tests/baselines_test/omgmega_test.py | 6 +++++- tests/baselines_test/pgame_test.py | 6 +++++- tests/baselines_test/qdpg_test.py | 6 +++++- .../emitters_test/multi_emitter_test.py | 6 +++++- tests/core_test/map_elites_test.py | 6 +++++- tests/core_test/mels_test.py | 6 +++++- tests/core_test/mome_test.py | 6 +++++- tests/default_tasks_test/arm_test.py | 6 +++++- tests/default_tasks_test/brax_task_test.py | 6 +++++- .../hypervolume_functions_test.py | 6 +++++- tests/default_tasks_test/qd_suite_test.py | 6 +++++- .../standard_functions_test.py | 6 +++++- 38 files changed, 180 insertions(+), 37 deletions(-) diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 699c6aba..433bc1d2 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -79,7 +79,12 @@ def run_me() -> None: # Run MAP-Elites loop for _ in range(num_iterations): - (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = map_elites.update( repertoire, emitter_state, random_key, diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 88296f62..3600842b 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -430,12 +430,17 @@ def _update_dynamics( """ training_state, transitions = operand - dynamics_loss, dynamics_gradient = jax.value_and_grad(self._dynamics_loss_fn,)( + dynamics_loss, dynamics_gradient = jax.value_and_grad( + self._dynamics_loss_fn, + )( training_state.dynamics_params, transitions=transitions, ) - (dynamics_updates, dynamics_optimizer_state,) = self._dynamics_optimizer.update( + ( + dynamics_updates, + dynamics_optimizer_state, + ) = self._dynamics_optimizer.update( dynamics_gradient, training_state.dynamics_optimizer_state ) dynamics_params = optax.apply_updates( @@ -483,7 +488,11 @@ def _update_networks( random_key = training_state.random_key # Update skill-dynamics - (dynamics_params, dynamics_loss, dynamics_optimizer_state,) = jax.lax.cond( + ( + dynamics_params, + dynamics_loss, + dynamics_optimizer_state, + ) = jax.lax.cond( training_state.steps % self._config.dynamics_update_freq == 0, self._update_dynamics, self._not_update_dynamics, diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index 0714fb6c..c1a17b8b 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -1,4 +1,5 @@ """Core components of a basic genetic algorithm.""" + from functools import partial from typing import Any, Callable, Optional, Tuple diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index e53533f8..0ff30bf9 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -449,7 +449,10 @@ def _update_alpha( random_key=subkey, ) alpha_optimizer = optax.adam(learning_rate=alpha_lr) - (alpha_updates, alpha_optimizer_state,) = alpha_optimizer.update( + ( + alpha_updates, + alpha_optimizer_state, + ) = alpha_optimizer.update( alpha_gradient, training_state.alpha_optimizer_state ) alpha_params = optax.apply_updates( @@ -503,7 +506,10 @@ def _update_critic( random_key=subkey, ) critic_optimizer = optax.adam(learning_rate=critic_lr) - (critic_updates, critic_optimizer_state,) = critic_optimizer.update( + ( + critic_updates, + critic_optimizer_state, + ) = critic_optimizer.update( critic_gradient, training_state.critic_optimizer_state ) critic_params = optax.apply_updates( @@ -556,7 +562,10 @@ def _update_actor( random_key=subkey, ) policy_optimizer = optax.adam(learning_rate=policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index e09b5254..8b697277 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -76,7 +76,10 @@ class TD3: def __init__(self, config: TD3Config, action_size: int): self._config = config - self._policy, self._critic, = make_td3_networks( + ( + self._policy, + self._critic, + ) = make_td3_networks( action_size=action_size, critic_hidden_layer_sizes=self._config.critic_hidden_layer_size, policy_hidden_layer_sizes=self._config.policy_hidden_layer_size, @@ -421,7 +424,10 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam( learning_rate=self._config.policy_learning_rate ) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/td3_pbt.py b/qdax/baselines/td3_pbt.py index 60cd8a38..6f7abbe0 100644 --- a/qdax/baselines/td3_pbt.py +++ b/qdax/baselines/td3_pbt.py @@ -291,7 +291,10 @@ def update( def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam(learning_rate=training_state.policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index 05e6fffc..e7f29229 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -2,6 +2,7 @@ Definition of CMAES class, containing main functions necessary to build a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772 """ + from functools import partial from typing import Callable, Optional, Tuple diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index 7659e779..bb214e28 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -196,7 +197,12 @@ def _scan_update( repertoire, emitter_state, random_key = carry # apply one step of update - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, @@ -210,7 +216,11 @@ def update_fn( random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]: """Apply num_iterations of update.""" - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( _scan_update, (repertoire, emitter_state, random_key), (), diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 65b59524..e08694d2 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -1,6 +1,7 @@ """ Implements the Diversity PG inspired by QDPG algorithm in jax for brax environments, based on: https://arxiv.org/abs/2006.08505 """ + from dataclasses import dataclass from functools import partial from typing import Any, Callable, Optional, Tuple @@ -161,7 +162,10 @@ def scan_train_critics( return new_emitter_state, () # sample transitions - (transitions, random_key,) = emitter_state.replay_buffer.sample( + ( + transitions, + random_key, + ) = emitter_state.replay_buffer.sample( random_key=emitter_state.random_key, sample_size=self._config.num_critic_training_steps * self._config.batch_size, @@ -230,7 +234,11 @@ def _train_critics( ) # Update greedy policy - (policy_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + policy_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -329,7 +337,11 @@ def scan_train_policy( transitions, ) - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (transitions), diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index f2052428..4a089fe6 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -3,6 +3,7 @@ from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217 """ + from __future__ import annotations from dataclasses import dataclass diff --git a/qdax/core/emitters/pbt_variation_operators.py b/qdax/core/emitters/pbt_variation_operators.py index bd76ecd1..cef42edf 100644 --- a/qdax/core/emitters/pbt_variation_operators.py +++ b/qdax/core/emitters/pbt_variation_operators.py @@ -94,7 +94,10 @@ def td3_pbt_variation_fn( training_state1.critic_params, training_state2.critic_params, ) - (policy_params, critic_params,), random_key = isoline_variation( + ( + policy_params, + critic_params, + ), random_key = isoline_variation( x1=(policy_params1, critic_params1), x2=(policy_params2, critic_params2), random_key=random_key, diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index 26a0b693..81072b5d 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -5,6 +5,7 @@ it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm. """ + import functools from dataclasses import dataclass from typing import Callable diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c07e3b18..27635732 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -366,7 +366,11 @@ def _train_critics( ) # Update greedy actor - (actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + actor_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -514,7 +518,11 @@ def scan_train_policy( new_policy_optimizer_state, ), () - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (), diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c71b0013..30bb4e5b 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -173,7 +174,12 @@ def scan_update( The updated repertoire and emitter state, with a new random key and metrics. """ repertoire, emitter_state, random_key = carry - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6c06b785..c8266931 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites Low-Spread algorithm.""" + from __future__ import annotations from functools import partial diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index ea7618ba..3cb52a3e 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -7,7 +7,6 @@ Licensed under the Apache License, Version 2.0 (the "License") """ - import functools from typing import Any, Tuple diff --git a/qdax/core/neuroevolution/normalization_utils.py b/qdax/core/neuroevolution/normalization_utils.py index 63820921..62c12b8a 100644 --- a/qdax/core/neuroevolution/normalization_utils.py +++ b/qdax/core/neuroevolution/normalization_utils.py @@ -1,6 +1,5 @@ """Utilities functions to perform normalization (generally on observations in RL).""" - from typing import NamedTuple import jax.numpy as jnp diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 1a622b42..1850fd2b 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_brax( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: EnvState, diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index 03695d66..a22418e8 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_jumanji( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: JumanjiState, diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index 618fc72a..f74235ef 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites-sampling algorithm.""" + from functools import partial from typing import Callable, Tuple diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index c86bd622..1f3f43fd 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index fdd9330b..3d144bf3 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -125,7 +125,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 1e782f2a..2a8d3d1f 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -1,4 +1,5 @@ """Testing script for the algorithm DADS""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 0b9af46e..77094ffd 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -1,5 +1,6 @@ """Training script for the algorithm DADS, should be launched with hydra. e.g. python train_dads.py config=dads_ant""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 4e11370b..1dfd7bb0 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -119,7 +119,11 @@ def scoring_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( algo_instance.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index cd7477ed..5bd2dc0d 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -185,7 +185,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index 7b0f0639..bda441ba 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 24434275..a3e32eee 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -189,7 +189,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index e8d0c257..be58c453 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -239,7 +239,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), _metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), _metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index 93b3e081..ebf712d5 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -96,7 +96,11 @@ def test_multi_emitter() -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 170843df..4e1151f3 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -143,7 +143,11 @@ def play_step_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 21f90517..e970e08d 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -142,7 +142,11 @@ def metrics_fn(repertoire: MELSRepertoire) -> Dict: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mels.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index c70683ef..20dafcdc 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -131,7 +131,11 @@ def scoring_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mome.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index e71e761c..98361b23 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -96,7 +96,11 @@ def test_arm(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index f8c63259..c12518fb 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -84,7 +84,11 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index a390f709..3d619353 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -102,7 +102,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index a0542e9b..46f6ce9b 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -117,7 +117,11 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 7b310389..87913364 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -92,7 +92,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (),