From f2c6311401031439da7634b9b83f729498da04f6 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 30 Aug 2024 17:32:32 +0100 Subject: [PATCH] feat: Upgrade Library Versions and Python Version (#187) - moving to Python 3.10 - Upgrade all library versions in requirements.txt and setup.py - Remove DM-Haiku cause it is now deprecated. - all networks now are based on a single MLP class - jax.tree_map has been replaced with jax.tree_util.tree_map everywhere to avoid Deprecation Warnings. - types -> custom_types - add extra_require for jax[cuda12] - fix all notebooks and update typing extensions (for running notebooks) - fix dependabot security issues - added instructions for pip install qdax[cuda12] in README - fix observation space in jumanji test script --- .pre-commit-config.yaml | 16 +- .readthedocs.yaml | 2 +- README.md | 6 + dev.Dockerfile | 6 +- docs/installation.md | 2 +- environment.yaml | 4 +- examples/aurora.ipynb | 7 +- examples/cmaes.ipynb | 36 +-- examples/cmame.ipynb | 8 +- examples/cmamega.ipynb | 2 +- examples/dads.ipynb | 8 +- examples/diayn.ipynb | 8 +- examples/distributed_mapelites.ipynb | 16 +- examples/jumanji_snake.ipynb | 74 ++--- examples/me_sac_pbt.ipynb | 11 +- examples/me_td3_pbt.ipynb | 15 +- examples/mome.ipynb | 48 ++-- examples/pgame.ipynb | 2 +- examples/sac_pbt.ipynb | 71 +++-- examples/scripts/me_example.py | 7 +- examples/smerl.ipynb | 6 - examples/td3_pbt.ipynb | 47 ++- qdax/baselines/dads.py | 17 +- qdax/baselines/dads_smerl.py | 2 +- qdax/baselines/diayn.py | 2 +- qdax/baselines/diayn_smerl.py | 2 +- qdax/baselines/genetic_algorithm.py | 3 +- qdax/baselines/nsga2.py | 2 +- qdax/baselines/pbt.py | 2 +- qdax/baselines/sac.py | 17 +- qdax/baselines/sac_pbt.py | 2 +- qdax/baselines/spea2.py | 2 +- qdax/baselines/td3.py | 12 +- qdax/baselines/td3_pbt.py | 7 +- qdax/core/aurora.py | 4 +- qdax/core/cmaes.py | 3 +- qdax/core/containers/archive.py | 10 +- qdax/core/containers/ga_repertoire.py | 2 +- qdax/core/containers/mapelites_repertoire.py | 13 +- qdax/core/containers/mels_repertoire.py | 11 +- qdax/core/containers/mome_repertoire.py | 2 +- qdax/core/containers/nsga2_repertoire.py | 8 +- qdax/core/containers/repertoire.py | 7 +- qdax/core/containers/spea2_repertoire.py | 2 +- .../containers/uniform_replacement_archive.py | 4 +- .../containers/unstructured_repertoire.py | 23 +- qdax/core/distributed_map_elites.py | 18 +- qdax/core/emitters/cma_emitter.py | 9 +- qdax/core/emitters/cma_improvement_emitter.py | 6 +- qdax/core/emitters/cma_mega_emitter.py | 10 +- qdax/core/emitters/cma_opt_emitter.py | 2 +- qdax/core/emitters/cma_pool_emitter.py | 2 +- qdax/core/emitters/cma_rnd_emitter.py | 4 +- qdax/core/emitters/dcg_me_emitter.py | 2 +- qdax/core/emitters/dpg_emitter.py | 22 +- qdax/core/emitters/emitter.py | 2 +- qdax/core/emitters/mees_emitter.py | 25 +- qdax/core/emitters/multi_emitter.py | 2 +- qdax/core/emitters/mutation_operators.py | 2 +- qdax/core/emitters/omg_mega_emitter.py | 9 +- qdax/core/emitters/pbt_me_emitter.py | 2 +- qdax/core/emitters/pbt_variation_operators.py | 7 +- qdax/core/emitters/pga_me_emitter.py | 2 +- qdax/core/emitters/qdcg_emitter.py | 18 +- qdax/core/emitters/qdpg_emitter.py | 3 +- qdax/core/emitters/qpg_emitter.py | 18 +- qdax/core/emitters/standard_emitters.py | 2 +- qdax/core/map_elites.py | 10 +- qdax/core/mels.py | 3 +- qdax/core/mome.py | 2 +- qdax/core/neuroevolution/buffers/buffer.py | 2 +- .../buffers/trajectory_buffer.py | 2 +- qdax/core/neuroevolution/losses/dads_loss.py | 9 +- qdax/core/neuroevolution/losses/diayn_loss.py | 2 +- qdax/core/neuroevolution/losses/sac_loss.py | 2 +- qdax/core/neuroevolution/losses/td3_loss.py | 2 +- qdax/core/neuroevolution/mdp_utils.py | 4 +- .../neuroevolution/networks/dads_networks.py | 271 +++++++----------- .../neuroevolution/networks/diayn_networks.py | 117 ++++---- .../neuroevolution/networks/sac_networks.py | 89 +++--- .../networks/seq2seq_networks.py | 1 - .../neuroevolution/normalization_utils.py | 3 +- qdax/core/neuroevolution/sac_td3_utils.py | 6 +- qdax/{types.py => custom_types.py} | 0 qdax/environments/bd_extractors.py | 2 +- qdax/environments/exploration_wrappers.py | 8 +- qdax/environments/locomotion_wrappers.py | 4 +- qdax/environments/pointmaze.py | 20 +- qdax/environments/wrappers.py | 12 +- qdax/tasks/arm.py | 2 +- qdax/tasks/brax_envs.py | 3 +- qdax/tasks/hypervolume_functions.py | 2 +- qdax/tasks/jumanji_envs.py | 5 +- qdax/tasks/qd_suite/archimedean_spiral.py | 2 +- qdax/tasks/qd_suite/deceptive_evolvability.py | 2 +- qdax/tasks/qd_suite/qd_suite_task.py | 2 +- qdax/tasks/qd_suite/ssf.py | 2 +- qdax/tasks/standard_functions.py | 2 +- qdax/utils/metrics.py | 2 +- qdax/utils/pareto_front.py | 2 +- qdax/utils/plotting.py | 4 +- qdax/utils/sampling.py | 3 +- qdax/utils/train_seq2seq.py | 4 +- requirements.txt | 27 +- setup.py | 30 +- tests/baselines_test/cmame_test.py | 12 +- tests/baselines_test/cmamega_test.py | 8 +- tests/baselines_test/dads_smerl_test.py | 1 + tests/baselines_test/dads_test.py | 1 + tests/baselines_test/ga_test.py | 12 +- tests/baselines_test/me_pbt_sac_test.py | 8 +- tests/baselines_test/me_pbt_td3_test.py | 8 +- tests/baselines_test/mees_test.py | 8 +- tests/baselines_test/omgmega_test.py | 8 +- tests/baselines_test/pbt_sac_test.py | 2 +- tests/baselines_test/pbt_td3_test.py | 2 +- tests/baselines_test/pgame_test.py | 8 +- tests/baselines_test/qdpg_test.py | 8 +- tests/baselines_test/sac_test.py | 2 +- tests/core_test/aurora_test.py | 2 +- .../mapelites_repertoire_test.py | 2 +- .../containers_test/mels_repertoire_test.py | 2 +- .../emitters_test/multi_emitter_test.py | 6 +- tests/core_test/map_elites_test.py | 8 +- tests/core_test/mels_test.py | 8 +- tests/core_test/mome_test.py | 12 +- .../buffers_test/buffer_test.py | 12 +- .../buffers_test/trajectory_buffer_test.py | 4 +- tests/default_tasks_test/arm_test.py | 6 +- tests/default_tasks_test/brax_task_test.py | 6 +- .../hypervolume_functions_test.py | 6 +- tests/default_tasks_test/jumanji_envs_test.py | 17 +- tests/default_tasks_test/qd_suite_test.py | 6 +- .../standard_functions_test.py | 6 +- tests/environments_test/pointmaze_test.py | 2 +- tests/utils_test/sampling_test.py | 2 +- tool.Dockerfile | 2 +- 137 files changed, 857 insertions(+), 733 deletions(-) rename qdax/{types.py => custom_types.py} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9329a64..1414d749 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 24.8.0 hooks: - id: black - language_version: python3.9 - args: ["--target-version", "py39"] + language_version: python3.10 + args: ["--target-version", "py310"] - repo: https://github.com/PyCQA/flake8 - rev: 3.8.4 + rev: 7.1.1 hooks: - id: flake8 args: ['--max-line-length=88', '--extend-ignore=E203'] @@ -21,12 +21,12 @@ repos: - flake8-comprehensions - flake8-bugbear - repo: https://github.com/kynan/nbstripout - rev: 0.3.9 + rev: 0.7.1 hooks: - id: nbstripout args: ["examples/"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.6.0 hooks: - id: debug-statements - id: requirements-txt-fixer @@ -42,6 +42,6 @@ repos: - id: trailing-whitespace # This hook trims trailing whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 + rev: v1.11.2 hooks: - id: mypy diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d9f0965b..e22967ae 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.10" apt_packages: - swig diff --git a/README.md b/README.md index 551680eb..dab09614 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,12 @@ QDax is available on PyPI and can be installed with: ```bash pip install qdax ``` + +To install QDax with CUDA 12 support, use: +```bash +pip install qdax[cuda12] +``` + Alternatively, the latest commit of QDax can be installed directly from source with: ```bash pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main diff --git a/dev.Dockerfile b/dev.Dockerfile index 458599db..305e29e0 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \ FROM python as test-image -ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy310/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ @@ -26,7 +26,7 @@ RUN pip install -r requirements-dev.txt FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image -ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy310/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH @@ -70,7 +70,7 @@ RUN apt-get update && \ libosmesa6-dev \ patchelf \ python3-opengl \ - python3-dev=3.9* \ + python3-dev=3.10* \ python3-pip \ screen \ sudo \ diff --git a/docs/installation.md b/docs/installation.md index 90c62659..585af828 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -86,7 +86,7 @@ git clone git@github.com:adaptive-intelligent-robotics/QDax.git 2. Activate the environment and manually install the package qdax ```zsh - conda activate qdaxpy39 + conda activate qdaxpy310 pip install -e . ``` diff --git a/environment.yaml b/environment.yaml index 0ddf80d5..d93726af 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,9 +1,9 @@ -name: qdaxpy39 +name: qdaxpy310 channels: - defaults - conda-forge dependencies: -- python=3.9 +- python=3.10 - pip>=20.3.3 - conda>=4.9.2 - pip: diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index e4b86238..645ed911 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -512,11 +512,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -530,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index c8e2a9fe..d7b30b1d 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "5c4ab97a", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmaes.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "222bbe00", + "id": "1", "metadata": {}, "source": [ "# Optimizing with CMA-ES in Jax\n", @@ -26,7 +26,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d731f067", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "7b6e910b", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -80,7 +80,7 @@ { "cell_type": "code", "execution_count": null, - "id": "404fb0dc", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "ccc7cbeb", + "id": "5", "metadata": { "pycharm": { "name": "#%% md\n" @@ -111,7 +111,7 @@ { "cell_type": "code", "execution_count": null, - "id": "436dccbb", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -133,7 +133,7 @@ }, { "cell_type": "markdown", - "id": "62bdd2a4", + "id": "7", "metadata": { "pycharm": { "name": "#%% md\n" @@ -146,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4cf03f55", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -167,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "f1f69f50", + "id": "9", "metadata": { "pycharm": { "name": "#%% md\n" @@ -180,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1a95b74d", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -194,7 +194,7 @@ }, { "cell_type": "markdown", - "id": "ac2d5c0d", + "id": "11", "metadata": { "pycharm": { "name": "#%% md\n" @@ -207,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "363198ca", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -245,7 +245,7 @@ }, { "cell_type": "markdown", - "id": "0e5820b8", + "id": "13", "metadata": {}, "source": [ "## Check final fitnesses and distribution mean" @@ -254,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e4a2c7b", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +272,7 @@ }, { "cell_type": "markdown", - "id": "f3bd2b0f", + "id": "15", "metadata": { "pycharm": { "name": "#%% md\n" @@ -285,7 +285,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ad85551c", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -333,7 +333,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 3c355eea..ff5fa5c2 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -141,9 +141,9 @@ "def clip(x: jnp.ndarray):\n", " in_bound = (x <= maxval) * (x >= minval)\n", " return jnp.where(\n", - " condition=in_bound,\n", - " x=x,\n", - " y=(maxval / x)\n", + " in_bound,\n", + " x,\n", + " (maxval / x)\n", " )\n", "\n", "def _behavior_descriptor_1(x: jnp.ndarray):\n", @@ -387,7 +387,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index e5749993..739ac3d5 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -315,7 +315,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/dads.ipynb b/examples/dads.ipynb index b3cc43b5..47abd1ec 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -554,7 +548,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 0562e7c2..8e085fce 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -544,7 +538,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 574c56a2..ea2f9b9b 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -40,7 +40,12 @@ "from IPython.display import clear_output\n", "import functools\n", "\n", - "from tqdm import tqdm\n", + "try:\n", + " from tqdm import tqdm\n", + "except:\n", + " !pip install tqdm | tail -n 1\n", + " from tqdm import tqdm\n", + "\n", "import time\n", "\n", "import jax\n", @@ -128,8 +133,7 @@ "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", - "# devices = jax.devices('gpu')\n", - "devices = jax.devices('tpu')\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f'Detected the following {num_devices} device(s): {devices}')" ] @@ -351,7 +355,7 @@ "random_key = jnp.stack(random_key)\n", "\n", "# add a dimension for devices\n", - "init_variables = jax.tree_map(\n", + "init_variables = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (num_devices, batch_size_per_device,) + x.shape[1:]),\n", " init_variables\n", ")\n", @@ -397,7 +401,7 @@ " repertoire, emitter_state, random_key, metrics = update_fn(repertoire, emitter_state, random_key)\n", "\n", " # get metrics\n", - " metrics = jax.tree_map(lambda x: x[0], metrics)\n", + " metrics = jax.tree_util.tree_map(lambda x: x[0], metrics)\n", " timelapse = time.time() - start_time\n", "\n", " # log metrics\n", @@ -454,7 +458,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index a6a140fd..bfba1e5a 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "233e0f03", + "id": "0", "metadata": {}, "source": [ "# Training a population on Jumanji-Snake with QDax\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47b46c2f", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "03c2f1f7", + "id": "2", "metadata": {}, "source": [ "## Define hyperparameters" @@ -87,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52dd1e3b", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -97,7 +97,7 @@ "population_size = 100\n", "batch_size = population_size\n", "\n", - "num_iterations = 5000\n", + "num_iterations = 1000\n", "\n", "iso_sigma = 0.005\n", "line_sigma = 0.05" @@ -105,7 +105,7 @@ }, { "cell_type": "markdown", - "id": "8b8c890a", + "id": "4", "metadata": {}, "source": [ "## Instantiate the snake environment" @@ -114,7 +114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a842cccc", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "markdown", - "id": "776862f1", + "id": "6", "metadata": {}, "source": [ "## Define the type of policy that will be used to solve the problem" @@ -141,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a1ce7d0", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +161,7 @@ }, { "cell_type": "markdown", - "id": "49586b07", + "id": "8", "metadata": {}, "source": [ "## Utils to interact with the environment\n", @@ -172,12 +172,12 @@ { "cell_type": "code", "execution_count": null, - "id": "d1ff7827", + "id": "9", "metadata": {}, "outputs": [], "source": [ "def observation_processing(observation):\n", - " network_input = jnp.ravel(observation)\n", + " network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])\n", " return network_input\n", "\n", "\n", @@ -207,7 +207,7 @@ " obs=timestep.observation,\n", " next_obs=next_timestep.observation,\n", " rewards=next_timestep.reward,\n", - " dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),\n", + " dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),\n", " actions=action,\n", " truncations=jnp.array(0),\n", " state_desc=state_desc,\n", @@ -219,7 +219,7 @@ }, { "cell_type": "markdown", - "id": "0078bc01", + "id": "10", "metadata": {}, "source": [ "## Init a population of policies\n", @@ -230,7 +230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6cbd2065", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ "\n", "# compute observation size from observation spec\n", "obs_spec = env.observation_spec()\n", - "observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))\n", + "observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))\n", "\n", "fake_batch = jnp.zeros(shape=(batch_size, observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", @@ -255,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "fe6bf07f", + "id": "12", "metadata": {}, "source": [ "## Define a method to extract behavior descriptor when relevant" @@ -264,7 +264,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a264b672", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +311,7 @@ }, { "cell_type": "markdown", - "id": "1cdc5f87", + "id": "14", "metadata": {}, "source": [ "## Define the scoring function" @@ -320,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b77d826", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "6555491a", + "id": "16", "metadata": {}, "source": [ "## Define the emitter used" @@ -342,7 +342,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30061ff4", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "da7e9b74", + "id": "18", "metadata": {}, "source": [ "## Define the algorithm used and apply the initial step\n", @@ -371,7 +371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7b5c2d6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ }, { "cell_type": "markdown", - "id": "9b1bfee5", + "id": "20", "metadata": {}, "source": [ "## Run the optimization loop" @@ -424,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1af3a35", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "114ea4a8", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "92a35bf0", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79ada2d5", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -472,7 +472,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe5da301", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "93d8154e", + "id": "26", "metadata": {}, "source": [ "## Play snake with the best policy\n", @@ -500,7 +500,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3ff882f4", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -511,7 +511,7 @@ { "cell_type": "code", "execution_count": null, - "id": "762c167e", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -524,7 +524,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07523e33", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -537,7 +537,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c75ce088", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -550,7 +550,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50ef95f6", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +563,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40a03409", + "id": "32", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index b2de823e..fc6fbe8b 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -38,12 +38,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -80,7 +74,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -261,7 +256,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 238f703c..f28a1db4 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -39,12 +39,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -82,7 +76,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -264,7 +259,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { @@ -443,7 +438,7 @@ "num_cols = 5\n", "\n", "fig, axes = plt.subplots(\n", - " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30)\n", + " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30), squeeze=False,\n", ")\n", "for i, repertoire in enumerate(repertoires):\n", "\n", @@ -492,7 +487,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/mome.ipynb b/examples/mome.ipynb index bf0a5225..7e28b608 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "59f748d3", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "a5e13ff6", + "id": "1", "metadata": {}, "source": [ "# Optimizing multiple objectives with MOME in Jax\n", @@ -28,7 +28,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af063418", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "markdown", - "id": "22495c16", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b96b5d07", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ }, { "cell_type": "markdown", - "id": "c2850d54", + "id": "5", "metadata": {}, "source": [ "## Define the scoring function: rastrigin multi-objective\n", @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5effe11", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231d273d", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "29250e72", + "id": "8", "metadata": {}, "source": [ "## Define the metrics function that will be used" @@ -187,7 +187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab5d6334", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "a4828ca8", + "id": "10", "metadata": {}, "source": [ "## Define the initial population and the emitter" @@ -211,7 +211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebf3bd27", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "c904664b", + "id": "12", "metadata": {}, "source": [ "## Compute the centroids" @@ -257,7 +257,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76547c4c", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "markdown", - "id": "15936d15", + "id": "14", "metadata": {}, "source": [ "## Define a MOME instance" @@ -282,7 +282,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07a0d1d9", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "f7ec5a77", + "id": "16", "metadata": {}, "source": [ "## Init the algorithm" @@ -304,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c05cbf1e", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +318,7 @@ }, { "cell_type": "markdown", - "id": "6de4cedf", + "id": "18", "metadata": {}, "source": [ "## Run MOME iterations" @@ -327,7 +327,7 @@ { "cell_type": "code", "execution_count": null, - "id": "96ea04e6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "3ff9ca98", + "id": "20", "metadata": {}, "source": [ "## Plot the results" @@ -353,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6766dc4f", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28ab56c9", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab36cb7", + "id": "23", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 9b638b2d..03fd9c00 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -106,7 +106,7 @@ "#@markdown ---\n", "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", "episode_length = 250 #@param {type:\"integer\"}\n", - "num_iterations = 4000 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", "seed = 42 #@param {type:\"integer\"}\n", "policy_hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", "iso_sigma = 0.005 #@param {type:\"number\"}\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 7762083f..a484b035 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1606cdf6", + "id": "0", "metadata": { "jupyter": { "outputs_hidden": false @@ -44,12 +44,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -67,7 +61,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df61dcc5", + "id": "1", "metadata": { "jupyter": { "outputs_hidden": false @@ -84,7 +78,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7d0b0ef", + "id": "2", "metadata": { "jupyter": { "outputs_hidden": false @@ -95,7 +89,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -103,7 +98,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f342948", + "id": "3", "metadata": { "jupyter": { "outputs_hidden": false @@ -123,8 +118,8 @@ "buffer_size = 100000\n", "\n", "# PBT Config\n", - "num_best_to_replace_from = 20\n", - "num_worse_to_replace = 40\n", + "num_best_to_replace_from = 1\n", + "num_worse_to_replace = 1\n", "\n", "# SAC config\n", "batch_size = 256\n", @@ -144,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "090f8d4d", + "id": "4", "metadata": { "jupyter": { "outputs_hidden": false @@ -175,7 +170,7 @@ { "cell_type": "code", "execution_count": null, - "id": "efac713a", + "id": "5", "metadata": { "jupyter": { "outputs_hidden": false @@ -193,7 +188,7 @@ " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_map(\n", + " lambda tree: jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(\n", " x,\n", " (\n", @@ -214,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bccfc6d", + "id": "6", "metadata": { "jupyter": { "outputs_hidden": false @@ -237,7 +232,7 @@ { "cell_type": "code", "execution_count": null, - "id": "708eea0a", + "id": "7", "metadata": { "jupyter": { "outputs_hidden": false @@ -266,7 +261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e6e2bec", + "id": "8", "metadata": { "jupyter": { "outputs_hidden": false @@ -293,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f09fe1e", + "id": "9", "metadata": { "jupyter": { "outputs_hidden": false @@ -311,7 +306,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66ba826a", + "id": "10", "metadata": { "jupyter": { "outputs_hidden": false @@ -336,7 +331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a49af55e", + "id": "11", "metadata": { "jupyter": { "outputs_hidden": false @@ -362,7 +357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b137c8e5", + "id": "12", "metadata": { "jupyter": { "outputs_hidden": false @@ -384,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1dbbb855", + "id": "13", "metadata": { "jupyter": { "outputs_hidden": false @@ -397,8 +392,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_map(\n", + " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -407,7 +402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a052ba2e", + "id": "14", "metadata": { "jupyter": { "outputs_hidden": false @@ -447,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f7354f2", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" @@ -461,7 +456,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd1c27e8-1fe7-464d-8af3-72fa8d61852d", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -471,13 +466,13 @@ "source": [ "training_states = unshard_fn(training_states)\n", "best_idx = jnp.argmax(population_returns)\n", - "best_training_state = jax.tree_map(lambda x: x[best_idx], training_states)" + "best_training_state = jax.tree_util.tree_map(lambda x: x[best_idx], training_states)" ] }, { "cell_type": "code", "execution_count": null, - "id": "60e8ee82-27cf-4fa6-b189-66e6e10e2177", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" @@ -491,7 +486,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84bd809c-0127-4241-9556-9e81e550bbd2", + "id": "18", "metadata": { "pycharm": { "name": "#%%\n" @@ -509,7 +504,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2954be53-ffa5-42cf-8696-7f40f139edaf", + "id": "19", "metadata": { "pycharm": { "name": "#%%\n" @@ -523,7 +518,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14026ff2-e7f2-46eb-91d5-6e7394136e96", + "id": "20", "metadata": { "pycharm": { "name": "#%%\n" @@ -537,7 +532,7 @@ "rng = jax.random.PRNGKey(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", - "training_state, env_state = jax.tree_map(\n", + "training_state, env_state = jax.tree_util.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n", ")\n", "\n", @@ -552,7 +547,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5701084-e876-43f8-8de0-4216361ef5b4", + "id": "21", "metadata": { "pycharm": { "name": "#%%\n" @@ -561,7 +556,7 @@ "outputs": [], "source": [ "rollout = [\n", - " jax.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", + " jax.tree_util.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", " for env_state in rollout\n", "]" ] @@ -569,7 +564,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85bb7556-37bb-4a20-88b3-28b298c8b0a9", + "id": "22", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 699c6aba..433bc1d2 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -79,7 +79,12 @@ def run_me() -> None: # Run MAP-Elites loop for _ in range(num_iterations): - (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = map_elites.update( repertoire, emitter_state, random_key, diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index fe655fe2..d50f448f 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index ec98b9da..484f6d12 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf95707f", + "id": "0", "metadata": { "pycharm": { "name": "#%%\n" @@ -41,12 +41,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -62,7 +56,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15a43429", + "id": "1", "metadata": { "pycharm": { "name": "#%%\n" @@ -76,7 +70,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32d15301", + "id": "2", "metadata": { "pycharm": { "name": "#%%\n" @@ -84,7 +78,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"gpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -92,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7520b673", + "id": "3", "metadata": { "pycharm": { "name": "#%%\n" @@ -129,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3718a4c", + "id": "4", "metadata": { "pycharm": { "name": "#%%\n" @@ -157,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5485a16c", + "id": "5", "metadata": { "pycharm": { "name": "#%%\n" @@ -172,7 +167,7 @@ " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_map(\n", + " lambda tree: jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(\n", " x, (population_size_per_device, env_batch_size,) + x.shape[1:]\n", " ),\n", @@ -188,7 +183,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4dc22ec4", + "id": "6", "metadata": { "pycharm": { "name": "#%%\n" @@ -208,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9c610ba", + "id": "7", "metadata": { "pycharm": { "name": "#%%\n" @@ -232,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f6fd3b9", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -256,7 +251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a412cd4f", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -271,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "535250a8", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -293,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d24156f4", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -316,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23037e97", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -335,7 +330,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9ebb235", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -345,8 +340,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_map(\n", + " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -355,7 +350,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a58253e5", + "id": "14", "metadata": { "pycharm": { "name": "#%%\n" @@ -392,7 +387,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6111e836", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 41f2ff08..bd4f4534 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -25,7 +25,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor +from qdax.custom_types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor class DadsTrainingState(TrainingState): @@ -430,12 +430,17 @@ def _update_dynamics( """ training_state, transitions = operand - dynamics_loss, dynamics_gradient = jax.value_and_grad(self._dynamics_loss_fn,)( + dynamics_loss, dynamics_gradient = jax.value_and_grad( + self._dynamics_loss_fn, + )( training_state.dynamics_params, transitions=transitions, ) - (dynamics_updates, dynamics_optimizer_state,) = self._dynamics_optimizer.update( + ( + dynamics_updates, + dynamics_optimizer_state, + ) = self._dynamics_optimizer.update( dynamics_gradient, training_state.dynamics_optimizer_state ) dynamics_params = optax.apply_updates( @@ -483,7 +488,11 @@ def _update_networks( random_key = training_state.random_key # Update skill-dynamics - (dynamics_params, dynamics_loss, dynamics_optimizer_state,) = jax.lax.cond( + ( + dynamics_params, + dynamics_loss, + dynamics_optimizer_state, + ) = jax.lax.cond( training_state.steps % self._config.dynamics_update_freq == 0, self._update_dynamics, self._not_update_dynamics, diff --git a/qdax/baselines/dads_smerl.py b/qdax/baselines/dads_smerl.py index 206f0012..5bd8274d 100644 --- a/qdax/baselines/dads_smerl.py +++ b/qdax/baselines/dads_smerl.py @@ -14,7 +14,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer from qdax.core.neuroevolution.normalization_utils import normalize_with_rmstd -from qdax.types import Metrics, Reward +from qdax.custom_types import Metrics, Reward @dataclass diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index c03cfb3f..0ebdfc32 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -20,7 +20,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.diayn_networks import make_diayn_networks from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor +from qdax.custom_types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor class DiaynTrainingState(TrainingState): diff --git a/qdax/baselines/diayn_smerl.py b/qdax/baselines/diayn_smerl.py index 2966a692..daacaa74 100644 --- a/qdax/baselines/diayn_smerl.py +++ b/qdax/baselines/diayn_smerl.py @@ -13,7 +13,7 @@ from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer -from qdax.types import Metrics, Reward +from qdax.custom_types import Metrics, Reward @dataclass diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index a01a13b1..b4c6a32f 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -1,4 +1,5 @@ """Core components of a basic genetic algorithm.""" + from functools import partial from typing import Any, Callable, Optional, Tuple @@ -6,7 +7,7 @@ from qdax.core.containers.ga_repertoire import GARepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ExtraScores, Fitness, Genotype, Metrics, RNGKey +from qdax.custom_types import ExtraScores, Fitness, Genotype, Metrics, RNGKey class GeneticAlgorithm: diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index a889eadc..afd587af 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -13,7 +13,7 @@ from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.containers.nsga2_repertoire import NSGA2Repertoire from qdax.core.emitters.emitter import EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class NSGA2(GeneticAlgorithm): diff --git a/qdax/baselines/pbt.py b/qdax/baselines/pbt.py index 65d1a950..6555c537 100644 --- a/qdax/baselines/pbt.py +++ b/qdax/baselines/pbt.py @@ -6,7 +6,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer -from qdax.types import RNGKey +from qdax.custom_types import RNGKey class PBTTrainingState(PyTreeNode): diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index a5ce15c5..482c5715 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -32,7 +32,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Mask, @@ -449,7 +449,10 @@ def _update_alpha( random_key=subkey, ) alpha_optimizer = optax.adam(learning_rate=alpha_lr) - (alpha_updates, alpha_optimizer_state,) = alpha_optimizer.update( + ( + alpha_updates, + alpha_optimizer_state, + ) = alpha_optimizer.update( alpha_gradient, training_state.alpha_optimizer_state ) alpha_params = optax.apply_updates( @@ -503,7 +506,10 @@ def _update_critic( random_key=subkey, ) critic_optimizer = optax.adam(learning_rate=critic_lr) - (critic_updates, critic_optimizer_state,) = critic_optimizer.update( + ( + critic_updates, + critic_optimizer_state, + ) = critic_optimizer.update( critic_gradient, training_state.critic_optimizer_state ) critic_params = optax.apply_updates( @@ -556,7 +562,10 @@ def _update_actor( random_key=subkey, ) policy_optimizer = optax.adam(learning_rate=policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index 9aa2ff4c..947a7183 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -22,7 +22,7 @@ ) from qdax.core.neuroevolution.normalization_utils import normalize_with_rmstd from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn -from qdax.types import Descriptor, Mask, Metrics, RNGKey +from qdax.custom_types import Descriptor, Mask, Metrics, RNGKey class PBTSacTrainingState(PBTTrainingState, SacTrainingState): diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index c52063b6..10d195ad 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -15,7 +15,7 @@ from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.containers.spea2_repertoire import SPEA2Repertoire from qdax.core.emitters.emitter import EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class SPEA2(GeneticAlgorithm): diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index e09b5254..97f37893 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -23,7 +23,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.td3_networks import make_td3_networks from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Mask, @@ -76,7 +76,10 @@ class TD3: def __init__(self, config: TD3Config, action_size: int): self._config = config - self._policy, self._critic, = make_td3_networks( + ( + self._policy, + self._critic, + ) = make_td3_networks( action_size=action_size, critic_hidden_layer_sizes=self._config.critic_hidden_layer_size, policy_hidden_layer_sizes=self._config.policy_hidden_layer_size, @@ -421,7 +424,10 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam( learning_rate=self._config.policy_learning_rate ) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/td3_pbt.py b/qdax/baselines/td3_pbt.py index 60cd8a38..5762956d 100644 --- a/qdax/baselines/td3_pbt.py +++ b/qdax/baselines/td3_pbt.py @@ -25,7 +25,7 @@ td3_policy_loss_fn, ) from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn -from qdax.types import Descriptor, Mask, Metrics, Params, RNGKey +from qdax.custom_types import Descriptor, Mask, Metrics, Params, RNGKey class PBTTD3TrainingState(PBTTrainingState, TD3TrainingState): @@ -291,7 +291,10 @@ def update( def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam(learning_rate=training_state.policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index a0968ccc..f67d7b4f 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -12,8 +12,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.environments.bd_extractors import AuroraExtraInfo -from qdax.types import ( +from qdax.custom_types import ( Descriptor, Fitness, Genotype, @@ -22,6 +21,7 @@ Params, RNGKey, ) +from qdax.environments.bd_extractors import AuroraExtraInfo class AURORA: diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index 481a49bf..0e9b4084 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -2,6 +2,7 @@ Definition of CMAES class, containing main functions necessary to build a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772 """ + from functools import partial from typing import Callable, Optional, Tuple @@ -9,7 +10,7 @@ import jax import jax.numpy as jnp -from qdax.types import Fitness, Genotype, Mask, RNGKey +from qdax.custom_types import Fitness, Genotype, Mask, RNGKey class CMAESState(flax.struct.PyTreeNode): diff --git a/qdax/core/containers/archive.py b/qdax/core/containers/archive.py index 8af808f3..036c5892 100644 --- a/qdax/core/containers/archive.py +++ b/qdax/core/containers/archive.py @@ -40,7 +40,7 @@ def size(self) -> float: fake_data = jnp.isnan(self.data) # count number of real data - return sum(~fake_data) + return float(sum(~fake_data)) @classmethod def create( @@ -161,9 +161,7 @@ def insert(self, state_descriptors: jnp.ndarray) -> Archive: values, _indices = knn(self.data, state_descriptors, 1) # get indices where distance bigger than threshold - relevant_indices = jnp.where( - values.squeeze() > self.acceptance_threshold, x=0, y=1 - ) + relevant_indices = jnp.where(values.squeeze() > self.acceptance_threshold, 0, 1) def iterate_fn( carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict @@ -192,7 +190,7 @@ def iterate_fn( # get indices where distance bigger than threshold not_too_close = jnp.where( - values.squeeze() > self.acceptance_threshold, x=0, y=1 + values.squeeze() > self.acceptance_threshold, 0, 1 ) second_condition = not_too_close.sum() condition = (first_condition + second_condition) == 0 @@ -280,7 +278,7 @@ def knn( dist = jnp.nan_to_num(dist, nan=jnp.inf) # clipping necessary - numerical approx make some distancies negative - dist = jnp.sqrt(jnp.clip(dist, a_min=0.0)) + dist = jnp.sqrt(jnp.clip(dist, min=0.0)) # return values, indices values, indices = qdax_top_k(-dist, k) diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_repertoire.py index 87ade54f..403331ff 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_repertoire.py @@ -10,7 +10,7 @@ from jax.flatten_util import ravel_pytree from qdax.core.containers.repertoire import Repertoire -from qdax.types import Fitness, Genotype, RNGKey +from qdax.custom_types import Fitness, Genotype, RNGKey class GARepertoire(Repertoire): diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index b1145c34..b473d4b3 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -15,7 +15,14 @@ from numpy.random import RandomState from sklearn.cluster import KMeans -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) def compute_cvt_centroids( @@ -303,7 +310,7 @@ def add( # put dominated fitness to -jnp.inf batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf ) # get addition condition @@ -315,7 +322,7 @@ def add( # assign fake position when relevant : num_centroids is out of bound batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, batch_of_indices, num_centroids ) # create new repertoire diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py index a2e99971..7ef57bb9 100644 --- a/qdax/core/containers/mels_repertoire.py +++ b/qdax/core/containers/mels_repertoire.py @@ -14,7 +14,14 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Spread, +) def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: @@ -232,7 +239,7 @@ def add( # assign fake position when relevant : num_centroids is out of bound batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, batch_of_indices, num_centroids ) # create new repertoire diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 0e2b6d3e..43be3835 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -15,7 +15,7 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, diff --git a/qdax/core/containers/nsga2_repertoire.py b/qdax/core/containers/nsga2_repertoire.py index 74b0f454..331ef153 100644 --- a/qdax/core/containers/nsga2_repertoire.py +++ b/qdax/core/containers/nsga2_repertoire.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.containers.ga_repertoire import GARepertoire -from qdax.types import Fitness, Genotype +from qdax.custom_types import Fitness, Genotype from qdax.utils.pareto_front import compute_masked_pareto_front @@ -56,9 +56,9 @@ def _compute_crowding_distances( norm = jnp.max(srt_fitnesses, axis=0) - jnp.min(srt_fitnesses, axis=0) # get the distances - dists = jnp.row_stack( + dists = jnp.vstack( [srt_fitnesses, jnp.full(num_objective, jnp.inf)] - ) - jnp.row_stack([jnp.full(num_objective, -jnp.inf), srt_fitnesses]) + ) - jnp.vstack([jnp.full(num_objective, -jnp.inf), srt_fitnesses]) # Prepare the distance to last and next vectors dist_to_last, dist_to_next = dists, dists @@ -228,7 +228,7 @@ def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool: # get rid of the zeros (that correspond to the False from the mask) fake_indice = num_candidates + 1 # bigger than all the other indices - indices = jnp.where(indices == 0, x=fake_indice, y=indices) + indices = jnp.where(indices == 0, fake_indice, indices) # sort the indices to remove the fake indices indices = jnp.sort(indices)[: self.size] diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py index f50d53b7..77c91683 100644 --- a/qdax/core/containers/repertoire.py +++ b/qdax/core/containers/repertoire.py @@ -4,11 +4,11 @@ from __future__ import annotations -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod import flax -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class Repertoire(flax.struct.PyTreeNode, ABC): @@ -19,7 +19,8 @@ class Repertoire(flax.struct.PyTreeNode, ABC): to keep the parent classes explicit and transparent. """ - @abstractclassmethod + @classmethod + @abstractmethod def init(cls) -> Repertoire: # noqa: N805 """Create a repertoire.""" pass diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index 54870db4..33c31547 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from qdax.core.containers.ga_repertoire import GARepertoire -from qdax.types import Fitness, Genotype +from qdax.custom_types import Fitness, Genotype class SPEA2Repertoire(GARepertoire): diff --git a/qdax/core/containers/uniform_replacement_archive.py b/qdax/core/containers/uniform_replacement_archive.py index d6f233db..830878cf 100644 --- a/qdax/core/containers/uniform_replacement_archive.py +++ b/qdax/core/containers/uniform_replacement_archive.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from qdax.core.containers.archive import Archive -from qdax.types import RNGKey +from qdax.custom_types import RNGKey class UniformReplacementArchive(Archive): @@ -74,7 +74,7 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: subkey, shape=(1,), minval=0, maxval=self.max_size ) - index = jnp.where(condition=is_full, x=random_index, y=new_current_position) + index = jnp.where(is_full, random_index, new_current_position) new_data = self.data.at[index].set(state_descriptor) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index f4cc0c98..8512d3d6 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -8,7 +8,14 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from qdax.types import Centroid, Descriptor, Fitness, Genotype, Observation, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + Fitness, + Genotype, + Observation, + RNGKey, +) @partial(jax.jit, static_argnames=("k_nn",)) @@ -300,7 +307,7 @@ def add( # ReIndexing of all the inputs to the correct sorted way batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() - batch_of_genotypes = jax.tree_map( + batch_of_genotypes = jax.tree_util.tree_map( lambda x: x.at[sorted_bds].get(), batch_of_genotypes ) batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() @@ -333,7 +340,7 @@ def add( # put dominated fitness to -jnp.inf batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf ) # get addition condition @@ -347,12 +354,12 @@ def add( # assign fake position when relevant : num_centroids is out of bounds batch_of_indices = jnp.where( addition_condition, - x=batch_of_indices, - y=self.max_size, + batch_of_indices, + self.max_size, ) # create new grid - new_grid_genotypes = jax.tree_map( + new_grid_genotypes = jax.tree_util.tree_map( lambda grid_genotypes, new_genotypes: grid_genotypes.at[ batch_of_indices.squeeze() ].set(new_genotypes), @@ -398,7 +405,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey grid_empty = self.fitnesses == -jnp.inf p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), self.genotypes, ) @@ -435,7 +442,7 @@ def init( # Initialize grid with default values default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan), genotypes, ) diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index 7b5609f2..dbc6522b 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -10,7 +11,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import Centroid, Genotype, Metrics, RNGKey +from qdax.custom_types import Centroid, Genotype, Metrics, RNGKey class DistributedMAPElites(MAPElites): @@ -189,7 +190,7 @@ def get_distributed_update_fn( of MAP-Elites updates. """ - @partial(jax.jit, static_argnames=("self",)) + @jax.jit def _scan_update( carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], unused: Any, @@ -200,7 +201,12 @@ def _scan_update( repertoire, emitter_state, random_key = carry # apply one step of update - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, @@ -214,7 +220,11 @@ def update_fn( random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]: """Apply num_iterations of update.""" - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( _scan_update, (repertoire, emitter_state, random_key), (), diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index 66e5677a..315dcd9b 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -13,7 +13,14 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class CMAEmitterState(EmitterState): diff --git a/qdax/core/emitters/cma_improvement_emitter.py b/qdax/core/emitters/cma_improvement_emitter.py index 28424f3f..7c3fc98c 100644 --- a/qdax/core/emitters/cma_improvement_emitter.py +++ b/qdax/core/emitters/cma_improvement_emitter.py @@ -6,7 +6,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype class CMAImprovementEmitter(CMAEmitter): @@ -62,13 +62,13 @@ def _ranking_criteria( condition = improvements == jnp.inf # criteria: fitness if new cell, improvement else - ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + ranking_criteria = jnp.where(condition, fitnesses, improvements) # make sure to have all the new cells first new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index 1fd0e1e6..c3f87fed 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -12,7 +12,7 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -238,13 +238,13 @@ def state_update( condition = improvements == jnp.inf # criteria: fitness if new cell, improvement else - ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + ranking_criteria = jnp.where(condition, fitnesses, improvements) # make sure to have all the new cells first new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) # sort indices according to the criteria @@ -282,12 +282,12 @@ def state_update( # update theta in case of reinit theta = jax.tree_util.tree_map( - lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta + lambda x, y: jnp.where(reinitialize, x, y), random_theta, theta ) # update cmaes state in case of reinit cmaes_state = jax.tree_util.tree_map( - lambda x, y: jnp.where(reinitialize, x=x, y=y), + lambda x, y: jnp.where(reinitialize, x, y), self._cma_initial_state, cmaes_state, ) diff --git a/qdax/core/emitters/cma_opt_emitter.py b/qdax/core/emitters/cma_opt_emitter.py index d9c5bf71..9a783585 100644 --- a/qdax/core/emitters/cma_opt_emitter.py +++ b/qdax/core/emitters/cma_opt_emitter.py @@ -6,7 +6,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype class CMAOptimizingEmitter(CMAEmitter): diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index 24556f8b..55ccaa4f 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -9,7 +9,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class CMAPoolEmitterState(EmitterState): diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index e05cc453..27e4f0db 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -9,7 +9,7 @@ from qdax.core.cmaes import CMAESState from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class CMARndEmitterState(CMAEmitterState): @@ -168,7 +168,7 @@ def _ranking_criteria( condition = improvements == jnp.inf ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py index 94e0bb9d..fea237c6 100644 --- a/qdax/core/emitters/dcg_me_emitter.py +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Params, RNGKey @dataclass diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 2c55cbd2..ea921237 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -1,6 +1,7 @@ """ Implements the Diversity PG inspired by QDPG algorithm in jax for brax environments, based on: https://arxiv.org/abs/2006.08505 """ + from dataclasses import dataclass from functools import partial from typing import Any, Callable, Optional, Tuple @@ -17,8 +18,7 @@ QualityPGEmitterState, ) from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.environments.base_wrappers import QDEnv -from qdax.types import ( +from qdax.custom_types import ( Descriptor, ExtraScores, Fitness, @@ -28,6 +28,7 @@ RNGKey, StateDescriptor, ) +from qdax.environments.base_wrappers import QDEnv @dataclass @@ -180,7 +181,10 @@ def scan_train_critics( return new_emitter_state, () # sample transitions - (transitions, random_key,) = emitter_state.replay_buffer.sample( + ( + transitions, + random_key, + ) = emitter_state.replay_buffer.sample( random_key=emitter_state.random_key, sample_size=self._config.num_critic_training_steps * self._config.batch_size, @@ -249,7 +253,11 @@ def _train_critics( ) # Update greedy policy - (policy_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + policy_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -348,7 +356,11 @@ def scan_train_policy( transitions, ) - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (transitions), diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index 056798ba..21139356 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -6,7 +6,7 @@ from flax.struct import PyTreeNode from qdax.core.containers.repertoire import Repertoire -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class EmitterState(PyTreeNode): diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index 0a03a6ba..4d51326a 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -3,6 +3,7 @@ from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217 """ + from __future__ import annotations from dataclasses import dataclass @@ -19,7 +20,7 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class NoveltyArchive(flax.struct.PyTreeNode): @@ -362,7 +363,7 @@ def _sample( genotypes_empty = fitnesses < min_fitness p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), genotypes, ) @@ -429,7 +430,7 @@ def _sample_explore( repertoire_empty = novelties < min_novelty p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), repertoire.genotypes, ) @@ -486,7 +487,7 @@ def _es_emitter( # Sampling non-mirror noise else: sample_number = total_sample_number - sample_noise = jax.tree_map( + sample_noise = jax.tree_util.tree_map( lambda x: jax.random.normal( key=subkey, shape=jnp.repeat(x, sample_number, axis=0).shape, @@ -496,11 +497,11 @@ def _es_emitter( gradient_noise = sample_noise # Applying noise - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jnp.repeat(x, total_sample_number, axis=0), parent, ) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda mean, noise: mean + self._config.sample_sigma * noise, samples, sample_noise, @@ -526,7 +527,7 @@ def _es_emitter( if self._config.sample_mirror: ranks = jnp.reshape(ranks, (sample_number, 2)) ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks) - ranks = jax.tree_map( + ranks = jax.tree_util.tree_map( lambda x: jnp.reshape( jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape ), @@ -534,16 +535,16 @@ def _es_emitter( ) # Computing the gradients - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda noise, rank: jnp.multiply(noise, rank), gradient_noise, ranks, ) - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (sample_number, -1)), gradient, ) - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda g, p: jnp.reshape( -jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma), p.shape, @@ -553,7 +554,7 @@ def _es_emitter( ) # Adding regularisation - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda g, p: g + self._config.l2_coefficient * p, gradient, parent, @@ -626,7 +627,7 @@ def _buffers_update( last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set( fitnesses[0] ) - last_updated_genotypes = jax.tree_map( + last_updated_genotypes = jax.tree_util.tree_map( lambda last_gen, gen: last_gen.at[ jnp.expand_dims(last_updated_position, axis=0) ].set(gen), diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index b3ad23c6..17cb8ace 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -8,7 +8,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class MultiEmitterState(EmitterState): diff --git a/qdax/core/emitters/mutation_operators.py b/qdax/core/emitters/mutation_operators.py index f39b8060..bda2daca 100644 --- a/qdax/core/emitters/mutation_operators.py +++ b/qdax/core/emitters/mutation_operators.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey def _polynomial_mutation( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 54766152..580bd151 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -6,7 +6,14 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class OMGMEGAEmitterState(EmitterState): diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index a2266bfa..55bded4e 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -12,8 +12,8 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey class PBTEmitterState(EmitterState): diff --git a/qdax/core/emitters/pbt_variation_operators.py b/qdax/core/emitters/pbt_variation_operators.py index bd76ecd1..c8537003 100644 --- a/qdax/core/emitters/pbt_variation_operators.py +++ b/qdax/core/emitters/pbt_variation_operators.py @@ -3,7 +3,7 @@ from qdax.baselines.sac_pbt import PBTSacTrainingState from qdax.baselines.td3_pbt import PBTTD3TrainingState from qdax.core.emitters.mutation_operators import isoline_variation -from qdax.types import RNGKey +from qdax.custom_types import RNGKey def sac_pbt_variation_fn( @@ -94,7 +94,10 @@ def td3_pbt_variation_fn( training_state1.critic_params, training_state2.critic_params, ) - (policy_params, critic_params,), random_key = isoline_variation( + ( + policy_params, + critic_params, + ), random_key = isoline_variation( x1=(policy_params1, critic_params1), x2=(policy_params2, critic_params2), random_key=random_key, diff --git a/qdax/core/emitters/pga_me_emitter.py b/qdax/core/emitters/pga_me_emitter.py index e93eb696..a4f8b33f 100644 --- a/qdax/core/emitters/pga_me_emitter.py +++ b/qdax/core/emitters/pga_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Params, RNGKey @dataclass diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index 0d560cbb..0fb19c4b 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -16,8 +16,8 @@ from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @dataclass @@ -521,7 +521,11 @@ def _train_critics( ) # Update greedy actor - (actor_opt_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + actor_opt_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -580,7 +584,7 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, @@ -612,7 +616,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_map( + target_actor_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -701,7 +705,11 @@ def scan_train_policy( new_policy_opt_state, ), () - (emitter_state, policy_params, policy_opt_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_opt_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_opt_state), transitions, diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index eefd1566..b9de6090 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -5,6 +5,7 @@ it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm. """ + import functools from dataclasses import dataclass from typing import Callable @@ -17,8 +18,8 @@ from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Reward, StateDescriptor from qdax.environments.base_wrappers import QDEnv -from qdax.types import Reward, StateDescriptor @dataclass diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index 4a173b51..c6e2df7e 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -17,8 +17,8 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn from qdax.core.neuroevolution.networks.networks import QModule +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @dataclass @@ -379,7 +379,11 @@ def _train_critics( ) # Update greedy actor - (actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + actor_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -439,7 +443,7 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, @@ -471,7 +475,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_map( + target_actor_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -527,7 +531,11 @@ def scan_train_policy( new_policy_optimizer_state, ), () - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (), diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 860962d4..1d949b2d 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -6,7 +6,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ExtraScores, Genotype, RNGKey +from qdax.custom_types import ExtraScores, Genotype, RNGKey class MixingEmitter(Emitter): diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index 8b649d0c..d0b075a9 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -8,7 +9,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -169,7 +170,12 @@ def scan_update( The updated repertoire and emitter state, with a new random key and metrics. """ repertoire, emitter_state, random_key = carry - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6dc8f551..8b0e7511 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites Low-Spread algorithm.""" + from __future__ import annotations from functools import partial @@ -9,7 +10,7 @@ from qdax.core.containers.mels_repertoire import MELSRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, diff --git a/qdax/core/mome.py b/qdax/core/mome.py index db450b9a..c239bd1f 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -9,7 +9,7 @@ from qdax.core.containers.mome_repertoire import MOMERepertoire from qdax.core.emitters.emitter import EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import Centroid, RNGKey +from qdax.custom_types import Centroid, RNGKey class MOME(MAPElites): diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index d25c8c6c..5057e5e2 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Done, diff --git a/qdax/core/neuroevolution/buffers/trajectory_buffer.py b/qdax/core/neuroevolution/buffers/trajectory_buffer.py index 2cc4ab69..93e1b2f9 100644 --- a/qdax/core/neuroevolution/buffers/trajectory_buffer.py +++ b/qdax/core/neuroevolution/buffers/trajectory_buffer.py @@ -8,7 +8,7 @@ from flax import struct from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Reward, RNGKey +from qdax.custom_types import Reward, RNGKey class TrajectoryBuffer(struct.PyTreeNode): diff --git a/qdax/core/neuroevolution/losses/dads_loss.py b/qdax/core/neuroevolution/losses/dads_loss.py index b42ca416..60edfee1 100644 --- a/qdax/core/neuroevolution/losses/dads_loss.py +++ b/qdax/core/neuroevolution/losses/dads_loss.py @@ -6,7 +6,14 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.types import Action, Observation, Params, RNGKey, Skill, StateDescriptor +from qdax.custom_types import ( + Action, + Observation, + Params, + RNGKey, + Skill, + StateDescriptor, +) def make_dads_loss_fn( diff --git a/qdax/core/neuroevolution/losses/diayn_loss.py b/qdax/core/neuroevolution/losses/diayn_loss.py index 8bca3b4b..e25a73bd 100644 --- a/qdax/core/neuroevolution/losses/diayn_loss.py +++ b/qdax/core/neuroevolution/losses/diayn_loss.py @@ -7,7 +7,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.types import Action, Observation, Params, RNGKey, StateDescriptor +from qdax.custom_types import Action, Observation, Params, RNGKey, StateDescriptor def make_diayn_loss_fn( diff --git a/qdax/core/neuroevolution/losses/sac_loss.py b/qdax/core/neuroevolution/losses/sac_loss.py index b3656b18..d7289292 100644 --- a/qdax/core/neuroevolution/losses/sac_loss.py +++ b/qdax/core/neuroevolution/losses/sac_loss.py @@ -6,7 +6,7 @@ from brax.training.distribution import ParametricDistribution from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.custom_types import Action, Observation, Params, RNGKey def make_sac_loss_fn( diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index e12797b9..964c2c4f 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Descriptor, Observation, Params, RNGKey +from qdax.custom_types import Action, Descriptor, Observation, Params, RNGKey def make_td3_loss_fn( diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 3b077069..f269a22b 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Descriptor, Genotype, Params, RNGKey +from qdax.custom_types import Descriptor, Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -134,7 +134,7 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray: # the double transpose trick is here to allow easy broadcasting return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T - return jax.tree_map(mask_episodes, transition) # type: ignore + return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore def init_population_controllers( diff --git a/qdax/core/neuroevolution/networks/dads_networks.py b/qdax/core/neuroevolution/networks/dads_networks.py index beb4b77a..863bdab5 100644 --- a/qdax/core/neuroevolution/networks/dads_networks.py +++ b/qdax/core/neuroevolution/networks/dads_networks.py @@ -1,128 +1,129 @@ from typing import Optional, Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from haiku.initializers import Initializer, VarianceScaling - -from qdax.types import Action, Observation, Skill, StateDescriptor - - -class GaussianMixture(hk.Module): - """Module that outputs a Gaussian Mixture Distribution.""" - - def __init__( - self, - num_dimensions: int, - num_components: int, - reinterpreted_batch_ndims: Optional[int] = None, - identity_covariance: bool = True, - initializer: Optional[Initializer] = None, - name: str = "GaussianMixture", - ): - """Module that outputs a Gaussian Mixture Distribution - with identity covariance matrix.""" - - super().__init__(name=name) - if initializer is None: - initializer = VarianceScaling(1.0, "fan_in", "uniform") - self._num_dimensions = num_dimensions - self._num_components = num_components - self._reinterpreted_batch_ndims = reinterpreted_batch_ndims - self._identity_covariance = identity_covariance - self.initializer = initializer - logits_size = self._num_components - - self.logit_layer = hk.Linear(logits_size, w_init=self.initializer) - - # Create two layers that outputs a location and a scale, respectively, for - # each dimension and each component. - self.loc_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=self.initializer - ) - if not self._identity_covariance: - self.scale_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=self.initializer - ) +from jax.nn import initializers + +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation, Skill, StateDescriptor + +class GaussianMixture(nn.Module): + num_dimensions: int + num_components: int + reinterpreted_batch_ndims: Optional[int] = None + identity_covariance: bool = True + initializer: Optional[initializers.Initializer] = None + + @nn.compact def __call__(self, inputs: jnp.ndarray) -> tfp.distributions.Distribution: - # Compute logits, locs, and scales if necessary. - logits = self.logit_layer(inputs) - locs = self.loc_layer(inputs) + if self.initializer is None: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + else: + init = self.initializer - shape = [-1, self._num_components, self._num_dimensions] # [B, D, C] + logits = nn.Dense(self.num_components, kernel_init=init)(inputs) + locs = nn.Dense(self.num_dimensions * self.num_components, kernel_init=init)( + inputs + ) - # Reshape the mixture's location and scale parameters appropriately. + shape = [-1, self.num_components, self.num_dimensions] # [B, D, C] locs = locs.reshape(shape) - if not self._identity_covariance: - scales = self.scale_layer(inputs) + if not self.identity_covariance: + scales = nn.Dense( + self.num_dimensions * self.num_components, kernel_init=init + )(inputs) scales = scales.reshape(shape) else: scales = jnp.ones_like(locs) - # Create the mixture distribution components = tfp.distributions.MultivariateNormalDiag( loc=locs, scale_diag=scales ) mixture = tfp.distributions.Categorical(logits=logits) - distribution = tfp.distributions.MixtureSameFamily( + return tfp.distributions.MixtureSameFamily( mixture_distribution=mixture, components_distribution=components ) - return distribution - -class DynamicsNetwork(hk.Module): - """Dynamics network (used in DADS).""" +class DynamicsNetwork(nn.Module): + hidden_layer_sizes: Tuple[int, ...] + output_size: int + omit_input_dynamics_dim: int = 2 + identity_covariance: bool = True + initializer: Optional[initializers.Initializer] = None - def __init__( - self, - hidden_layer_sizes: tuple, - output_size: int, - omit_input_dynamics_dim: int = 2, - name: Optional[str] = None, - identity_covariance: bool = True, - initializer: Optional[Initializer] = None, - ): - super().__init__(name=name) - if initializer is None: - initializer = VarianceScaling(1.0, "fan_in", "uniform") + @nn.compact + def __call__( + self, obs: StateDescriptor, skill: Skill, target: StateDescriptor + ) -> jnp.ndarray: + if self.initializer is None: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + else: + init = self.initializer - self.distribution = GaussianMixture( - output_size, + distribution = GaussianMixture( + self.output_size, num_components=4, reinterpreted_batch_ndims=None, - identity_covariance=identity_covariance, - initializer=initializer, - ) - self.network = hk.Sequential( - [ - hk.nets.MLP( - list(hidden_layer_sizes), - w_init=initializer, - activation=jax.nn.relu, - activate_final=True, - ), - ] + identity_covariance=self.identity_covariance, + initializer=init, ) - self._omit_input_dynamics_dim = omit_input_dynamics_dim - def __call__( - self, obs: StateDescriptor, skill: Skill, target: StateDescriptor - ) -> jnp.ndarray: - """Normalizes the observation, predicts a distribution probability conditioned - on (obs,skill) and returns the log_prob of the target. - """ - - obs = obs[:, self._omit_input_dynamics_dim :] + obs = obs[:, self.omit_input_dynamics_dim :] obs = jnp.concatenate((obs, skill), axis=1) - out = self.network(obs) - dist = self.distribution(out) + + x = MLP( + layer_sizes=self.hidden_layer_sizes, + kernel_init=init, + activation=nn.relu, + final_activation=nn.relu, + )(obs) + + dist = distribution(x) return dist.log_prob(target) +class Actor(nn.Module): + action_size: int + hidden_layer_sizes: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + + return MLP( + layer_sizes=self.hidden_layer_sizes + (2 * self.action_size,), + kernel_init=init, + activation=nn.relu, + )(obs) + + +class Critic(nn.Module): + hidden_layer_sizes: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + input_ = jnp.concatenate([obs, action], axis=-1) + + value_1 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) + + def make_dads_networks( action_size: int, descriptor_size: int, @@ -130,78 +131,16 @@ def make_dads_networks( policy_hidden_layer_size: Tuple[int, ...] = (256, 256), omit_input_dynamics_dim: int = 2, identity_covariance: bool = True, - dynamics_initializer: Optional[Initializer] = None, -) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: - """Creates networks used in DADS. - - Args: - action_size: the size of the environment's action space - descriptor_size: the size of the environment's descriptor space (i.e. the - dimension of the dynamics network's input) - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). - omit_input_dynamics_dim: how many descriptors we omit when creating the input - of the dynamics networks. Defaults to 2. - identity_covariance: whether to fix the covariance matrix of the Gaussian models - to identity. Defaults to True. - dynamics_initializer: the initializer of the dynamics layers. Defaults to None. - - Returns: - the policy network - the critic network - the dynamics network - """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - def _dynamics_fn( - obs: StateDescriptor, skill: Skill, target: StateDescriptor - ) -> jnp.ndarray: - dynamics_network = DynamicsNetwork( - critic_hidden_layer_size, - descriptor_size, - omit_input_dynamics_dim=omit_input_dynamics_dim, - identity_covariance=identity_covariance, - initializer=dynamics_initializer, - ) - return dynamics_network(obs, skill, target) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - dynamics = hk.without_apply_rng(hk.transform(_dynamics_fn)) + dynamics_initializer: Optional[initializers.Initializer] = None, +) -> Tuple[nn.Module, nn.Module, nn.Module]: + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) + dynamics = DynamicsNetwork( + critic_hidden_layer_size, + descriptor_size, + omit_input_dynamics_dim=omit_input_dynamics_dim, + identity_covariance=identity_covariance, + initializer=dynamics_initializer, + ) return policy, critic, dynamics diff --git a/qdax/core/neuroevolution/networks/diayn_networks.py b/qdax/core/neuroevolution/networks/diayn_networks.py index c656cace..e292e131 100644 --- a/qdax/core/neuroevolution/networks/diayn_networks.py +++ b/qdax/core/neuroevolution/networks/diayn_networks.py @@ -1,10 +1,60 @@ from typing import Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp -from qdax.types import Action, Observation +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation + + +class Actor(nn.Module): + action_size: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) + + +class Critic(nn.Module): + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + input_ = jnp.concatenate([obs, action], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) + + +class Discriminator(nn.Module): + num_skills: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (self.num_skills,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) def make_diayn_networks( @@ -12,71 +62,22 @@ def make_diayn_networks( num_skills: int, critic_hidden_layer_size: Tuple[int, ...] = (256, 256), policy_hidden_layer_size: Tuple[int, ...] = (256, 256), -) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: +) -> Tuple[nn.Module, nn.Module, nn.Module]: """Creates networks used in DIAYN. Args: action_size: the size of the environment's action space num_skills: the number of skills set - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). + critic_hidden_layer_size: the number of neurons for critic hidden layers. + policy_hidden_layer_size: the number of neurons for policy hidden layers. Returns: the policy network the critic network the discriminator network """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - def _discriminator_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [num_skills], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - discriminator = hk.without_apply_rng(hk.transform(_discriminator_fn)) + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) + discriminator = Discriminator(num_skills, critic_hidden_layer_size) return policy, critic, discriminator diff --git a/qdax/core/neuroevolution/networks/sac_networks.py b/qdax/core/neuroevolution/networks/sac_networks.py index dcadfaa2..a236afd4 100644 --- a/qdax/core/neuroevolution/networks/sac_networks.py +++ b/qdax/core/neuroevolution/networks/sac_networks.py @@ -1,66 +1,65 @@ from typing import Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp -from qdax.types import Action, Observation +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation + + +class Actor(nn.Module): + action_size: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + )(obs) + + +class Critic(nn.Module): + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + input_ = jnp.concatenate([obs, action], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) def make_sac_networks( action_size: int, critic_hidden_layer_size: Tuple[int, ...] = (256, 256), policy_hidden_layer_size: Tuple[int, ...] = (256, 256), -) -> Tuple[hk.Transformed, hk.Transformed]: +) -> Tuple[nn.Module, nn.Module]: """Creates networks used in SAC. Args: action_size: the size of the environment's action space - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). + critic_hidden_layer_size: the number of neurons for critic hidden layers. + policy_hidden_layer_size: the number of neurons for policy hidden layers. Returns: the policy network the critic network """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) return policy, critic diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index ea7618ba..3cb52a3e 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -7,7 +7,6 @@ Licensed under the Apache License, Version 2.0 (the "License") """ - import functools from typing import Any, Tuple diff --git a/qdax/core/neuroevolution/normalization_utils.py b/qdax/core/neuroevolution/normalization_utils.py index 63820921..0c98b29d 100644 --- a/qdax/core/neuroevolution/normalization_utils.py +++ b/qdax/core/neuroevolution/normalization_utils.py @@ -1,11 +1,10 @@ """Utilities functions to perform normalization (generally on observations in RL).""" - from typing import NamedTuple import jax.numpy as jnp -from qdax.types import Observation +from qdax.custom_types import Observation class RunningMeanStdState(NamedTuple): diff --git a/qdax/core/neuroevolution/sac_td3_utils.py b/qdax/core/neuroevolution/sac_td3_utils.py index 1c54511a..32bbe7a4 100644 --- a/qdax/core/neuroevolution/sac_td3_utils.py +++ b/qdax/core/neuroevolution/sac_td3_utils.py @@ -5,6 +5,7 @@ We are currently thinking about elegant ways to unify both in order to avoid code repetition. """ + # TODO: Uniformize with the functions in mdp_utils from functools import partial from typing import Any, Callable, Tuple @@ -14,7 +15,7 @@ from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.mdp_utils import TrainingState -from qdax.types import Metrics +from qdax.custom_types import Metrics @partial( @@ -75,7 +76,8 @@ def generate_unroll( ], ], ) -> Tuple[EnvState, TrainingState, Transition]: - """Generates an episode according to the agent's policy, returns the final state of the + """ + Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode. """ diff --git a/qdax/types.py b/qdax/custom_types.py similarity index 100% rename from qdax/types.py rename to qdax/custom_types.py diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index af1d51ba..918fbbfb 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.types import Descriptor, Params +from qdax.custom_types import Descriptor, Params def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index ec32e7a2..c784b045 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -436,10 +436,8 @@ def step(self, state: State, action: jp.ndarray) -> State: # this line avoid this by increasing the threshold done = jp.where( state.qp.pos[0, 2] < 0.2, - x=jp.array(1, dtype=jp.float32), - y=jp.array(0, dtype=jp.float32), - ) - done = jp.where( - state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done + jp.array(1, dtype=jp.float32), + jp.array(0, dtype=jp.float32), ) + done = jp.where(state.qp.pos[0, 2] > 5.0, jp.array(1, dtype=jp.float32), done) return state.replace(obs=new_obs, reward=new_reward, done=done) # type: ignore diff --git a/qdax/environments/locomotion_wrappers.py b/qdax/environments/locomotion_wrappers.py index a727479e..982f5b69 100644 --- a/qdax/environments/locomotion_wrappers.py +++ b/qdax/environments/locomotion_wrappers.py @@ -260,7 +260,7 @@ def name(self) -> str: def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) state.info["state_descriptor"] = jnp.clip( - state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval + state.qp.pos[self._cog_idx][:2], min=self._minval, max=self._maxval ) return state @@ -268,7 +268,7 @@ def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) # get xy position of the center of gravity state.info["state_descriptor"] = jnp.clip( - state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval + state.qp.pos[self._cog_idx][:2], min=self._minval, max=self._maxval ) return state diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index b5f86ef5..78f7c575 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -150,8 +150,8 @@ def step(self, state: State, action: jp.ndarray) -> State: done = jp.where( jp.array(in_zone), - x=jp.array(1.0), - y=jp.array(0.0), + jp.array(1.0), + jp.array(0.0), ) new_obs = jp.array([x_pos, y_pos]) @@ -199,8 +199,8 @@ def _collision_lower_wall( y_axis_down_contact_condition_1 & y_axis_down_contact_condition_2 & x_axis_contact_condition, - x=jp.array(self.lower_wall_height_offset), - y=y_pos, + jp.array(self.lower_wall_height_offset), + y_pos, ) # From up - boolean style @@ -217,8 +217,8 @@ def _collision_lower_wall( & y_axis_up_contact_condition_2 & y_axis_up_contact_condition_3 & x_axis_contact_condition, - x=jp.array(self.lower_wall_height_offset + self.wallheight), - y=new_y_pos, + jp.array(self.lower_wall_height_offset + self.wallheight), + new_y_pos, ) return new_y_pos @@ -250,8 +250,8 @@ def _collision_upper_wall( y_axis_up_contact_condition_1 & y_axis_up_contact_condition_2 & x_axis_contact_condition, - x=jp.array(self.upper_wall_height_offset + self.wallheight), - y=y_pos, + jp.array(self.upper_wall_height_offset + self.wallheight), + y_pos, ) # From down - boolean style @@ -264,8 +264,8 @@ def _collision_upper_wall( & y_axis_down_contact_condition_2 & y_axis_down_contact_condition_3 & x_axis_contact_condition, - x=jp.array(self.upper_wall_height_offset), - y=new_y_pos, + jp.array(self.upper_wall_height_offset), + new_y_pos, ) return new_y_pos diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index cf0c3336..e5e40e4b 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional import flax.struct import jax @@ -80,7 +80,10 @@ class ClipRewardWrapper(Wrapper): """ def __init__( - self, env: Env, clip_min: float = None, clip_max: float = None + self, + env: Env, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, ) -> None: super().__init__(env) self._clip_min = clip_min @@ -108,7 +111,10 @@ class AffineRewardWrapper(Wrapper): """ def __init__( - self, env: Env, clip_min: float = None, clip_max: float = None + self, + env: Env, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, ) -> None: super().__init__(env) self._clip_min = clip_min diff --git a/qdax/tasks/arm.py b/qdax/tasks/arm.py index 7122ed63..27782cf3 100644 --- a/qdax/tasks/arm.py +++ b/qdax/tasks/arm.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 1a588e52..07d37d59 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -12,7 +12,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP -from qdax.types import ( +from qdax.custom_types import ( Descriptor, EnvState, ExtraScores, @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_brax( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: EnvState, diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py index f4936574..340581ab 100644 --- a/qdax/tasks/hypervolume_functions.py +++ b/qdax/tasks/hypervolume_functions.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def square(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index 14455d66..5f861f0e 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -7,7 +7,7 @@ import jumanji from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.types import ( +from qdax.custom_types import ( Descriptor, ExtraScores, Fitness, @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_jumanji( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: JumanjiState, @@ -67,7 +68,7 @@ def default_play_step_fn( obs=timestep.observation, next_obs=next_timestep.observation, rewards=next_timestep.reward, - dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)), + dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)), actions=action, truncations=jnp.array(0), state_desc=state_desc, diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py index 5784f596..59108ae5 100644 --- a/qdax/tasks/qd_suite/archimedean_spiral.py +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -4,8 +4,8 @@ import jax.lax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype class ParameterizationGenotype(Enum): diff --git a/qdax/tasks/qd_suite/deceptive_evolvability.py b/qdax/tasks/qd_suite/deceptive_evolvability.py index d5be0688..830ad523 100644 --- a/qdax/tasks/qd_suite/deceptive_evolvability.py +++ b/qdax/tasks/qd_suite/deceptive_evolvability.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype def multivariate_normal( diff --git a/qdax/tasks/qd_suite/qd_suite_task.py b/qdax/tasks/qd_suite/qd_suite_task.py index 6f1af76f..0d79317f 100644 --- a/qdax/tasks/qd_suite/qd_suite_task.py +++ b/qdax/tasks/qd_suite/qd_suite_task.py @@ -4,7 +4,7 @@ import jax from jax import numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class QDSuiteTask(abc.ABC): diff --git a/qdax/tasks/qd_suite/ssf.py b/qdax/tasks/qd_suite/ssf.py index 547bee8d..601aa6ad 100644 --- a/qdax/tasks/qd_suite/ssf.py +++ b/qdax/tasks/qd_suite/ssf.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype class SsfV0(QDSuiteTask): diff --git a/qdax/tasks/standard_functions.py b/qdax/tasks/standard_functions.py index 53d5b492..82b2f875 100644 --- a/qdax/tasks/standard_functions.py +++ b/qdax/tasks/standard_functions.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def rastrigin(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/utils/metrics.py b/qdax/utils/metrics.py index 2b8355af..509c6d91 100644 --- a/qdax/utils/metrics.py +++ b/qdax/utils/metrics.py @@ -12,7 +12,7 @@ from qdax.core.containers.ga_repertoire import GARepertoire from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.mome_repertoire import MOMERepertoire -from qdax.types import Metrics +from qdax.custom_types import Metrics from qdax.utils.pareto_front import compute_hypervolume diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index f9bd77ae..54fad3e6 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from qdax.types import Mask, ParetoFront +from qdax.custom_types import Mask, ParetoFront def compute_pareto_dominance( diff --git a/qdax/utils/plotting.py b/qdax/utils/plotting.py index 9b107c7e..7f0f086d 100644 --- a/qdax/utils/plotting.py +++ b/qdax/utils/plotting.py @@ -544,7 +544,7 @@ def _get_projection_in_1d( for all index i: x[i] < bases_tuple[i]. The vector and tuple of bases must have the same length. - For example if x=jnp.array([3, 1, 2]) and the bases are (5, 7, 3). + For example if jnp.array([3, 1, 2]) and the bases are (5, 7, 3). then the projection is 3*(7*3) + 1*(3) + 2 = 47. Args: @@ -574,7 +574,7 @@ def _get_projection_in_2d( """Projects an integer vector into a pair of integers, (given tuple of bases to consider for conversion). - For example if x=jnp.array([3, 1, 2, 5]) and the bases are (5, 2, 3, 7). + For example if jnp.array([3, 1, 2, 5]) and the bases are (5, 2, 3, 7). then the projection is obtained by: - projecting in 1D the point jnp.array([3, 2]) with the bases (5, 3) - projecting in 1D the point jnp.array([1, 5]) with the bases (2, 7) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index bf5c1ae4..be1d336d 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -1,11 +1,12 @@ """Core components of the MAP-Elites-sampling algorithm.""" + from functools import partial from typing import Callable, Tuple import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey @jax.jit diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index acb14a9b..bd9570a9 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -16,8 +16,8 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq +from qdax.custom_types import Params, RNGKey from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -from qdax.types import Params, RNGKey Array = Any PRNGKey = Any @@ -132,7 +132,7 @@ def lstm_ae_train( std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) # the std where they were NaNs was set to zero. But here we divide by the # std, so we replace the zeros by inf here. - std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) + std_obs = jnp.where(std_obs == 0, jnp.inf, std_obs) # TODO: maybe we could just compute this data on the valid dataset diff --git a/requirements.txt b/requirements.txt index 50d7899c..f6dea29a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,17 @@ ---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - absl-py==1.0.0 -brax==0.9.2 -chex==0.1.83 -dm-haiku==0.0.10 -flax==0.7.4 +brax==0.10.4 +chex==0.1.86 +flax==0.8.5 gym==0.26.2 ipython -jax[cuda12_pip] +jax==0.4.28 +jaxlib==0.4.28 jumanji==0.3.1 jupyter -numpy==1.24.1 -optax==0.1.7 -protobuf==3.19.4 -scikit-learn==1.0.2 -scipy==1.8.0 -seaborn==0.11.2 -tensorflow-probability==0.19.0 -typing-extensions==4.3.0 +numpy==1.26.4 +optax==0.1.9 +protobuf==3.19.5 +scikit-learn==1.5.1 +scipy==1.10.1 +tensorflow-probability==0.24.0 +typing-extensions==4.12.2 diff --git a/setup.py b/setup.py index 0065bf18..cd7d2b13 100644 --- a/setup.py +++ b/setup.py @@ -22,19 +22,23 @@ long_description_content_type="text/markdown", install_requires=[ "absl-py>=1.0.0", - "jax>=0.4.16", - "jaxlib>=0.4.16", # necessary to build the doc atm - "jinja2<3.1.0", + "brax>=0.10.4", + "chex>=0.1.86", + "flax>=0.8.5", + "gym>=0.26.2", + "jax>=0.4.28", + "jaxlib>=0.4.28", # necessary to build the doc atm + "jinja2>=3.1.4", "jumanji>=0.3.1", - "flax>=0.7.4", - "chex>=0.1.83", - "brax>=0.9.2", - "gym>=0.23.1", - "numpy>=1.22.3", - "optax>=0.1.7", - "scikit-learn>=1.0.2", - "scipy>=1.8.0", + "numpy>=1.26.4", + "optax>=0.1.9", + "scikit-learn>=1.5.1", + "scipy>=1.10.1", + "tensorflow-probability>=0.24.0", ], + extras_require={ + "cuda12": ["jax[cuda12]>=0.4.28"], + }, dependency_links=[ "https://storage.googleapis.com/jax-releases/jax_releases.html", ], @@ -46,7 +50,9 @@ "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index c86bd622..2dc6fa10 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -16,7 +16,7 @@ from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey @pytest.mark.parametrize( @@ -25,7 +25,7 @@ ) def test_cma_me(emitter_type: Type[CMAEmitter]) -> None: - num_iterations = 1000 + num_iterations = 2000 num_dimensions = 20 grid_shape = (50, 50) batch_size = 36 @@ -43,7 +43,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: def clip(x: jnp.ndarray) -> jnp.ndarray: in_bound = (x <= maxval) * (x >= minval) - return jnp.where(condition=in_bound, x=x, y=(maxval / x)) + return jnp.where(in_bound, x, (maxval / x)) def _behavior_descriptor_1(x: jnp.ndarray) -> jnp.ndarray: return jnp.sum(clip(x[: x.shape[-1] // 2])) @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index fdd9330b..5bfdfd58 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -12,7 +12,7 @@ ) from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey def test_cma_mega() -> None: @@ -125,7 +125,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 1e782f2a..2a8d3d1f 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -1,4 +1,5 @@ """Testing script for the algorithm DADS""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 0b9af46e..77094ffd 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -1,5 +1,6 @@ """Training script for the algorithm DADS, should be launched with hydra. e.g. python train_dads.py config=dads_ant""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 5f9ec5f7..a1eb1b51 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -15,7 +15,7 @@ polynomial_mutation, ) from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.types import ExtraScores, Fitness, RNGKey +from qdax.custom_types import ExtraScores, Fitness, RNGKey from qdax.utils.metrics import default_ga_metrics @@ -32,11 +32,11 @@ def test_ga(algorithm_class: Type[GeneticAlgorithm]) -> None: batch_size = 100 genotype_dim = 6 lag = 2.2 - base_lag = 0 + base_lag = 0.0 num_neighbours = 1 def rastrigin_scorer( - genotypes: jnp.ndarray, base_lag: int, lag: int + genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Rastrigin Scorer with first two dimensions as descriptors @@ -119,7 +119,11 @@ def scoring_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( algo_instance.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 98a5b960..5058bad6 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -119,10 +119,10 @@ def test_me_pbt_sac() -> None: def scoring_function(genotypes, random_key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) @@ -186,7 +186,7 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_map( + initial_metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] @@ -196,7 +196,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys, metrics = update_fn( repertoire, emitter_state, keys ) - metrics_cpu = jax.tree_map( + metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index fc2e89b0..39c3e942 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -117,10 +117,10 @@ def test_me_pbt_td3() -> None: def scoring_function(genotypes, random_key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) @@ -184,7 +184,7 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_map( + initial_metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] @@ -194,7 +194,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys, metrics = update_fn( repertoire, emitter_state, keys ) - metrics_cpu = jax.tree_map( + metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 3f3314fd..d1913b02 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_mees() -> None: @@ -185,7 +185,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index 7b0f0639..ad51c7ae 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -11,7 +11,7 @@ ) from qdax.core.emitters.omg_mega_emitter import OMGMEGAEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def test_omg_mega() -> None: @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index c83f277c..db7dc69e 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -59,7 +59,7 @@ def init_environments(random_key): # type: ignore eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) reshape_fn = jax.jit( - lambda tree: jax.tree_map( + lambda tree: jax.tree_util.tree_map( lambda x: jnp.reshape( x, ( diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 9e6134c9..0be68277 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -57,7 +57,7 @@ def init_environments(random_key): # type: ignore eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) reshape_fn = jax.jit( - lambda tree: jax.tree_map( + lambda tree: jax.tree_util.tree_map( lambda x: jnp.reshape( x, ( diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 9cb1b3fb..0490a481 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -15,8 +15,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_pgame() -> None: @@ -189,7 +189,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 1889f197..704416a4 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -17,8 +17,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_qdpg() -> None: @@ -239,7 +239,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), _metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), _metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index c667aa66..8c26b510 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -10,7 +10,7 @@ from qdax.baselines.sac import SAC, SacConfig, TrainingState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer -from qdax.types import EnvState +from qdax.custom_types import EnvState def test_sac() -> None: diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 2b238237..4bbb9d82 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -11,6 +11,7 @@ from qdax import environments from qdax.core.aurora import AURORA from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.custom_types import Observation from qdax.environments.bd_extractors import ( AuroraExtraInfoNormalization, get_aurora_encoding, @@ -19,7 +20,6 @@ create_default_brax_task_components, get_aurora_scoring_fn, ) -from qdax.types import Observation from qdax.utils import train_seq2seq from qdax.utils.metrics import default_qd_metrics from tests.core_test.map_elites_test import get_mixing_emitter diff --git a/tests/core_test/containers_test/mapelites_repertoire_test.py b/tests/core_test/containers_test/mapelites_repertoire_test.py index 55e6ed11..5c0d9d75 100644 --- a/tests/core_test/containers_test/mapelites_repertoire_test.py +++ b/tests/core_test/containers_test/mapelites_repertoire_test.py @@ -5,7 +5,7 @@ MapElitesRepertoire, compute_euclidean_centroids, ) -from qdax.types import ExtraScores +from qdax.custom_types import ExtraScores def test_mapelites_repertoire() -> None: diff --git a/tests/core_test/containers_test/mels_repertoire_test.py b/tests/core_test/containers_test/mels_repertoire_test.py index 2fb1bd76..0b854b32 100644 --- a/tests/core_test/containers_test/mels_repertoire_test.py +++ b/tests/core_test/containers_test/mels_repertoire_test.py @@ -2,7 +2,7 @@ import pytest from qdax.core.containers.mels_repertoire import MELSRepertoire -from qdax.types import ExtraScores +from qdax.custom_types import ExtraScores def test_add_to_mels_repertoire() -> None: diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index 93b3e081..ebf712d5 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -96,7 +96,11 @@ def test_multi_emitter() -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index b532aa65..c89ce04f 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey from qdax.utils.metrics import default_qd_metrics @@ -143,7 +143,11 @@ def play_step_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 21f90517..66bcc05f 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -15,8 +15,8 @@ from qdax.core.mels import MELS from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey @pytest.mark.parametrize( @@ -142,7 +142,11 @@ def metrics_fn(repertoire: MELSRepertoire) -> Dict: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mels.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index 103f9489..746b94a0 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -14,7 +14,7 @@ ) from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.mome import MOME -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey from qdax.utils.metrics import default_moqd_metrics @@ -36,10 +36,10 @@ def test_mome(num_descriptors: int) -> None: crossover_percentage = 1.0 batch_size = 80 lag = 2.2 - base_lag = 0 + base_lag = 0.0 def rastrigin_scorer( - genotypes: jnp.ndarray, base_lag: int, lag: int + genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Rastrigin Scorer with first two dimensions as descriptors @@ -131,7 +131,11 @@ def scoring_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mome.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index e0e298c1..06e25fcd 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -42,7 +42,9 @@ def test_insert_batch() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) + simple_transition = jax.tree_util.tree_map( + lambda x: x.repeat(3, axis=0), dummy_transition + ) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) pytest.assume( @@ -83,7 +85,9 @@ def test_sample() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) + simple_transition = jax.tree_util.tree_map( + lambda x: x.repeat(3, axis=0), dummy_transition + ) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) @@ -91,6 +95,6 @@ def test_sample() -> None: samples, random_key = replay_buffer.sample(random_key, 3) - samples_shapes = jax.tree_map(lambda x: x.shape, samples) - transition_shapes = jax.tree_map(lambda x: x.shape, simple_transition) + samples_shapes = jax.tree_util.tree_map(lambda x: x.shape, samples) + transition_shapes = jax.tree_util.tree_map(lambda x: x.shape, simple_transition) pytest.assume((samples_shapes == transition_shapes)) diff --git a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py index 75f68b40..12ea0874 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py @@ -202,8 +202,8 @@ def test_trajectory_buffer_insert() -> None: multy_step_episodic_data, equal_nan=True, ), - "Episodic data when transitions are added sequentially is not consistent to when\ - theya are added as batch.", + "Episodic data when transitions are added sequentially is not consistent to \ + when they are added as batch.", ) pytest.assume( diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index e71e761c..98361b23 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -96,7 +96,11 @@ def test_arm(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index f8c63259..c12518fb 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -84,7 +84,11 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index a390f709..3d619353 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -102,7 +102,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index eed90127..636a02cf 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -11,11 +11,11 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Descriptor, Observation from qdax.tasks.jumanji_envs import ( jumanji_scoring_function, make_policy_network_play_step_fn_jumanji, ) -from qdax.types import Descriptor, Observation def test_jumanji_utils() -> None: @@ -53,7 +53,13 @@ def test_jumanji_utils() -> None: def observation_processing( observation: jumanji.environments.routing.snake.types.Observation, ) -> Observation: - network_input = jnp.ravel(observation.grid) + network_input = jnp.concatenate( + [ + jnp.ravel(observation.grid), + jnp.array([observation.step_count]), + observation.action_mask.ravel(), + ] + ) return network_input play_step_fn = make_policy_network_play_step_fn_jumanji( @@ -67,7 +73,12 @@ def observation_processing( keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec - observation_size = np.prod(np.array(env.observation_spec().grid.shape)) + obs_spec = env.observation_spec() + observation_size = int( + np.prod(obs_spec.grid.shape) + + np.prod(obs_spec.step_count.shape) + + np.prod(obs_spec.action_mask.shape) + ) fake_batch = jnp.zeros(shape=(batch_size, observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index a0542e9b..46f6ce9b 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -117,7 +117,11 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 7b310389..87913364 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -92,7 +92,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/environments_test/pointmaze_test.py b/tests/environments_test/pointmaze_test.py index a13f41cc..ecc97864 100644 --- a/tests/environments_test/pointmaze_test.py +++ b/tests/environments_test/pointmaze_test.py @@ -6,8 +6,8 @@ from brax.v1.envs import Env import qdax +from qdax.custom_types import EnvState from qdax.environments.pointmaze import PointMaze -from qdax.types import EnvState def test_pointmaze() -> None: diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 6ce6cbe9..8d19379e 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -8,8 +8,8 @@ from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey from qdax.utils.sampling import ( average, closest, diff --git a/tool.Dockerfile b/tool.Dockerfile index 10b15b02..26a68236 100644 --- a/tool.Dockerfile +++ b/tool.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9.18-slim +FROM python:3.10.14-slim ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1