From 26c010e1698fb0be773bd7ef9b0ec78f89380ca3 Mon Sep 17 00:00:00 2001 From: Bryan Lim <46229436+limbryan@users.noreply.github.com> Date: Fri, 28 Apr 2023 16:35:23 +0100 Subject: [PATCH 01/16] chore: remove singularity tools and update haiku version (#147) * update haiku version * update jaxlib and optax dependencies * remove singularity scripts and remove singularity from documentation --- .readthedocs.yaml | 4 +- README.md | 2 +- docs/installation.md | 53 ----- environment.yaml | 1 - requirements.txt | 4 +- setup.py | 1 + singularity/build_final_image | 347 ------------------------------- singularity/build_final_image.py | 347 ------------------------------- singularity/singularity.def | 41 ---- singularity/start_container | 184 ---------------- singularity/start_container.py | 184 ---------------- 11 files changed, 7 insertions(+), 1161 deletions(-) delete mode 100755 singularity/build_final_image delete mode 100755 singularity/build_final_image.py delete mode 100644 singularity/singularity.def delete mode 100755 singularity/start_container delete mode 100755 singularity/start_container.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2ef47062..7eec359d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -18,7 +18,7 @@ mkdocs: # Optionally declare the Python requirements required to build your docs python: install: - - requirements: requirements.txt - - requirements: docs/requirements.txt - method: pip path: . + - requirements: requirements.txt + - requirements: docs/requirements.txt diff --git a/README.md b/README.md index 8b321c6b..2477348d 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main ``` Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To use QDax with NVidia GPUs, you must first install [CUDA, CuDNN, and JAX with GPU support](https://github.com/google/jax#installation). -However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). +However, we also provide and recommend using either Docker or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). ## Basic API Usage For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/mapelites.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). diff --git a/docs/installation.md b/docs/installation.md index 319a88c7..2c273da1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -56,59 +56,6 @@ sudo docker run --rm -it -v $QDAX_PATH:/app instadeep/qdax:$USER /bin/bash sudo docker run --rm -it --gpus '"device=0,1"' -v $QDAX_PATH:/app instadeep/qdax:$USER /bin/bash ``` - - -### Using singularity - -First, follow these initial steps: - -1. If it is not already done, install Singularity, following [these instructions](https://docs.sylabs.io/guides/3.0/user-guide/installation.html). - -2. Clone `qdax` -```zsh -git clone git@github.com:adaptive-intelligent-robotics/QDax.git -``` - -3. Enter the singularity folder -```zsh -cd qdax/singularity/ -``` - -You can build two distinct types of images with singularity: "final images" or "sandbox images". -A final image is a single file with the `.sif` extension, it is immutable. -On the contrary, a sandbox image is not a file but a folder, it allows you to develop inside the singularity container to test your code while writing it. - -To build a final image, execute the `build_final_image` script: -```zsh -./build_final_image -``` -It will generate a `.sif` file: `[image_name].sif`. If you execute this file using singularity, as follows, it will run the default application of the image, defined in the `singularity.def` file that you can find in the `singularity` folder as well. At the moment, this is just running the MAP-Elites algorithm on a simple task. -```zsh -singularity run --nv [image_name].sif -``` - -!!! warning "Using GPU" - The `--nv` flag of the `singularity run` command allows the container to use the GPU, it is thus important to use it for QDax. - - -To build a sandbox image, execute the `start_container` script: -```zsh -./start_container -n -``` - -!!! warning "Using GPU" - The `-n` flag of the `start_container` command allow the container to use the GPU, it is thus important to use it for QDax. - -This command will generate a sandbox container `qdax.sif/` and enter it. If you execute this command again later, it will not generate a new container but enter directly the existing one. -Once inside the sandbox container, enter the qdax development folder: -```zsh -cd /git/exp/qdax -``` -This folder is linked with the `qdax` folder on your machine, meaning that any modification inside the container will directly modify the files on your machine. You can now use this development environment to develop your own QDax-based code. - - - - ### Using conda 1. If it is not already done, install conda from [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html) diff --git a/environment.yaml b/environment.yaml index 78058b9d..e46c034e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,5 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/requirements.txt b/requirements.txt index b97297fa..718d6213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,16 @@ absl-py==1.0.0 brax==0.0.15 chex==0.1.5 -dm-haiku==0.0.5 +dm-haiku==0.0.9 flax==0.6.0 gym==0.23.1 ipython jax==0.3.17 +jaxlib==0.3.15 jumanji==0.1.3 jupyter numpy==1.22.3 +optax==0.1.4 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 diff --git a/setup.py b/setup.py index 2e50e0ea..a71f3174 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", + "optax>=0.1, <0.1.5", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], diff --git a/singularity/build_final_image b/singularity/build_final_image deleted file mode 100755 index 99911407..00000000 --- a/singularity/build_final_image +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import os -import subprocess -import sys -import time -from typing import Tuple, Union - -SINGULARITY_DEFINITION_FILE_NAME = "singularity.def" - - -class BColors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -def error_print(message: str) -> None: - print(f"{BColors.FAIL}{message}{BColors.ENDC}", file=sys.stderr) - - -def bold(message: str) -> str: - return f"{BColors.BOLD}{message}{BColors.ENDC}" - - -def load_singularity_file(path_to_singularity_definition_file: str) -> str: - try: - # read input file - fin = open(path_to_singularity_definition_file, "rt") - - except IOError: - error_print(f"ERROR, {path_to_singularity_definition_file} file not found!") - - finally: - data = fin.read() - # close the input file - fin.close() - return data - - -def get_repo_address() -> str: - # Search projects - command = os.popen("git config --local remote.origin.url") - url = command.read()[:-1] - - # if it is using the ssh protocal, we need to convert it into an address - # compatible with https as the key is not available inside the container - if url.startswith("git@"): - url = url.replace(":", "/") - url = url.replace("git@", "") - - if url.startswith("https://"): - url = url[len("https://") :] # Removing the https header - - return url - - -def get_commit_sha_and_branch_name( - project_commit_sha_to_consider: str, -) -> Tuple[str, str]: - # Search projects - command = os.popen(f"git rev-parse --short {project_commit_sha_to_consider}") - sha = command.read()[:-1] - command = os.popen(f"git rev-parse --abbrev-ref {project_commit_sha_to_consider}") - branch = command.read()[:-1] - - return sha, branch - - -def check_local_changes() -> None: - command = os.popen("git status --porcelain --untracked-files=no") - output = command.read()[:-1] - if output: - error_print("WARNING: There are currently unpushed changes:") - error_print(output) - - -def check_local_commit_is_pushed(project_commit_ref_to_consider: str) -> None: - command = os.popen(f"git branch -r --contains {project_commit_ref_to_consider}") - remote_branches_containing_commit = command.read()[:-1] - - if not remote_branches_containing_commit: - error_print( - f"WARNING: local commit {project_commit_ref_to_consider} not pushed, " - f"build is likely to fail!" - ) - - -def get_project_folder_name() -> str: - return ( - os.path.basename(os.path.dirname(os.getcwd())).strip().lower().replace(" ", "_") - ) - - -def clone_commands( - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> str: - repo_address = get_repo_address() - sha, branch = get_commit_sha_and_branch_name(project_commit_ref_to_consider) - - if ci_job_token: # we are in a CI environment - repo_address = f"http://gitlab-ci-token:{ci_job_token}@{repo_address}" - elif personal_token: # if a personal token is available - repo_address = f"https://oauth:{personal_token}@{repo_address}" - else: - repo_address = f"https://{repo_address}" - - print( - f"Building final image using branch: {bold(branch)} with sha: {bold(sha)} \n" - f"URL: {bold(repo_address)}" - ) - - if not no_check: - code_block = f""" - if [ ! -d {project_name} ] - then - echo 'ERROR: you are probably not cloning your project in the right directory' - echo 'Consider using the --project option of build_final_image' - echo 'with one of the folders shown below:' - ls - echo 'if you want to build your image anyway, use the --no-check option' - exit 1 - fi - - """ - else: - code_block = "" - - code_block += f""" - git clone --recurse-submodules --shallow-submodules {repo_address} {project_name} - cd {project_name} - git checkout {sha} - git submodule update - cd .. - """ - - return code_block - - -def apply_changes( - original_file: str, - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> None: - fout = open("./tmp.def", "w") - for line in original_file.splitlines(): - if "#NOTFORFINAL" in line: - continue - if "#CLONEHERE" in line: - line = clone_commands( - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - fout.write(line + "\n") - fout.close() - - -def compile_container( - project_name: str, image_name: Union[str, None], debug: bool -) -> None: - if not image_name: - image_name = f"final_{project_name}_{time.strftime('%Y-%m-%d_%H_%M_%S')}.sif" - subprocess.run( - ["singularity", "build", "--force", "--fakeroot", image_name, "./tmp.def"] - ) - if not debug: - os.remove("./tmp.def") - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a read-only final container " - "in which the entire project repository is cloned", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file.", - ) - - parser.add_argument( - "--commit-ref", - "-c", - required=False, - type=str, - default="HEAD", - help="commit/branch/tag to consider in the project repository " - "(useful only when using #CLONEHERE).", - ) - - parser.add_argument( - "--ci-job-token", - required=False, - type=str, - default=get_ci_job_token(), - help="Gitlab CI job token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "CI_JOB_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_CI_JOB_TOKEN is not set yet, " - "then it is set the value provided.", - ) - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=get_personal_token(), - help="Gitlab Personal token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "--project", - required=False, - type=str, - default=get_project_folder_name(), - help="Specify the name of the project. This corresponds to: " - "(1) Name of the folder in which the current repository will be cloned " - "(useful only when using #CLONEHERE); " - "(2) the name in the final singularity image " - '"final__YYYY_mm_DD_HH_MM_SS.sif". ' - "By default, it uses the name of the parent folder, as it is considered that " - "the script is executed in the 'singularity/' folder of the project.", - ) - - parser.add_argument( - "--image", - "-i", - required=False, - type=str, - default=None, - help="Name of the image to create. By default: " - '"final__YYYY_mm_DD_HH_MM_SS.sif"', - ) - - parser.add_argument( - "--no-check", - action="store_true", - help="Avoids standard verifications (checking if the repository is " - "cloned at the right place).", - ) - - parser.add_argument( - "--debug", - "-d", - action="store_true", - help="Shows debugging information. Temporary files are not removed.", - ) - - args = parser.parse_args() - return args - - -def get_ci_job_token() -> Union[str, None]: - if "CI_JOB_TOKEN" in os.environ: - return os.getenv("CI_JOB_TOKEN") - else: - return None - - -def get_personal_token() -> Union[str, None]: - if "PERSONAL_TOKEN" in os.environ: - return os.getenv("PERSONAL_TOKEN") - else: - return None - - -def generate_singularity_environment_variables( - ci_job_token: Union[str, None], - personal_token: Union[str, None], - project_folder: Union[str, None], -) -> None: - key_singularityenv_ci_job_token = "SINGULARITYENV_CI_JOB_TOKEN" - if ci_job_token and key_singularityenv_ci_job_token not in os.environ: - os.environ[key_singularityenv_ci_job_token] = ci_job_token - - key_singularityenv_personal_token = "SINGULARITYENV_PERSONAL_TOKEN" - if personal_token and key_singularityenv_personal_token not in os.environ: - os.environ[key_singularityenv_personal_token] = personal_token - - key_singularityenv_project_folder = "SINGULARITYENV_PROJECT_FOLDER" - if project_folder and key_singularityenv_project_folder not in os.environ: - os.environ[key_singularityenv_project_folder] = project_folder - - -def main() -> None: - args = get_args() - - path_to_singularity_definition_file = args.path_def - project_commit_ref_to_consider = args.commit_ref - ci_job_token = args.ci_job_token - personal_token = args.personal_token - project_name = args.project - debug = args.debug - image_name = args.image - no_check = args.no_check - - # doing some checks and print warnings - check_local_changes() - check_local_commit_is_pushed(project_commit_ref_to_consider) - - # getting the orignal singularity file - data = load_singularity_file(path_to_singularity_definition_file) - - # appling the changes and writing this in ./tmp.def - apply_changes( - data, - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - - # Create environment variables for singularity - generate_singularity_environment_variables( - ci_job_token, personal_token, project_folder=project_name - ) - - # compiling and deleting ./tmp.def - compile_container(project_name, image_name, debug) - - -if __name__ == "__main__": - main() diff --git a/singularity/build_final_image.py b/singularity/build_final_image.py deleted file mode 100755 index 99911407..00000000 --- a/singularity/build_final_image.py +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import os -import subprocess -import sys -import time -from typing import Tuple, Union - -SINGULARITY_DEFINITION_FILE_NAME = "singularity.def" - - -class BColors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -def error_print(message: str) -> None: - print(f"{BColors.FAIL}{message}{BColors.ENDC}", file=sys.stderr) - - -def bold(message: str) -> str: - return f"{BColors.BOLD}{message}{BColors.ENDC}" - - -def load_singularity_file(path_to_singularity_definition_file: str) -> str: - try: - # read input file - fin = open(path_to_singularity_definition_file, "rt") - - except IOError: - error_print(f"ERROR, {path_to_singularity_definition_file} file not found!") - - finally: - data = fin.read() - # close the input file - fin.close() - return data - - -def get_repo_address() -> str: - # Search projects - command = os.popen("git config --local remote.origin.url") - url = command.read()[:-1] - - # if it is using the ssh protocal, we need to convert it into an address - # compatible with https as the key is not available inside the container - if url.startswith("git@"): - url = url.replace(":", "/") - url = url.replace("git@", "") - - if url.startswith("https://"): - url = url[len("https://") :] # Removing the https header - - return url - - -def get_commit_sha_and_branch_name( - project_commit_sha_to_consider: str, -) -> Tuple[str, str]: - # Search projects - command = os.popen(f"git rev-parse --short {project_commit_sha_to_consider}") - sha = command.read()[:-1] - command = os.popen(f"git rev-parse --abbrev-ref {project_commit_sha_to_consider}") - branch = command.read()[:-1] - - return sha, branch - - -def check_local_changes() -> None: - command = os.popen("git status --porcelain --untracked-files=no") - output = command.read()[:-1] - if output: - error_print("WARNING: There are currently unpushed changes:") - error_print(output) - - -def check_local_commit_is_pushed(project_commit_ref_to_consider: str) -> None: - command = os.popen(f"git branch -r --contains {project_commit_ref_to_consider}") - remote_branches_containing_commit = command.read()[:-1] - - if not remote_branches_containing_commit: - error_print( - f"WARNING: local commit {project_commit_ref_to_consider} not pushed, " - f"build is likely to fail!" - ) - - -def get_project_folder_name() -> str: - return ( - os.path.basename(os.path.dirname(os.getcwd())).strip().lower().replace(" ", "_") - ) - - -def clone_commands( - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> str: - repo_address = get_repo_address() - sha, branch = get_commit_sha_and_branch_name(project_commit_ref_to_consider) - - if ci_job_token: # we are in a CI environment - repo_address = f"http://gitlab-ci-token:{ci_job_token}@{repo_address}" - elif personal_token: # if a personal token is available - repo_address = f"https://oauth:{personal_token}@{repo_address}" - else: - repo_address = f"https://{repo_address}" - - print( - f"Building final image using branch: {bold(branch)} with sha: {bold(sha)} \n" - f"URL: {bold(repo_address)}" - ) - - if not no_check: - code_block = f""" - if [ ! -d {project_name} ] - then - echo 'ERROR: you are probably not cloning your project in the right directory' - echo 'Consider using the --project option of build_final_image' - echo 'with one of the folders shown below:' - ls - echo 'if you want to build your image anyway, use the --no-check option' - exit 1 - fi - - """ - else: - code_block = "" - - code_block += f""" - git clone --recurse-submodules --shallow-submodules {repo_address} {project_name} - cd {project_name} - git checkout {sha} - git submodule update - cd .. - """ - - return code_block - - -def apply_changes( - original_file: str, - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> None: - fout = open("./tmp.def", "w") - for line in original_file.splitlines(): - if "#NOTFORFINAL" in line: - continue - if "#CLONEHERE" in line: - line = clone_commands( - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - fout.write(line + "\n") - fout.close() - - -def compile_container( - project_name: str, image_name: Union[str, None], debug: bool -) -> None: - if not image_name: - image_name = f"final_{project_name}_{time.strftime('%Y-%m-%d_%H_%M_%S')}.sif" - subprocess.run( - ["singularity", "build", "--force", "--fakeroot", image_name, "./tmp.def"] - ) - if not debug: - os.remove("./tmp.def") - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a read-only final container " - "in which the entire project repository is cloned", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file.", - ) - - parser.add_argument( - "--commit-ref", - "-c", - required=False, - type=str, - default="HEAD", - help="commit/branch/tag to consider in the project repository " - "(useful only when using #CLONEHERE).", - ) - - parser.add_argument( - "--ci-job-token", - required=False, - type=str, - default=get_ci_job_token(), - help="Gitlab CI job token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "CI_JOB_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_CI_JOB_TOKEN is not set yet, " - "then it is set the value provided.", - ) - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=get_personal_token(), - help="Gitlab Personal token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "--project", - required=False, - type=str, - default=get_project_folder_name(), - help="Specify the name of the project. This corresponds to: " - "(1) Name of the folder in which the current repository will be cloned " - "(useful only when using #CLONEHERE); " - "(2) the name in the final singularity image " - '"final__YYYY_mm_DD_HH_MM_SS.sif". ' - "By default, it uses the name of the parent folder, as it is considered that " - "the script is executed in the 'singularity/' folder of the project.", - ) - - parser.add_argument( - "--image", - "-i", - required=False, - type=str, - default=None, - help="Name of the image to create. By default: " - '"final__YYYY_mm_DD_HH_MM_SS.sif"', - ) - - parser.add_argument( - "--no-check", - action="store_true", - help="Avoids standard verifications (checking if the repository is " - "cloned at the right place).", - ) - - parser.add_argument( - "--debug", - "-d", - action="store_true", - help="Shows debugging information. Temporary files are not removed.", - ) - - args = parser.parse_args() - return args - - -def get_ci_job_token() -> Union[str, None]: - if "CI_JOB_TOKEN" in os.environ: - return os.getenv("CI_JOB_TOKEN") - else: - return None - - -def get_personal_token() -> Union[str, None]: - if "PERSONAL_TOKEN" in os.environ: - return os.getenv("PERSONAL_TOKEN") - else: - return None - - -def generate_singularity_environment_variables( - ci_job_token: Union[str, None], - personal_token: Union[str, None], - project_folder: Union[str, None], -) -> None: - key_singularityenv_ci_job_token = "SINGULARITYENV_CI_JOB_TOKEN" - if ci_job_token and key_singularityenv_ci_job_token not in os.environ: - os.environ[key_singularityenv_ci_job_token] = ci_job_token - - key_singularityenv_personal_token = "SINGULARITYENV_PERSONAL_TOKEN" - if personal_token and key_singularityenv_personal_token not in os.environ: - os.environ[key_singularityenv_personal_token] = personal_token - - key_singularityenv_project_folder = "SINGULARITYENV_PROJECT_FOLDER" - if project_folder and key_singularityenv_project_folder not in os.environ: - os.environ[key_singularityenv_project_folder] = project_folder - - -def main() -> None: - args = get_args() - - path_to_singularity_definition_file = args.path_def - project_commit_ref_to_consider = args.commit_ref - ci_job_token = args.ci_job_token - personal_token = args.personal_token - project_name = args.project - debug = args.debug - image_name = args.image - no_check = args.no_check - - # doing some checks and print warnings - check_local_changes() - check_local_commit_is_pushed(project_commit_ref_to_consider) - - # getting the orignal singularity file - data = load_singularity_file(path_to_singularity_definition_file) - - # appling the changes and writing this in ./tmp.def - apply_changes( - data, - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - - # Create environment variables for singularity - generate_singularity_environment_variables( - ci_job_token, personal_token, project_folder=project_name - ) - - # compiling and deleting ./tmp.def - compile_container(project_name, image_name, debug) - - -if __name__ == "__main__": - main() diff --git a/singularity/singularity.def b/singularity/singularity.def deleted file mode 100644 index ee8ca27a..00000000 --- a/singularity/singularity.def +++ /dev/null @@ -1,41 +0,0 @@ -Bootstrap: library -From: airl_lab/default/airl_env:qdax_f57720d0 - -%labels - Author adaptive.intelligent.robotics@gmail.com - Version v0.0.1 - -%environment - export PYTHONPATH=$PYTHONPATH:/workspace/lib/python3.8/site-packages/ - export LD_LIBRARY_PATH="/workspace/lib:$LD_LIBRARY_PATH" - export PATH=$PATH:/usr/local/go/bin - -%post - export LD_LIBRARY_PATH="/workspace/lib:$LD_LIBRARY_PATH" - apt-get update -y - pip3 install --upgrade pip - - # Create working directory - mkdir -p /git/exp/qdax/ - - #================================================================================== - exit 0 #NOTFORFINAL - the lines below this "exit" will be executed only when building the final image - #================================================================================== - - # Enter working directory - cd /git/exp/ - - #CLONEHERE - -%runscript - # Entering directory - cd /git/exp/qdax/ - - # Running the test file as a demo - echo - echo 'Running the test of MAP-Elites algorithm as a demo' - echo - pytest tests/core_test/map_elites_test.py - -%help - This is the development and running environment of QDax diff --git a/singularity/start_container b/singularity/start_container deleted file mode 100755 index 09b8dd9b..00000000 --- a/singularity/start_container +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import subprocess -import tempfile - -import build_final_image - -EXP_PATH = "git/exp/" -ABSOLUTE_EXP_PATH = "/" + EXP_PATH - - -def get_default_image_name() -> str: - return f"{build_final_image.get_project_folder_name()}.sif" - - -def build_sandbox(path_singularity_def: str, image_name: str) -> None: - # check if the sandbox has already been created - if os.path.exists(image_name): - return - - print(f"{image_name} does not exist, building it now from {path_singularity_def}") - assert os.path.exists( - path_singularity_def - ) # exit if path_singularity_definition_file is not found - - # run commands - command = ( - f"singularity build --force --fakeroot --sandbox {image_name} " - f"{path_singularity_def}" - ) - subprocess.run(command.split()) - - -def run_container( - nvidia: bool, - use_no_home: bool, - use_tmp_home: bool, - image_name: str, - binding_folder_inside_container: str, -) -> None: - additional_args = "" - - if nvidia: - print("Nvidia runtime ON") - additional_args += " " + "--nv" - - if use_no_home: - print("Using --no-home") - additional_args += " " + "--no-home --containall" - - if use_tmp_home: - tmp_home_folder = tempfile.mkdtemp(dir="/tmp") - additional_args += " " + f"--home {tmp_home_folder}" - build_final_image.error_print( - f"Warning: The HOME folder is a temporary directory located in " - f"{tmp_home_folder}! Do not store any result there!" - ) - - if not binding_folder_inside_container: - binding_folder_inside_container = build_final_image.get_project_folder_name() - - path_folder_binding_in_container = os.path.join( - image_name, EXP_PATH, binding_folder_inside_container - ) - if not os.path.exists(path_folder_binding_in_container): - list_possible_folder_binding_in_container = next( - os.walk(os.path.join(image_name, EXP_PATH)) - )[1] - list_possible_options = [ - f" --binding-folder {existing_folder}" - for existing_folder in list_possible_folder_binding_in_container - ] - build_final_image.error_print( - f"Warning: The folder " - f"{os.path.join(ABSOLUTE_EXP_PATH, binding_folder_inside_container)} " - f"does not exist in the container. The Binding between your project folder " - f"and your container is likely to be unsuccessful.\n" - f"You may want to consider adding one of the following options to the " - f"'start_container' command:\n" + "\n".join(list_possible_options) - ) - - command = ( - f"singularity shell -w {additional_args} " - f"--bind {os.path.dirname(os.getcwd())}:" - f"{ABSOLUTE_EXP_PATH}/{binding_folder_inside_container} " - f"{image_name}" - ) - subprocess.run(command.split()) - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a sandbox container and shell into it.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-n", "--nv", action="store_true", help="enable experimental Nvidia support" - ) - parser.add_argument( - "--no-home", action="store_true", help='apply --no-home to "singularity shell"' - ) - parser.add_argument( - "--tmp-home", - action="store_true", - help="binds HOME directory of the singularity container to a temporary folder", - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=build_final_image.SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file", - ) - - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=build_final_image.get_personal_token(), - help="Gitlab Personal token. " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "-b", - "--binding-folder", - required=False, - type=str, - default=build_final_image.get_project_folder_name(), - help=f"If specified, it corresponds to the name folder in {ABSOLUTE_EXP_PATH} " - f"from which the binding is performed to the current project source code. " - f"By default, it corresponds to the image name (without the .sif extension)", - ) - - parser.add_argument( - "-i", - "--image", - required=False, - type=str, - default=get_default_image_name(), - help="name of the sandbox image to start", - ) - - args = parser.parse_args() - - return args - - -def main() -> None: - args = get_args() - - enable_nvidia_support = args.nv - use_no_home = args.no_home - use_tmp_home = args.tmp_home - path_singularity_definition_file = args.path_def - image_name = args.image - binding_folder_inside_container = args.binding_folder - personal_token = args.personal_token - - # Create environment variables for singularity - build_final_image.generate_singularity_environment_variables( - ci_job_token=None, - personal_token=personal_token, - project_folder=binding_folder_inside_container, - ) - - build_sandbox(path_singularity_definition_file, image_name) - run_container( - enable_nvidia_support, - use_no_home, - use_tmp_home, - image_name, - binding_folder_inside_container, - ) - - -if __name__ == "__main__": - main() diff --git a/singularity/start_container.py b/singularity/start_container.py deleted file mode 100755 index 09b8dd9b..00000000 --- a/singularity/start_container.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import subprocess -import tempfile - -import build_final_image - -EXP_PATH = "git/exp/" -ABSOLUTE_EXP_PATH = "/" + EXP_PATH - - -def get_default_image_name() -> str: - return f"{build_final_image.get_project_folder_name()}.sif" - - -def build_sandbox(path_singularity_def: str, image_name: str) -> None: - # check if the sandbox has already been created - if os.path.exists(image_name): - return - - print(f"{image_name} does not exist, building it now from {path_singularity_def}") - assert os.path.exists( - path_singularity_def - ) # exit if path_singularity_definition_file is not found - - # run commands - command = ( - f"singularity build --force --fakeroot --sandbox {image_name} " - f"{path_singularity_def}" - ) - subprocess.run(command.split()) - - -def run_container( - nvidia: bool, - use_no_home: bool, - use_tmp_home: bool, - image_name: str, - binding_folder_inside_container: str, -) -> None: - additional_args = "" - - if nvidia: - print("Nvidia runtime ON") - additional_args += " " + "--nv" - - if use_no_home: - print("Using --no-home") - additional_args += " " + "--no-home --containall" - - if use_tmp_home: - tmp_home_folder = tempfile.mkdtemp(dir="/tmp") - additional_args += " " + f"--home {tmp_home_folder}" - build_final_image.error_print( - f"Warning: The HOME folder is a temporary directory located in " - f"{tmp_home_folder}! Do not store any result there!" - ) - - if not binding_folder_inside_container: - binding_folder_inside_container = build_final_image.get_project_folder_name() - - path_folder_binding_in_container = os.path.join( - image_name, EXP_PATH, binding_folder_inside_container - ) - if not os.path.exists(path_folder_binding_in_container): - list_possible_folder_binding_in_container = next( - os.walk(os.path.join(image_name, EXP_PATH)) - )[1] - list_possible_options = [ - f" --binding-folder {existing_folder}" - for existing_folder in list_possible_folder_binding_in_container - ] - build_final_image.error_print( - f"Warning: The folder " - f"{os.path.join(ABSOLUTE_EXP_PATH, binding_folder_inside_container)} " - f"does not exist in the container. The Binding between your project folder " - f"and your container is likely to be unsuccessful.\n" - f"You may want to consider adding one of the following options to the " - f"'start_container' command:\n" + "\n".join(list_possible_options) - ) - - command = ( - f"singularity shell -w {additional_args} " - f"--bind {os.path.dirname(os.getcwd())}:" - f"{ABSOLUTE_EXP_PATH}/{binding_folder_inside_container} " - f"{image_name}" - ) - subprocess.run(command.split()) - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a sandbox container and shell into it.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-n", "--nv", action="store_true", help="enable experimental Nvidia support" - ) - parser.add_argument( - "--no-home", action="store_true", help='apply --no-home to "singularity shell"' - ) - parser.add_argument( - "--tmp-home", - action="store_true", - help="binds HOME directory of the singularity container to a temporary folder", - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=build_final_image.SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file", - ) - - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=build_final_image.get_personal_token(), - help="Gitlab Personal token. " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "-b", - "--binding-folder", - required=False, - type=str, - default=build_final_image.get_project_folder_name(), - help=f"If specified, it corresponds to the name folder in {ABSOLUTE_EXP_PATH} " - f"from which the binding is performed to the current project source code. " - f"By default, it corresponds to the image name (without the .sif extension)", - ) - - parser.add_argument( - "-i", - "--image", - required=False, - type=str, - default=get_default_image_name(), - help="name of the sandbox image to start", - ) - - args = parser.parse_args() - - return args - - -def main() -> None: - args = get_args() - - enable_nvidia_support = args.nv - use_no_home = args.no_home - use_tmp_home = args.tmp_home - path_singularity_definition_file = args.path_def - image_name = args.image - binding_folder_inside_container = args.binding_folder - personal_token = args.personal_token - - # Create environment variables for singularity - build_final_image.generate_singularity_environment_variables( - ci_job_token=None, - personal_token=personal_token, - project_folder=binding_folder_inside_container, - ) - - build_sandbox(path_singularity_definition_file, image_name) - run_container( - enable_nvidia_support, - use_no_home, - use_tmp_home, - image_name, - binding_folder_inside_container, - ) - - -if __name__ == "__main__": - main() From a07397d933477b815bdd49c0bc0bc98efa2bd105 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 12:48:39 +0200 Subject: [PATCH 02/16] fix: dependencies in notebook examples --- examples/cmame.ipynb | 8 +- examples/cmamega.ipynb | 8 +- examples/dads.ipynb | 8 +- examples/diayn.ipynb | 8 +- examples/distributed_mapelites.ipynb | 140 ++++++--------------------- examples/mapelites.ipynb | 8 +- examples/mees.ipynb | 8 +- examples/mome.ipynb | 8 +- examples/nsga2_spea2.ipynb | 6 ++ examples/omgmega.ipynb | 6 ++ examples/pgame.ipynb | 8 +- examples/qdpg.ipynb | 8 +- examples/smerl.ipynb | 10 +- 13 files changed, 112 insertions(+), 122 deletions(-) mode change 100755 => 100644 examples/mees.ipynb diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index c9d6f67e..1d7337d4 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -49,7 +49,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index 509e52ea..2e00d660 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -43,7 +43,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index f64f4685..deba8835 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -45,10 +45,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import haiku\n", "except:\n", " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 10cfda49..c725da4b 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -45,10 +45,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import haiku\n", "except:\n", " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index b8a08b52..434725a3 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -2,22 +2,14 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/distributed_mapelites.ipynb)" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "# Optimizing with MAP-Elites in Jax (multi-devices example)\n", "\n", @@ -34,11 +26,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "#@title Installs and Imports\n", @@ -61,10 +49,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -93,22 +87,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Setup and get devices" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Setup the default platform where the MAP-Elites will be stored and MAP-Elite updates will happen. " ] @@ -116,11 +102,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "default_device = 'cpu'\n", @@ -130,11 +112,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", @@ -146,11 +124,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Setup run parameters" ] @@ -158,11 +132,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "#@title QD Training Definitions Fields\n", @@ -185,11 +155,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Init environment, policy, population params, init states of the env\n", "\n", @@ -199,11 +165,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -237,11 +199,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the way the policy interacts with the env\n", "\n", @@ -251,11 +209,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Define the fonction to play a step with the policy in the environment\n", @@ -289,11 +243,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", @@ -303,11 +253,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Prepare the scoring function\n", @@ -332,11 +278,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the emitter\n", "\n", @@ -346,11 +288,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Define emitter\n", @@ -367,11 +305,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Instantiate and initialise the MAP Elites algorithm" ] @@ -379,11 +313,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -423,11 +353,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Launch MAP-Elites iterations" ] @@ -435,11 +361,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "log_period = 10\n", @@ -493,11 +415,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Get the repertoire from the first device\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index c456cf5b..18728e73 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -49,10 +49,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb old mode 100755 new mode 100644 index 8f1dc444..ab5fad93 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -54,10 +54,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@feat/add-algo-mees |tail -n 1\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 6a6f7d39..a4ca36a6 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -49,7 +49,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 5cbe02a2..51c5f5bd 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -52,6 +52,12 @@ " import chex\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index d75a0077..0a28876a 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -47,6 +47,12 @@ " import chex\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 24222ddf..7a51a0bd 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -48,10 +48,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 5642fd3b..d778ad1d 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -48,10 +48,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 8042c8cf..47ff96e9 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -45,8 +45,14 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", - " import \n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", " \n", "try:\n", " import haiku\n", From 5531831288607edbae08214caec8dfb8b0ec6ea4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 14:14:49 +0200 Subject: [PATCH 03/16] make all tests pass --- .readthedocs.yaml | 4 ++-- environment.yaml | 1 - setup.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2ef47062..7eec359d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -18,7 +18,7 @@ mkdocs: # Optionally declare the Python requirements required to build your docs python: install: - - requirements: requirements.txt - - requirements: docs/requirements.txt - method: pip path: . + - requirements: requirements.txt + - requirements: docs/requirements.txt diff --git a/environment.yaml b/environment.yaml index 78058b9d..e46c034e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,5 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/setup.py b/setup.py index 2e50e0ea..a71f3174 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", + "optax>=0.1, <0.1.5", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], From b07d1a7bc29da73b5780c07b662f0a7b31bd42f1 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 14:27:51 +0200 Subject: [PATCH 04/16] make all tests pass --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index b97297fa..16c91bc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,11 @@ flax==0.6.0 gym==0.23.1 ipython jax==0.3.17 +jaxlib==0.3.15 jumanji==0.1.3 jupyter numpy==1.22.3 +optax==0.1.4 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 From 1e016f1ee8581a0a0ed67bdf3bbada98f3d706d8 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 15:15:03 +0200 Subject: [PATCH 05/16] fix: dependencies in notebook examples (#149) fix: dependencies in notebook examples --- .readthedocs.yaml | 4 +- environment.yaml | 1 - examples/cmame.ipynb | 8 +- examples/cmamega.ipynb | 8 +- examples/dads.ipynb | 8 +- examples/diayn.ipynb | 8 +- examples/distributed_mapelites.ipynb | 140 ++++++--------------------- examples/mapelites.ipynb | 8 +- examples/mees.ipynb | 8 +- examples/mome.ipynb | 8 +- examples/nsga2_spea2.ipynb | 6 ++ examples/omgmega.ipynb | 6 ++ examples/pgame.ipynb | 8 +- examples/qdpg.ipynb | 8 +- examples/smerl.ipynb | 10 +- requirements.txt | 2 + setup.py | 1 + 17 files changed, 117 insertions(+), 125 deletions(-) mode change 100755 => 100644 examples/mees.ipynb diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2ef47062..7eec359d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -18,7 +18,7 @@ mkdocs: # Optionally declare the Python requirements required to build your docs python: install: - - requirements: requirements.txt - - requirements: docs/requirements.txt - method: pip path: . + - requirements: requirements.txt + - requirements: docs/requirements.txt diff --git a/environment.yaml b/environment.yaml index 78058b9d..e46c034e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,5 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index c9d6f67e..1d7337d4 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -49,7 +49,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index 509e52ea..2e00d660 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -43,7 +43,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index f64f4685..deba8835 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -45,10 +45,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import haiku\n", "except:\n", " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 10cfda49..c725da4b 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -45,10 +45,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import haiku\n", "except:\n", " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index b8a08b52..434725a3 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -2,22 +2,14 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/distributed_mapelites.ipynb)" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "# Optimizing with MAP-Elites in Jax (multi-devices example)\n", "\n", @@ -34,11 +26,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "#@title Installs and Imports\n", @@ -61,10 +49,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -93,22 +87,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Setup and get devices" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Setup the default platform where the MAP-Elites will be stored and MAP-Elite updates will happen. " ] @@ -116,11 +102,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "default_device = 'cpu'\n", @@ -130,11 +112,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", @@ -146,11 +124,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Setup run parameters" ] @@ -158,11 +132,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "#@title QD Training Definitions Fields\n", @@ -185,11 +155,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Init environment, policy, population params, init states of the env\n", "\n", @@ -199,11 +165,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -237,11 +199,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the way the policy interacts with the env\n", "\n", @@ -251,11 +209,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Define the fonction to play a step with the policy in the environment\n", @@ -289,11 +243,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", @@ -303,11 +253,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Prepare the scoring function\n", @@ -332,11 +278,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the emitter\n", "\n", @@ -346,11 +288,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Define emitter\n", @@ -367,11 +305,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Instantiate and initialise the MAP Elites algorithm" ] @@ -379,11 +313,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -423,11 +353,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Launch MAP-Elites iterations" ] @@ -435,11 +361,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "log_period = 10\n", @@ -493,11 +415,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Get the repertoire from the first device\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index c456cf5b..18728e73 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -49,10 +49,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb old mode 100755 new mode 100644 index 8f1dc444..ab5fad93 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -54,10 +54,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@feat/add-algo-mees |tail -n 1\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 6a6f7d39..a4ca36a6 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -49,7 +49,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 5cbe02a2..51c5f5bd 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -52,6 +52,12 @@ " import chex\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index d75a0077..0a28876a 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -47,6 +47,12 @@ " import chex\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 24222ddf..7a51a0bd 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -48,10 +48,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 5642fd3b..d778ad1d 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -48,10 +48,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 8042c8cf..47ff96e9 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -45,8 +45,14 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", - " import \n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", " \n", "try:\n", " import haiku\n", diff --git a/requirements.txt b/requirements.txt index b97297fa..16c91bc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,11 @@ flax==0.6.0 gym==0.23.1 ipython jax==0.3.17 +jaxlib==0.3.15 jumanji==0.1.3 jupyter numpy==1.22.3 +optax==0.1.4 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 diff --git a/setup.py b/setup.py index 2e50e0ea..a71f3174 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", + "optax>=0.1, <0.1.5", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], From b44969f94aaa70dc6e53aaed95193f65f20400c2 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 14:17:37 +0100 Subject: [PATCH 06/16] Update version QDax after bugfix --- qdax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qdax/__init__.py b/qdax/__init__.py index b5fdc753..d31c31ea 100644 --- a/qdax/__init__.py +++ b/qdax/__init__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.2.3" From 79939ee0f8e6dabccc67baf0ef300eb5906ba314 Mon Sep 17 00:00:00 2001 From: Bryon Tjanaka <38124174+btjanaka@users.noreply.github.com> Date: Wed, 23 Aug 2023 09:13:34 -0400 Subject: [PATCH 07/16] feat(algo): Add MAP-Elites Low-Spread (#152) * Add MELS Repertoire * Create MELS Algorithm class * Introduce Spread type * Add multi_sample_scoring_function Authored-by: b-tjanaka@wings --- README.md | 1 + docs/api_documentation/core/mels.md | 7 + examples/mels.ipynb | 559 ++++++++++++++++++ mkdocs.yml | 1 + qdax/core/containers/mels_repertoire.py | 311 ++++++++++ qdax/core/mels.py | 104 ++++ qdax/types.py | 1 + qdax/utils/sampling.py | 81 ++- .../containers_test/mels_repertoire_test.py | 236 ++++++++ tests/core_test/mels_test.py | 156 +++++ 10 files changed, 1444 insertions(+), 13 deletions(-) create mode 100644 docs/api_documentation/core/mels.md create mode 100644 examples/mels.ipynb create mode 100644 qdax/core/containers/mels_repertoire.py create mode 100644 qdax/core/mels.py create mode 100644 tests/core_test/containers_test/mels_repertoire_test.py create mode 100644 tests/core_test/mels_test.py diff --git a/README.md b/README.md index 2477348d..0881da1d 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ QDax currently supports the following algorithms: | [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) | | [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) | | [MAP-Elites PBT (ME-PBT)](https://openreview.net/forum?id=CBfYffLqWqb) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_sac_pbt.ipynb) | +| [MAP-Elites Low-Spread (ME-LS)](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_ls.ipynb) | diff --git a/docs/api_documentation/core/mels.md b/docs/api_documentation/core/mels.md new file mode 100644 index 00000000..3aa212b5 --- /dev/null +++ b/docs/api_documentation/core/mels.md @@ -0,0 +1,7 @@ +# MAP-Elites Low-Spread (ME-LS) + +[ME-LS](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) is a variant of +MAP-Elites that thrives the search process towards solutions that are consistent +in the behavior space for uncertain domains. + +::: qdax.core.mels.MELS diff --git a/examples/mels.ipynb b/examples/mels.ipynb new file mode 100644 index 00000000..1fcd6c42 --- /dev/null +++ b/examples/mels.ipynb @@ -0,0 +1,559 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mels.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing Uncertain Domains with ME-LS in JAX\n", + "\n", + "This notebook shows how to discover controllers that achieve consistent performance in MDP domains using the [MAP-Elites Low-Spread](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an ME-LS instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.mels import MELS\n", + "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", + "from qdax.core.containers.mels_repertoire import MELSRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "from jax.flatten_util import ravel_pytree\n", + "\n", + "from IPython.display import HTML\n", + "from brax.io import html\n", + "\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "num_samples = 5 #@param {type:\"number\"}\n", + "episode_length = 100 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we consider the problem where each controller is evaluated `num_samples` times, each time in a different environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers. There are batch_size controllers, and each\n", + "# controller will be evaluated num_samples times.\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data. This is identical to the function in the MAP-Elites tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the function to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"Play an environment step and return the updated state and the\n", + " transition.\"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + "\n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. Note that while the MAP-Elites tutorial uses `scoring_function_brax_envs` as the basis for the scoring function, we use `reset_based_scoring_function_brax_envs`. The difference is that `reset_based_scoring_function_brax_envs` generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity in the behavior." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "scoring_fn = functools.partial(\n", + " reset_based_scoring_function_brax_envs,\n", + " episode_length=episode_length,\n", + " play_reset_fn=env.reset,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_fn = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "mixing_emitter = MixingEmitter(\n", + " mutation_fn=None, \n", + " variation_fn=variation_fn, \n", + " variation_percentage=1.0, \n", + " batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the ME-LS algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Instantiate ME-LS.\n", + "mels = MELS(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=metrics_fn,\n", + " num_samples=num_samples,\n", + ")\n", + "\n", + "# Compute the centroids\n", + "centroids, random_key = compute_cvt_centroids(\n", + " num_descriptors=env.behavior_descriptor_length,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + " random_key=random_key,\n", + ")\n", + "\n", + "# Compute initial repertoire and emitter state\n", + "repertoire, emitter_state, random_key = mels.init(init_variables, centroids, random_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch ME-LS iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = int(num_iterations / log_period)\n", + "\n", + "csv_logger = CSVLogger(\n", + " \"mapelites-logs.csv\",\n", + " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + ")\n", + "all_metrics = {}\n", + "\n", + "# main loop\n", + "mels_scan_update = mels.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " # main iterations\n", + " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " mels_scan_update,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # log metrics\n", + " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", + " for key, value in metrics.items():\n", + " # take last value\n", + " logged_metrics[key] = value[-1]\n", + "\n", + " # take all values\n", + " if key in all_metrics.keys():\n", + " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " else:\n", + " all_metrics[key] = value\n", + "\n", + " csv_logger.log(logged_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# create the x-axis array\n", + "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "\n", + "# create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to save/load a repertoire\n", + "\n", + "The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the final repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire_path = \"./last_repertoire/\"\n", + "os.makedirs(repertoire_path, exist_ok=True)\n", + "repertoire.save(path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build the reconstruction function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init population of policies\n", + "random_key, subkey = jax.random.split(random_key)\n", + "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", + "fake_params = policy_network.init(subkey, fake_batch)\n", + "\n", + "_, reconstruction_fn = ravel_pytree(fake_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the reconstruction function to load and re-create the repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire = MELSRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get the best individual of the repertoire\n", + "\n", + "Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its `num_samples` behavior descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_idx = jnp.argmax(repertoire.fitnesses)\n", + "best_fitness = jnp.max(repertoire.fitnesses)\n", + "best_bd = repertoire.descriptors[best_idx]\n", + "best_spread = repertoire.spreads[best_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\"\n", + " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\"\n", + " f\"Spread of the best individual in the repertoire: {best_spread}\\n\"\n", + " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_params = jax.tree_util.tree_map(\n", + " lambda x: x[best_idx],\n", + " repertoire.genotypes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Play some steps in the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jit_env_reset = jax.jit(env.reset)\n", + "jit_env_step = jax.jit(env.step)\n", + "jit_inference_fn = jax.jit(policy_network.apply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollout = []\n", + "rng = jax.random.PRNGKey(seed=1)\n", + "state = jit_env_reset(rng=rng)\n", + "while not state.done:\n", + " rollout.append(state)\n", + " action = jit_inference_fn(my_params, state.obs)\n", + " state = jit_env_step(state, action)\n", + "\n", + "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index 702a474a..2c0bbdb6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,6 +140,7 @@ nav: - MOME: api_documentation/core/mome.md - ME ES: api_documentation/core/mees.md - ME PBT: api_documentation/core/me_pbt.md + - ME LS: api_documentation/core/mels.md - Baseline algorithms: - SMERL: api_documentation/core/smerl.md - DIAYN: api_documentation/core/diayn.md diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py new file mode 100644 index 00000000..a2e99971 --- /dev/null +++ b/qdax/core/containers/mels_repertoire.py @@ -0,0 +1,311 @@ +"""This file contains the class to define the repertoire used to +store individuals in the Multi-Objective MAP-Elites algorithm as +well as several variants.""" + +from __future__ import annotations + +from typing import Callable, Optional + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + get_cells_indices, +) +from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread + + +def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: + """Computes dispersion of a batch of num_samples descriptors. + + Args: + descriptors: (num_samples, num_descriptors) array of descriptors. + Returns: + The float dispersion of the descriptors (this is represented as a scalar + jnp.ndarray). + """ + + # Pairwise distances between the descriptors. + dists = jnp.linalg.norm(descriptors[:, None] - descriptors, axis=2) + + # Compute dispersion -- this is the mean of the unique pairwise distances. + # + # Zero out the duplicate distances since the distance matrix is diagonal. + # Setting k=1 will also remove entries on the diagonal since they are zero. + dists = jnp.triu(dists, k=1) + + num_samples = len(descriptors) + n_pairwise = num_samples * (num_samples - 1) / 2.0 + + return jnp.sum(dists) / n_pairwise + + +def _mode(x: jnp.ndarray) -> jnp.ndarray: + """Computes mode (most common item) of an array. + + The return type is a scalar ndarray. + """ + unique_vals, counts = jnp.unique(x, return_counts=True, size=x.size) + return unique_vals[jnp.argmax(counts)] + + +class MELSRepertoire(MapElitesRepertoire): + """Class for the repertoire in MAP-Elites Low-Spread. + + This class inherits from MapElitesRepertoire. In addition to the stored data in + MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire + also maintains an array of spreads. We overload the save, load, add, and + init_default methods of MapElitesRepertoire. + + Refer to Mace 2023 for more info on MAP-Elites Low-Spread: + https://dl.acm.org/doi/abs/10.1145/3583131.3590433 + + Args: + genotypes: a PyTree containing all the genotypes in the repertoire ordered + by the centroids. Each leaf has a shape (num_centroids, num_features). The + PyTree can be a simple Jax array or a more complex nested structure such + as to represent parameters of neural network in Flax. + fitnesses: an array that contains the fitness of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + descriptors: an array that contains the descriptors of solutions in each cell + of the repertoire, ordered by centroids. The array shape + is (num_centroids, num_descriptors). + centroids: an array that contains the centroids of the tessellation. The array + shape is (num_centroids, num_descriptors). + spreads: an array that contains the spread of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + """ + + spreads: Spread + + def save(self, path: str = "./") -> None: + """Saves the repertoire on disk in the form of .npy files. + + Flattens the genotypes to store it with .npy format. Supposes that + a user will have access to the reconstruction function when loading + the genotypes. + + Args: + path: Path where the data will be saved. Defaults to "./". + """ + + def flatten_genotype(genotype: Genotype) -> jnp.ndarray: + flatten_genotype, _ = ravel_pytree(genotype) + return flatten_genotype + + # flatten all the genotypes + flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes) + + # save data + jnp.save(path + "genotypes.npy", flat_genotypes) + jnp.save(path + "fitnesses.npy", self.fitnesses) + jnp.save(path + "descriptors.npy", self.descriptors) + jnp.save(path + "centroids.npy", self.centroids) + jnp.save(path + "spreads.npy", self.spreads) + + @classmethod + def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire: + """Loads a MAP-Elites Low-Spread Repertoire. + + Args: + reconstruction_fn: Function to reconstruct a PyTree + from a flat array. + path: Path where the data is saved. Defaults to "./". + + Returns: + A MAP-Elites Low-Spread Repertoire. + """ + + flat_genotypes = jnp.load(path + "genotypes.npy") + genotypes = jax.vmap(reconstruction_fn)(flat_genotypes) + + fitnesses = jnp.load(path + "fitnesses.npy") + descriptors = jnp.load(path + "descriptors.npy") + centroids = jnp.load(path + "centroids.npy") + spreads = jnp.load(path + "spreads.npy") + + return cls( + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + spreads=spreads, + ) + + @jax.jit + def add( + self, + batch_of_genotypes: Genotype, + batch_of_descriptors: Descriptor, + batch_of_fitnesses: Fitness, + batch_of_extra_scores: Optional[ExtraScores] = None, + ) -> MELSRepertoire: + """ + Add a batch of elements to the repertoire. + + The key difference between this method and the default add() in + MapElitesRepertoire is that it expects each individual to be evaluated + `num_samples` times, resulting in `num_samples` fitnesses and + `num_samples` descriptors per individual. + + If multiple individuals may be added to a single cell, this method will + arbitrarily pick one -- the exact choice depends on the implementation of + jax.at[].set(), which can be non-deterministic: + https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + We do not currently check if one of the multiple individuals dominates the + others (dominate means that the individual has both highest fitness and lowest + spread among the individuals for that cell). + + If `num_samples` is only 1, the spreads will default to 0. + + Args: + batch_of_genotypes: a batch of genotypes to be added to the repertoire. + Similarly to the self.genotypes argument, this is a PyTree in which + the leaves have a shape (batch_size, num_features) + batch_of_descriptors: an array that contains the descriptors of the + aforementioned genotypes over all evals. Its shape is + (batch_size, num_samples, num_descriptors). Note that we "aggregate" + descriptors by finding the most frequent cell of each individual. Thus, + the actual descriptors stored in the repertoire are just the coordinates + of the centroid of the most frequent cell. + batch_of_fitnesses: an array that contains the fitnesses of the + aforementioned genotypes over all evals. Its shape is (batch_size, + num_samples) + batch_of_extra_scores: unused tree that contains the extra_scores of + aforementioned genotypes. + + Returns: + The updated repertoire. + """ + batch_size, num_samples = batch_of_fitnesses.shape + + # Compute indices/cells of all descriptors. + batch_of_all_indices = get_cells_indices( + batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids + ).reshape((batch_size, num_samples)) + + # Compute most frequent cell of each solution. + batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None] + + # Compute dispersion / spread. The dispersion is set to zero if + # num_samples is 1. + batch_of_spreads = jax.lax.cond( + num_samples == 1, + lambda desc: jnp.zeros(batch_size), + lambda desc: jax.vmap(_dispersion)( + desc.reshape((batch_size, num_samples, -1)) + ), + batch_of_descriptors, + ) + batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1) + + # Compute canonical descriptors as the descriptor of the centroid of the most + # frequent cell. Note that this line redefines the earlier batch_of_descriptors. + batch_of_descriptors = jnp.take_along_axis( + self.centroids, batch_of_indices, axis=0 + ) + + # Compute canonical fitnesses as the average fitness. + # + # Shape: (batch_size, 1) + batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True) + + num_centroids = self.centroids.shape[0] + + # get current repertoire fitnesses and spreads + repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) + current_fitnesses = jnp.take_along_axis( + repertoire_fitnesses, batch_of_indices, 0 + ) + + repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1) + current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0) + + # get addition condition + addition_condition_fitness = batch_of_fitnesses > current_fitnesses + addition_condition_spread = batch_of_spreads <= current_spreads + addition_condition = jnp.logical_and( + addition_condition_fitness, addition_condition_spread + ) + + # assign fake position when relevant : num_centroids is out of bound + batch_of_indices = jnp.where( + addition_condition, x=batch_of_indices, y=num_centroids + ) + + # create new repertoire + new_repertoire_genotypes = jax.tree_util.tree_map( + lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[ + batch_of_indices.squeeze(axis=-1) + ].set(new_genotypes), + self.genotypes, + batch_of_genotypes, + ) + + # compute new fitness and descriptors + new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_fitnesses.squeeze(axis=-1) + ) + new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_descriptors + ) + new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_spreads.squeeze(axis=-1) + ) + + return MELSRepertoire( + genotypes=new_repertoire_genotypes, + fitnesses=new_fitnesses, + descriptors=new_descriptors, + centroids=self.centroids, + spreads=new_spreads, + ) + + @classmethod + def init_default( + cls, + genotype: Genotype, + centroids: Centroid, + ) -> MELSRepertoire: + """Initialize a MAP-Elites Low-Spread repertoire with an initial population of + genotypes. Requires the definition of centroids that can be computed with any + method such as CVT or Euclidean mapping. + + Note: this function has been kept outside of the object MELS, so + it can be called easily called from other modules. + + Args: + genotype: the typical genotype that will be stored. + centroids: the centroids of the repertoire. + + Returns: + A repertoire filled with default values. + """ + + # get number of centroids + num_centroids = centroids.shape[0] + + # default fitness is -inf + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + + # default genotypes is all 0 + default_genotypes = jax.tree_util.tree_map( + lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype), + genotype, + ) + + # default descriptor is all zeros + default_descriptors = jnp.zeros_like(centroids) + + # default spread is inf so that any spread will be less + default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf) + + return cls( + genotypes=default_genotypes, + fitnesses=default_fitnesses, + descriptors=default_descriptors, + centroids=centroids, + spreads=default_spreads, + ) diff --git a/qdax/core/mels.py b/qdax/core/mels.py new file mode 100644 index 00000000..6c06b785 --- /dev/null +++ b/qdax/core/mels.py @@ -0,0 +1,104 @@ +"""Core components of the MAP-Elites Low-Spread algorithm.""" +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional, Tuple + +import jax + +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.map_elites import MAPElites +from qdax.types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Metrics, + RNGKey, +) +from qdax.utils.sampling import multi_sample_scoring_function + + +class MELS(MAPElites): + """Core elements of the MAP-Elites Low-Spread algorithm. + + Most methods in this class are inherited from MAPElites. + + The same scoring function can be passed into both MAPElites and this class. + We have overridden __init__ such that it takes in the scoring function and + wraps it such that every solution is evaluated `num_samples` times. + + We also overrode the init method to use the MELSRepertoire instead of + MapElitesRepertoire. + """ + + def __init__( + self, + scoring_function: Callable[ + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + ], + emitter: Emitter, + metrics_function: Callable[[MELSRepertoire], Metrics], + num_samples: int, + ) -> None: + self._scoring_function = partial( + multi_sample_scoring_function, + scoring_fn=scoring_function, + num_samples=num_samples, + ) + self._emitter = emitter + self._metrics_function = metrics_function + self._num_samples = num_samples + + @partial(jax.jit, static_argnames=("self",)) + def init( + self, + init_genotypes: Genotype, + centroids: Centroid, + random_key: RNGKey, + ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: + """Initialize a MAP-Elites Low-Spread repertoire with an initial + population of genotypes. Requires the definition of centroids that can + be computed with any method such as CVT or Euclidean mapping. + + Args: + init_genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + centroids: tessellation centroids of shape (batch_size, num_descriptors) + random_key: a random key used for stochastic operations. + + Returns: + A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter + state, JAX random key). + """ + # score initial genotypes + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + init_genotypes, random_key + ) + + # init the repertoire + repertoire = MELSRepertoire.init( + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # get initial state of the emitter + emitter_state, random_key = self._emitter.init( + init_genotypes=init_genotypes, random_key=random_key + ) + + # update emitter state + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + repertoire=repertoire, + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + return repertoire, emitter_state, random_key diff --git a/qdax/types.py b/qdax/types.py index 67fbb8a0..5000869b 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -26,6 +26,7 @@ Genotype: TypeAlias = ArrayTree Descriptor: TypeAlias = jnp.ndarray Centroid: TypeAlias = jnp.ndarray +Spread: TypeAlias = jnp.ndarray Gradient: TypeAlias = jnp.ndarray Skill: TypeAlias = jnp.ndarray diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index 88b6286e..a25e190f 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -8,7 +8,7 @@ from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey -@partial(jax.jit, static_argnames=("num_samples")) +@partial(jax.jit, static_argnames=("num_samples",)) def dummy_extra_scores_extractor( extra_scores: ExtraScores, num_samples: int, @@ -29,6 +29,60 @@ def dummy_extra_scores_extractor( return extra_scores +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_samples", + ), +) +def multi_sample_scoring_function( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_samples: int, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Wrap scoring_function to perform sampling. + + This function returns the fitnesses, descriptors, and extra_scores computed + over num_samples evaluations with the scoring_fn. + + Args: + policies_params: policies to evaluate + random_key: JAX random key + scoring_fn: scoring function used for evaluation + num_samples: number of samples to generate for each individual + + Returns: + (n, num_samples) array of fitnesses, + (n, num_samples, num_descriptors) array of descriptors, + dict with num_samples extra_scores per individual, + JAX random key + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=num_samples) + + # evaluate + sample_scoring_fn = jax.vmap( + scoring_fn, + # vectorizing over axis 0 vectorizes over the num_samples random keys + in_axes=(None, 0), + # indicates that the vectorized axis will become axis 1, i.e., the final + # output is shape (batch_size, num_samples, ...) + out_axes=1, + ) + all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( + policies_params, keys + ) + + return all_fitnesses, all_descriptors, all_extra_scores, random_key + + @partial( jax.jit, static_argnames=( @@ -49,14 +103,16 @@ def sampling( [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """ - Wrap scoring_function to perform sampling. + """Wrap scoring_function to perform sampling. + + This function averages the fitnesses and descriptors for each individual + over `num_samples` evaluations. Args: policies_params: policies to evaluate - random_key + random_key: JAX random key scoring_fn: scoring function used for evaluation - num_samples + num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from multiple samples of the same policy. @@ -65,14 +121,13 @@ def sampling( The extra_score extract from samples with extra_scores_extractor A new random key """ - - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, num=num_samples) - - # evaluate - sample_scoring_fn = jax.vmap(scoring_fn, (None, 0), 1) - all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( - policies_params, keys + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params, random_key, scoring_fn, num_samples ) # average results diff --git a/tests/core_test/containers_test/mels_repertoire_test.py b/tests/core_test/containers_test/mels_repertoire_test.py new file mode 100644 index 00000000..2fb1bd76 --- /dev/null +++ b/tests/core_test/containers_test/mels_repertoire_test.py @@ -0,0 +1,236 @@ +import jax.numpy as jnp +import pytest + +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.types import ExtraScores + + +def test_add_to_mels_repertoire() -> None: + """Test several additions to the MELSRepertoire, including adding a solution + and overwriting it by adding multiple solutions.""" + genotype_size = 12 + num_centroids = 4 + num_descriptors = 2 + + # create a repertoire instance + repertoire = MELSRepertoire( + genotypes=jnp.zeros(shape=(num_centroids, genotype_size)), + fitnesses=jnp.ones(shape=(num_centroids,)) * (-jnp.inf), + descriptors=jnp.zeros(shape=(num_centroids, num_descriptors)), + centroids=jnp.array( + [ + [1.0, 1.0], + [2.0, 1.0], + [2.0, 2.0], + [1.0, 2.0], + ] + ), + spreads=jnp.full(shape=(num_centroids,), fill_value=jnp.inf), + ) + + # + # Test 1: Insert a single solution. + # + + # create fake genotypes and scores to add + fake_genotypes = jnp.ones(shape=(1, genotype_size)) + # each solution gets two fitnesses and two descriptors + fake_fitnesses = jnp.array([[0.0, 0.0]]) + fake_descriptors = jnp.array([[[0.0, 1.0], [1.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # check that the repertoire looks as expected + expected_genotypes = jnp.array( + [ + [1.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses = jnp.array([0.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads = jnp.array([1.0, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume(jnp.allclose(repertoire.genotypes, expected_genotypes, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.fitnesses, expected_fitnesses, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.descriptors, expected_descriptors, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads, atol=1e-6)) + + # + # Test 2: Adding solutions into the same cell as above. + # + + # create fake genotypes and scores to add + fake_genotypes = jnp.concatenate( + ( + jnp.full(shape=(1, genotype_size), fill_value=2.0), + jnp.full(shape=(1, genotype_size), fill_value=3.0), + ), + axis=0, + ) + # Each solution gets two fitnesses and two descriptors (i.e. num_evals = 2). One + # solution has fitness 1.0 and spread 0.75, while the other has fitness 0.5 and + # spread 0.5. Thus, neither solution dominates the other (by having both higher + # fitness and lower spread). However, both solutions would be valid candidates for + # the archive due to dominating the current solution there. + fake_fitnesses = jnp.array([[1.0, 1.0], [0.5, 0.5]]) + fake_descriptors = jnp.array([[[1.0, 0.25], [1.0, 1.0]], [[1.0, 0.5], [1.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # Either solution may be added due to the behavior of jax.at[].set(): + # https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + # Thus, we provide possible values for each scenario. + + # check that the repertoire looks like expected + expected_genotypes_1 = jnp.array( + [ + [2.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses_1 = jnp.array([1.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors_1 = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads_1 = jnp.array([0.75, jnp.inf, jnp.inf, jnp.inf]) + + expected_genotypes_2 = jnp.array( + [ + [3.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses_2 = jnp.array([0.5, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors_2 = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads_2 = jnp.array([0.5, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6) + or jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6) + ) + + if jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6): + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.fitnesses, expected_fitnesses_1, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.descriptors, expected_descriptors_1, atol=1e-6) + ) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads_1, atol=1e-6)) + elif jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6): + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.fitnesses, expected_fitnesses_2, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.descriptors, expected_descriptors_2, atol=1e-6) + ) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads_2, atol=1e-6)) + + +def test_add_with_single_eval() -> None: + """Tries adding with a single evaluation. + + This is a special case because the spread defaults to 0. + """ + genotype_size = 12 + num_centroids = 4 + num_descriptors = 2 + + # create a repertoire instance + repertoire = MELSRepertoire( + genotypes=jnp.zeros(shape=(num_centroids, genotype_size)), + fitnesses=jnp.ones(shape=(num_centroids,)) * (-jnp.inf), + descriptors=jnp.zeros(shape=(num_centroids, num_descriptors)), + centroids=jnp.array( + [ + [1.0, 1.0], + [2.0, 1.0], + [2.0, 2.0], + [1.0, 2.0], + ] + ), + spreads=jnp.full(shape=(num_centroids,), fill_value=jnp.inf), + ) + + # Insert a single solution with only one eval. + + # create fake genotypes and scores to add + fake_genotypes = jnp.ones(shape=(1, genotype_size)) + # the solution gets one fitness and one descriptor. + fake_fitnesses = jnp.array([[0.0]]) + fake_descriptors = jnp.array([[[0.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # check that the repertoire looks as expected + expected_genotypes = jnp.array( + [ + [1.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses = jnp.array([0.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + # Spread should be 0 since there's only one eval. + expected_spreads = jnp.array([0.0, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume(jnp.allclose(repertoire.genotypes, expected_genotypes, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.fitnesses, expected_fitnesses, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.descriptors, expected_descriptors, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads, atol=1e-6)) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py new file mode 100644 index 00000000..21f90517 --- /dev/null +++ b/tests/core_test/mels_test.py @@ -0,0 +1,156 @@ +"""Tests MAP-Elites Low-Spread implementation.""" + +import functools +from typing import Dict, Tuple + +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.mels import MELS +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs +from qdax.types import EnvState, Params, RNGKey + + +@pytest.mark.parametrize( + "env_name, batch_size", + [("walker2d_uni", 1), ("walker2d_uni", 10), ("hopper_uni", 10)], +) +def test_mels(env_name: str, batch_size: int) -> None: + batch_size = batch_size + env_name = env_name + num_samples = 5 + episode_length = 100 + num_iterations = 5 + seed = 42 + policy_hidden_layer_sizes = (64, 64) + num_init_cvt_samples = 1000 + num_centroids = 50 + min_bd = 0.0 + max_bd = 1.0 + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + # Init population of controllers. There are batch_size controllers, and each + # controller will be evaluated num_samples times. + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_variables = jax.vmap(policy_network.init)(keys, fake_batch) + + # Define the function to play a step with the policy in the environment + def play_step_fn( + env_state: EnvState, + policy_params: Params, + random_key: RNGKey, + ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: + """Play an environment step and return the updated state and the + transition.""" + + actions = policy_network.apply(policy_params, env_state.obs) + + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = QDTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + actions=actions, + truncations=next_state.info["truncation"], + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + reset_based_scoring_function_brax_envs, + episode_length=episode_length, + play_reset_fn=env.reset, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + # Define emitter + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = environments.reward_offset[env_name] + + # Define a metrics function + def metrics_fn(repertoire: MELSRepertoire) -> Dict: + # Get metrics + grid_empty = repertoire.fitnesses == -jnp.inf + qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) + # Add offset for positive qd_score + qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) + coverage = 100 * jnp.mean(1.0 - grid_empty) + max_fitness = jnp.max(repertoire.fitnesses) + + return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + + # Instantiate ME-LS. + mels = MELS( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + num_samples=num_samples, + ) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = mels.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + mels.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_mels(env_name="pointmaze", batch_size=10) From 64ea81d5aa7800e760418a49344cf04433d55ae1 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Mon, 9 Oct 2023 11:31:20 +0100 Subject: [PATCH 08/16] Update ci.yaml to accept PRs from outside repositories --- .github/workflows/ci.yaml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fab4bb44..d1e0c060 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -3,7 +3,8 @@ name: ci on: push: branches: [main, develop] - pull_request: + pull_request_target: + types: [assigned, opened, synchronize, reopened] env: REGISTRY: ghcr.io @@ -25,6 +26,11 @@ jobs: build: needs: [setup] runs-on: ubuntu-latest + + permissions: + contents: read + packages: write + steps: - name: Checkout uses: actions/checkout@v2 From c111cea61928b04bf64ac6634290d88370fb08f4 Mon Sep 17 00:00:00 2001 From: Bryon Tjanaka <38124174+btjanaka@users.noreply.github.com> Date: Mon, 9 Oct 2023 05:28:59 -0700 Subject: [PATCH 09/16] fix: Various spelling errors (#150) * Update map_elites.py distributed_map_elites.py mome_repertoire.py mapelites_repertoire.py and mapelites.ipynb --- examples/mapelites.ipynb | 2 +- qdax/core/containers/mapelites_repertoire.py | 8 ++++---- qdax/core/containers/mome_repertoire.py | 2 +- qdax/core/distributed_map_elites.py | 2 +- qdax/core/map_elites.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index 18728e73..49765438 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -173,7 +173,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the fonction to play a step with the policy in the environment\n", + "# Define the function to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index 1babb2c9..aed74c78 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -26,10 +26,10 @@ def compute_cvt_centroids( maxval: Union[float, List[float]], random_key: RNGKey, ) -> Tuple[jnp.ndarray, RNGKey]: - """Compute centroids for CVT tesselation. + """Compute centroids for CVT tessellation. Args: - num_descriptors: number od scalar descriptors + num_descriptors: number of scalar descriptors num_init_cvt_samples: number of sampled point to be sued for clustering to determine the centroids. The larger the number of centroids and the number of descriptors, the higher this value must be (e.g. 100000 for @@ -69,7 +69,7 @@ def compute_euclidean_centroids( minval: Union[float, List[float]], maxval: Union[float, List[float]], ) -> jnp.ndarray: - """Compute centroids for square Euclidean tesselation. + """Compute centroids for square Euclidean tessellation. Args: grid_shape: number of centroids per BD dimension @@ -144,7 +144,7 @@ class MapElitesRepertoire(flax.struct.PyTreeNode): descriptors: an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors). - centroids: an array the contains the centroids of the tesselation. The array + centroids: an array that contains the centroids of the tessellation. The array shape is (num_centroids, num_descriptors). """ diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index d1da327a..0e2b6d3e 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -377,7 +377,7 @@ def init( # type: ignore (batch_size, num_criteria) descriptors: descriptors of the initial genotypes of shape (batch_size, num_descriptors) - centroids: tesselation centroids of shape (batch_size, num_descriptors) + centroids: tessellation centroids of shape (batch_size, num_descriptors) pareto_front_max_length: maximum size of the pareto fronts extra_scores: unused extra_scores of the initial genotypes diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index 79b39f49..c8a1ea44 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -32,7 +32,7 @@ def init( Args: init_genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tesselation centroids of shape (batch_size, num_descriptors) + centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index d306d11e..c71b0013 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -23,7 +23,7 @@ class MAPElites: """Core elements of the MAP-Elites algorithm. Note: Although very similar to the GeneticAlgorithm, we decided to keep the - MAPElites class independant of the GeneticAlgorithm class at the moment to keep + MAPElites class independent of the GeneticAlgorithm class at the moment to keep elements explicit. Args: @@ -64,7 +64,7 @@ def init( Args: init_genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tesselation centroids of shape (batch_size, num_descriptors) + centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: From df954c291171278bb149117ee6559f2c755c274e Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 9 Oct 2023 11:23:20 -0400 Subject: [PATCH 10/16] fix: use brax.v1 and update requirements (#156) * Change brax for brax.v1 * update jax/jaxlib and flax dependencies --- .readthedocs.yaml | 2 +- dev.Dockerfile | 12 +++++------ qdax/__init__.py | 2 +- qdax/environments/__init__.py | 25 ++++++++++++----------- qdax/environments/base_wrappers.py | 4 ++-- qdax/environments/exploration_wrappers.py | 14 ++++++------- qdax/environments/humanoidtrap.py | 16 +++++++-------- qdax/environments/init_state_wrapper.py | 8 ++++---- qdax/environments/locomotion_wrappers.py | 10 ++++----- qdax/environments/pointmaze.py | 16 +++++++-------- qdax/environments/wrappers.py | 12 +++++------ requirements.txt | 20 +++++++++--------- setup.py | 12 +++++------ 13 files changed, 77 insertions(+), 76 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 7eec359d..d9f0965b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.8" + python: "3.9" apt_packages: - swig diff --git a/dev.Dockerfile b/dev.Dockerfile index a6ac351f..458599db 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,4 +1,4 @@ -FROM mambaorg/micromamba:0.22.0 as conda +FROM mambaorg/micromamba:1.5.1 as conda # Speed up the build, and avoid unnecessary writes to disk ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 CONDA_DIR=/opt/conda @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \ FROM python as test-image -ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ @@ -25,8 +25,8 @@ COPY requirements-dev.txt ./ RUN pip install -r requirements-dev.txt -FROM nvidia/cuda:11.4.1-cudnn8-devel-ubuntu20.04 as cuda-image -ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app +FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image +ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH @@ -40,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l ENV TZ=Europe/Paris RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone -RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \ +RUN pip --no-cache-dir install jaxlib==0.4.16+cuda11.cudnn86 \ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ && rm -rf /tmp/* @@ -70,7 +70,7 @@ RUN apt-get update && \ libosmesa6-dev \ patchelf \ python3-opengl \ - python3-dev=3.8* \ + python3-dev=3.9* \ python3-pip \ screen \ sudo \ diff --git a/qdax/__init__.py b/qdax/__init__.py index b5fdc753..788da1fb 100644 --- a/qdax/__init__.py +++ b/qdax/__init__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.2.4" diff --git a/qdax/environments/__init__.py b/qdax/environments/__init__.py index 3e89df1a..be336a14 100644 --- a/qdax/environments/__init__.py +++ b/qdax/environments/__init__.py @@ -1,8 +1,9 @@ import functools from typing import Any, Callable, List, Optional, Union -import brax -import brax.envs +from brax.v1.envs import Env +from brax.v1.envs import _envs +from brax.v1.envs.wrappers import EpisodeWrapper, AutoResetWrapper, EvalWrapper, VectorWrapper from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper from qdax.environments.bd_extractors import ( @@ -122,20 +123,20 @@ def create( fixed_init_state: bool = False, qdax_wrappers_kwargs: Optional[List] = None, **kwargs: Any, -) -> Union[brax.envs.env.Env, QDEnv]: +) -> Union[Env, QDEnv]: """Creates an Env with a specified brax system. Please use namespace to avoid confusion between this function and brax.envs.create. """ - if env_name in brax.envs._envs.keys(): - env = brax.envs._envs[env_name](legacy_spring=True, **kwargs) + if env_name in _envs.keys(): + env = _envs[env_name](legacy_spring=True, **kwargs) elif env_name in _qdax_envs.keys(): env = _qdax_envs[env_name](**kwargs) elif env_name in _qdax_custom_envs.keys(): base_env_name = _qdax_custom_envs[env_name]["env"] - if base_env_name in brax.envs._envs.keys(): - env = brax.envs._envs[base_env_name](legacy_spring=True, **kwargs) + if base_env_name in _envs.keys(): + env = _envs[base_env_name](legacy_spring=True, **kwargs) elif base_env_name in _qdax_envs.keys(): env = _qdax_envs[base_env_name](**kwargs) # type: ignore else: @@ -152,9 +153,9 @@ def create( env = wrapper(env, base_env_name, **kwargs) # type: ignore if episode_length is not None: - env = brax.envs.wrappers.EpisodeWrapper(env, episode_length, action_repeat) + env = EpisodeWrapper(env, episode_length, action_repeat) if batch_size: - env = brax.envs.wrappers.VectorWrapper(env, batch_size) + env = VectorWrapper(env, batch_size) if fixed_init_state: # retrieve the base env if env_name not in _qdax_custom_envs.keys(): @@ -162,17 +163,17 @@ def create( # wrap the env env = FixedInitialStateWrapper(env, base_env_name=base_env_name) # type: ignore if auto_reset: - env = brax.envs.wrappers.AutoResetWrapper(env) + env = AutoResetWrapper(env) if env_name in _qdax_custom_envs.keys(): env = StateDescriptorResetWrapper(env) if eval_metrics: - env = brax.envs.wrappers.EvalWrapper(env) + env = EvalWrapper(env) env = CompletedEvalWrapper(env) return env -def create_fn(env_name: str, **kwargs: Any) -> Callable[..., brax.envs.Env]: +def create_fn(env_name: str, **kwargs: Any) -> Callable[..., Env]: """Returns a function that when called, creates an Env. Please use namespace to avoid confusion between this function and brax.envs.create_fn. diff --git a/qdax/environments/base_wrappers.py b/qdax/environments/base_wrappers.py index 69ef782f..6f317e7f 100644 --- a/qdax/environments/base_wrappers.py +++ b/qdax/environments/base_wrappers.py @@ -1,8 +1,8 @@ from abc import abstractmethod from typing import Any, List, Tuple -from brax import jumpy as jp -from brax.envs.env import Env, State +from brax.v1 import jumpy as jp +from brax.v1.envs import Env, State class QDEnv(Env): diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index 33428994..fbffad5f 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -1,9 +1,9 @@ import warnings -import brax +import brax.v1 as brax import jax.numpy as jnp -from brax import jumpy as jp -from brax.envs import State, env +from brax.v1 import jumpy as jp +from brax.v1.envs import Env, State, Wrapper from google.protobuf import text_format # type: ignore from qdax.environments.locomotion_wrappers import COG_NAMES @@ -103,7 +103,7 @@ } -class TrapWrapper(env.Wrapper): +class TrapWrapper(Wrapper): """Wraps gym environments to add a Trap in the environment. Utilisation is simple: create an environment with Brax, pass @@ -143,7 +143,7 @@ class TrapWrapper(env.Wrapper): """ - def __init__(self, env: env.Env, env_name: str) -> None: + def __init__(self, env: Env, env_name: str) -> None: if ( env_name not in ENV_SYSTEM_CONFIG.keys() or env_name not in COG_NAMES.keys() @@ -323,7 +323,7 @@ def step(self, state: State, action: jp.ndarray) -> State: } -class MazeWrapper(env.Wrapper): +class MazeWrapper(Wrapper): """Wraps gym environments to add a maze in the environment and a new reward (distance to the goal). @@ -364,7 +364,7 @@ class MazeWrapper(env.Wrapper): """ - def __init__(self, env: env.Env, env_name: str) -> None: + def __init__(self, env: Env, env_name: str) -> None: if ( env_name not in ENV_SYSTEM_CONFIG.keys() or env_name not in COG_NAMES.keys() diff --git a/qdax/environments/humanoidtrap.py b/qdax/environments/humanoidtrap.py index 2d0773a9..b24226b9 100644 --- a/qdax/environments/humanoidtrap.py +++ b/qdax/environments/humanoidtrap.py @@ -3,10 +3,10 @@ from typing import Any, Dict -import brax -from brax import jumpy as jp -from brax.envs import env -from brax.physics import bodies +import brax.v1 as brax +from brax.v1 import jumpy as jp +from brax.v1.envs import Env, State +from brax.v1.physics import bodies TRAP_CONFIG = """bodies { name: "Trap" @@ -58,7 +58,7 @@ """ -class HumanoidTrap(env.Env): +class HumanoidTrap(Env): """Trains a humanoid to run in the +x direction. RMQ: uses legacy spring from Brax. @@ -76,7 +76,7 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None: self.inertia = body.inertia self.inertia_matrix = jp.array([jp.diag(a) for a in self.inertia]) - def reset(self, rng: jp.ndarray) -> env.State: + def reset(self, rng: jp.ndarray) -> State: """Resets the environment to an initial state.""" rng, rng1, rng2 = jp.random_split(rng, 3) qpos = self.sys.default_angle() + jp.random_uniform( @@ -93,9 +93,9 @@ def reset(self, rng: jp.ndarray) -> env.State: "reward_alive": zero, "reward_impact": zero, } - return env.State(qp, obs, reward, done, metrics) + return State(qp, obs, reward, done, metrics) - def step(self, state: env.State, action: jp.ndarray) -> env.State: + def step(self, state: State, action: jp.ndarray) -> State: """Run one timestep of the environment's dynamics.""" qp, info = self.sys.step(state.qp, action) obs = self._get_obs(qp, info, action) diff --git a/qdax/environments/init_state_wrapper.py b/qdax/environments/init_state_wrapper.py index cdeb9517..be395e50 100644 --- a/qdax/environments/init_state_wrapper.py +++ b/qdax/environments/init_state_wrapper.py @@ -1,8 +1,8 @@ from typing import Callable, Optional -import brax -from brax import jumpy as jp -from brax.envs import Env, State, Wrapper +import brax.v1 as brax +from brax.v1 import jumpy as jp +from brax.v1.envs import Env, State, Wrapper class FixedInitialStateWrapper(Wrapper): @@ -51,7 +51,7 @@ def reset(self, rng: jp.ndarray) -> State: # Run the default reset method of parent environment state = self.env.reset(rng) - # Compute new initial positions and velicities + # Compute new initial positions and velocities qpos = self.sys.default_angle() qvel = jp.zeros((self.sys.num_joint_dof,)) diff --git a/qdax/environments/locomotion_wrappers.py b/qdax/environments/locomotion_wrappers.py index 022662e9..a727479e 100644 --- a/qdax/environments/locomotion_wrappers.py +++ b/qdax/environments/locomotion_wrappers.py @@ -1,11 +1,11 @@ from typing import Any, List, Optional, Sequence, Tuple import jax.numpy as jnp -from brax import jumpy as jp -from brax.envs import Env, State, Wrapper -from brax.physics import config_pb2 -from brax.physics.base import QP, Info -from brax.physics.system import System +from brax.v1 import jumpy as jp +from brax.v1.envs import Env, State, Wrapper +from brax.v1.physics import config_pb2 +from brax.v1.physics.base import QP, Info +from brax.v1.physics.system import System from qdax.environments.base_wrappers import QDEnv diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index 30854283..b5f86ef5 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -1,18 +1,18 @@ from typing import Any, Dict, List, Tuple, Union -import brax -from brax import jumpy as jp -from brax.envs import env +import brax.v1 as brax +from brax.v1 import jumpy as jp +from brax.v1.envs import Env, State -class PointMaze(env.Env): +class PointMaze(Env): """Jax/Brax implementation of the PointMaze. Highly inspired from the old python implementation of the PointMaze. In order to stay in the Brax API, I will use a fake QP at several moment of the implementation. This enable to - use the brax.envs.env.State from Brax. To avoid this, + use the brax.envs.State from Brax. To avoid this, it would be good to ask Brax to enlarge a bit their API for environments that are not physically simulated. """ @@ -103,7 +103,7 @@ def action_size(self) -> int: """The size of the observation vector returned in step and reset.""" return 2 - def reset(self, rng: jp.ndarray) -> env.State: + def reset(self, rng: jp.ndarray) -> State: """Resets the environment to an initial state.""" rng, rng1, rng2 = jp.random_split(rng, 3) # get initial position - reproduce the old implementation @@ -117,9 +117,9 @@ def reset(self, rng: jp.ndarray) -> env.State: metrics: Dict = {} # managing state descriptor by our own info_init = {"state_descriptor": obs_init} - return env.State(fake_qp, obs_init, reward, done, metrics, info_init) + return State(fake_qp, obs_init, reward, done, metrics, info_init) - def step(self, state: env.State, action: jp.ndarray) -> env.State: + def step(self, state: State, action: jp.ndarray) -> State: """Run one timestep of the environment's dynamics.""" # clip action taken diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 45be0787..274b9073 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -1,9 +1,9 @@ from typing import Dict -import brax.envs +from brax.v1.envs import State, Wrapper import flax.struct import jax -from brax import jumpy as jp +from brax.v1 import jumpy as jp class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -13,12 +13,12 @@ class CompletedEvalMetrics(flax.struct.PyTreeNode): completed_episodes_steps: jp.ndarray -class CompletedEvalWrapper(brax.envs.env.Wrapper): +class CompletedEvalWrapper(Wrapper): """Brax env with eval metrics for completed episodes.""" STATE_INFO_KEY = "completed_eval_metrics" - def reset(self, rng: jp.ndarray) -> brax.envs.env.State: + def reset(self, rng: jp.ndarray) -> State: reset_state = self.env.reset(rng) reset_state.metrics["reward"] = reset_state.reward eval_metrics = CompletedEvalMetrics( @@ -35,8 +35,8 @@ def reset(self, rng: jp.ndarray) -> brax.envs.env.State: return reset_state def step( - self, state: brax.envs.env.State, action: jp.ndarray - ) -> brax.envs.env.State: + self, state: State, action: jp.ndarray + ) -> State: state_metrics = state.info[self.STATE_INFO_KEY] if not isinstance(state_metrics, CompletedEvalMetrics): raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}") diff --git a/requirements.txt b/requirements.txt index 718d6213..31008a59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,19 @@ absl-py==1.0.0 -brax==0.0.15 -chex==0.1.5 +brax==0.9.2 +chex==0.1.83 dm-haiku==0.0.9 -flax==0.6.0 -gym==0.23.1 +flax==0.7.4 +gym==0.26.2 ipython -jax==0.3.17 -jaxlib==0.3.15 -jumanji==0.1.3 +jax==0.4.16 +jaxlib==0.4.16 +jumanji==0.3.1 jupyter -numpy==1.22.3 -optax==0.1.4 +numpy==1.24.1 +optax==0.1.7 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 seaborn==0.11.2 tensorflow-probability==0.15.0 -typing-extensions==4.3.0 +typing-extensions==4.3.0 \ No newline at end of file diff --git a/setup.py b/setup.py index a71f3174..f597ad32 100644 --- a/setup.py +++ b/setup.py @@ -22,15 +22,15 @@ long_description_content_type="text/markdown", install_requires=[ "absl-py>=1.0.0", - "jax>=0.3.16", - "jaxlib>=0.3.15", # necessary to build the doc atm + "jax>=0.4.16", + "jaxlib>=0.4.16", # necessary to build the doc atm "jinja2<3.1.0", - "jumanji>=0.1.3", - "flax>=0.6, <0.6.2", - "brax>=0.0.15", + "jumanji>=0.3.1", + "flax>=0.7.4", + "brax>=0.9.2", "gym>=0.23.1", "numpy>=1.22.3", - "optax>=0.1, <0.1.5", + "optax>=0.1.7", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], From a63df5b1cabcb3dd22ca112bc8b193bc6e6e7b64 Mon Sep 17 00:00:00 2001 From: Manon Flageat <61653012+manon-but-yes@users.noreply.github.com> Date: Mon, 27 Nov 2023 18:07:56 +0000 Subject: [PATCH 11/16] Add multiple variants of sampling extractors (#158) --- qdax/utils/sampling.py | 197 ++++++++++++++++++++++++++++-- tests/utils_test/sampling_test.py | 175 +++++++++++++++++++++----- 2 files changed, 334 insertions(+), 38 deletions(-) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index a25e190f..bf5c1ae4 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -8,6 +8,91 @@ from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +@jax.jit +def average(quantities: jnp.ndarray) -> jnp.ndarray: + """Default expectation extractor using average.""" + return jnp.average(quantities, axis=1) + + +@jax.jit +def median(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative expectation extractor using median. + More robust to outliers than average.""" + return jnp.median(quantities, axis=1) + + +@jax.jit +def mode(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative expectation extractor using mode. + More robust to outliers than average. + WARNING: for multidimensional objects such as descriptor, do + dimension-wise selection. + """ + + def _mode(quantity: jnp.ndarray) -> jnp.ndarray: + + # Ensure correct dimensions for both single and multi-dimension + quantity = jnp.reshape(quantity, (quantity.shape[0], -1)) + + # Dimension-wise voting in case of multi-dimension + def _dim_mode(dim_quantity: jnp.ndarray) -> jnp.ndarray: + unique_vals, counts = jnp.unique( + dim_quantity, return_counts=True, size=dim_quantity.size + ) + return unique_vals[jnp.argmax(counts)] + + # vmap over dimensions + return jnp.squeeze(jax.vmap(_dim_mode)(jnp.transpose(quantity))) + + # vmap over individuals + return jax.vmap(_mode)(quantities) + + +@jax.jit +def closest(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative expectation extractor selecting individual + that has the minimum distance to all other individuals. This + is an approximation of the geometric median. + More robust to outliers than average.""" + + def _closest(values: jnp.ndarray) -> jnp.ndarray: + def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return jnp.sqrt(jnp.sum(jnp.square(x - y))) + + distances = jax.vmap( + jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None) + )(values, values) + return values[jnp.argmin(jnp.mean(distances, axis=0))] + + return jax.vmap(_closest)(quantities) + + +@jax.jit +def std(quantities: jnp.ndarray) -> jnp.ndarray: + """Default reproducibility extractor using standard deviation.""" + return jnp.std(quantities, axis=1) + + +@jax.jit +def mad(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative reproducibility extractor using Median Absolute Deviation. + More robust to outliers than standard deviation.""" + num_samples = quantities.shape[1] + median = jnp.repeat( + jnp.median(quantities, axis=1, keepdims=True), num_samples, axis=1 + ) + return jnp.median(jnp.abs(quantities - median), axis=1) + + +@jax.jit +def iqr(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative reproducibility extractor using Inter-Quartile Range. + More robust to outliers than standard deviation.""" + q1 = jnp.quantile(quantities, 0.25, axis=1) + q4 = jnp.quantile(quantities, 0.75, axis=1) + return q4 - q1 + + @partial(jax.jit, static_argnames=("num_samples",)) def dummy_extra_scores_extractor( extra_scores: ExtraScores, @@ -89,6 +174,8 @@ def multi_sample_scoring_function( "scoring_fn", "num_samples", "extra_scores_extractor", + "fitness_extractor", + "descriptor_extractor", ), ) def sampling( @@ -102,11 +189,14 @@ def sampling( extra_scores_extractor: Callable[ [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: """Wrap scoring_function to perform sampling. - This function averages the fitnesses and descriptors for each individual - over `num_samples` evaluations. + This function return the expected fitnesses and descriptors for each + individual over `num_samples` evaluations using the provided extractor + function for the fitness and the descriptor. Args: policies_params: policies to evaluate @@ -115,12 +205,17 @@ def sampling( num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from multiple samples of the same policy. + fitness_extractor: function to extract the fitness expectation from + multiple samples of the same policy. + descriptor_extractor: function to extract the descriptor expectation + from multiple samples of the same policy. Returns: - The average fitness and descriptor of the individuals - The extra_score extract from samples with extra_scores_extractor + The expected fitnesses, descriptors and extra_scores of the individuals A new random key """ + + # Perform sampling ( all_fitnesses, all_descriptors, @@ -130,11 +225,95 @@ def sampling( policies_params, random_key, scoring_fn, num_samples ) - # average results - descriptors = jnp.average(all_descriptors, axis=1) - fitnesses = jnp.average(all_fitnesses, axis=1) - - # extract extra scores and add number of evaluations to it + # Extract final scores + descriptors = descriptor_extractor(all_descriptors) + fitnesses = fitness_extractor(all_fitnesses) extra_scores = extra_scores_extractor(all_extra_scores, num_samples) return fitnesses, descriptors, extra_scores, random_key + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_samples", + "extra_scores_extractor", + "fitness_extractor", + "descriptor_extractor", + "fitness_reproducibility_extractor", + "descriptor_reproducibility_extractor", + ), +) +def sampling_reproducibility( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_samples: int, + extra_scores_extractor: Callable[ + [ExtraScores, int], ExtraScores + ] = dummy_extra_scores_extractor, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, + fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, + descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, +) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor, RNGKey]: + """Wrap scoring_function to perform sampling and compute the + expectation and reproduciblity. + + This function return the reproducibility of fitnesses and descriptors for each + individual over `num_samples` evaluations using the provided extractor + function for the fitness and the descriptor. + + Args: + policies_params: policies to evaluate + random_key: JAX random key + scoring_fn: scoring function used for evaluation + num_samples: number of samples to generate for each individual + extra_scores_extractor: function to extract the extra_scores from + multiple samples of the same policy. + fitness_extractor: function to extract the fitness expectation from + multiple samples of the same policy. + descriptor_extractor: function to extract the descriptor expectation + from multiple samples of the same policy. + fitness_reproducibility_extractor: function to extract the fitness + reproducibility from multiple samples of the same policy. + descriptor_reproducibility_extractor: function to extract the descriptor + reproducibility from multiple samples of the same policy. + + Returns: + The expected fitnesses, descriptors and extra_scores of the individuals + The fitnesses and descriptors reproducibility of the individuals + A new random key + """ + + # Perform sampling + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params, random_key, scoring_fn, num_samples + ) + + # Extract final scores + descriptors = descriptor_extractor(all_descriptors) + fitnesses = fitness_extractor(all_fitnesses) + extra_scores = extra_scores_extractor(all_extra_scores, num_samples) + + # Extract reproducibility + descriptors_reproducibility = descriptor_reproducibility_extractor(all_descriptors) + fitnesses_reproducibility = fitness_reproducibility_extractor(all_fitnesses) + + return ( + fitnesses, + descriptors, + extra_scores, + fitnesses_reproducibility, + descriptors_reproducibility, + random_key, + ) diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index a7b2d15d..6ce6cbe9 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -1,5 +1,5 @@ import functools -from typing import Tuple +from typing import Callable, Tuple import jax import jax.numpy as jnp @@ -10,7 +10,17 @@ from qdax.core.neuroevolution.networks.networks import MLP from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.types import EnvState, Params, RNGKey -from qdax.utils.sampling import sampling +from qdax.utils.sampling import ( + average, + closest, + iqr, + mad, + median, + mode, + sampling, + sampling_reproducibility, + std, +) def test_sampling() -> None: @@ -74,7 +84,7 @@ def play_step_fn( keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) init_states = reset_fn(keys) - # Compare scoring against perforing a single sample + # Create the scoring function bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] scoring_fn = functools.partial( scoring_function_brax_envs, @@ -83,34 +93,141 @@ def play_step_fn( play_step_fn=play_step_fn, behavior_descriptor_extractor=bd_extraction_fn, ) - scoring_1_sample_fn = functools.partial( - sampling, - scoring_fn=scoring_fn, - num_samples=1, - ) - # Evaluate individuals using the scoring functions - fitnesses, descriptors, _, _ = scoring_fn(init_variables, random_key) - sample_fitnesses, sample_descriptors, _, _ = scoring_1_sample_fn( - init_variables, random_key - ) + # Test function for different extractors + def sampling_test( + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray], + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray], + ) -> None: + + # Compare scoring against perforing a single sample + scoring_1_sample_fn = functools.partial( + sampling, + scoring_fn=scoring_fn, + num_samples=1, + fitness_extractor=fitness_extractor, + descriptor_extractor=descriptor_extractor, + ) - # Compare - pytest.assume(jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08)) - pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + # Evaluate individuals using the scoring functions + fitnesses, descriptors, _, _ = scoring_fn(init_variables, random_key) + sample_fitnesses, sample_descriptors, _, _ = scoring_1_sample_fn( + init_variables, random_key + ) - # Compare scoring against perforing multiple samples - scoring_multi_sample_fn = functools.partial( - sampling, - scoring_fn=scoring_fn, - num_samples=sample_number, - ) + # Compare + pytest.assume( + jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08) + ) + pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + + # Compare scoring against perforing multiple samples + scoring_multi_sample_fn = functools.partial( + sampling, + scoring_fn=scoring_fn, + num_samples=sample_number, + fitness_extractor=fitness_extractor, + descriptor_extractor=descriptor_extractor, + ) - # Evaluate individuals using the scoring functions - sample_fitnesses, sample_descriptors, _, _ = scoring_multi_sample_fn( - init_variables, random_key - ) + # Evaluate individuals using the scoring functions + sample_fitnesses, sample_descriptors, _, _ = scoring_multi_sample_fn( + init_variables, random_key + ) + + # Compare + pytest.assume( + jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08) + ) + pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + + # Call the test for each type of extractor + sampling_test(average, average) + sampling_test(median, median) + sampling_test(mode, mode) + sampling_test(closest, closest) + + # Test function for different reproducibility extractors + def sampling_reproducibility_test( + fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray], + descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray], + ) -> None: + + # Compare scoring against perforing a single sample + scoring_1_sample_fn = functools.partial( + sampling_reproducibility, + scoring_fn=scoring_fn, + num_samples=1, + fitness_reproducibility_extractor=fitness_reproducibility_extractor, + descriptor_reproducibility_extractor=descriptor_reproducibility_extractor, + ) + + # Evaluate individuals using the scoring functions + ( + _, + _, + _, + fitnesses_reproducibility, + descriptors_reproducibility, + _, + ) = scoring_1_sample_fn(init_variables, random_key) + + # Compare - all reproducibility should be 0 + pytest.assume( + jnp.allclose( + fitnesses_reproducibility, + jnp.zeros_like(fitnesses_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) + pytest.assume( + jnp.allclose( + descriptors_reproducibility, + jnp.zeros_like(descriptors_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) + + # Compare scoring against perforing multiple samples + scoring_multi_sample_fn = functools.partial( + sampling_reproducibility, + scoring_fn=scoring_fn, + num_samples=sample_number, + fitness_reproducibility_extractor=fitness_reproducibility_extractor, + descriptor_reproducibility_extractor=descriptor_reproducibility_extractor, + ) + + # Evaluate individuals using the scoring functions + ( + _, + _, + _, + fitnesses_reproducibility, + descriptors_reproducibility, + _, + ) = scoring_multi_sample_fn(init_variables, random_key) + + # Compare - all reproducibility should be 0 + pytest.assume( + jnp.allclose( + fitnesses_reproducibility, + jnp.zeros_like(fitnesses_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) + pytest.assume( + jnp.allclose( + descriptors_reproducibility, + jnp.zeros_like(descriptors_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) - # Compare - pytest.assume(jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08)) - pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + # Call the test for each type of extractor + sampling_reproducibility_test(std, std) + sampling_reproducibility_test(mad, mad) + sampling_reproducibility_test(iqr, iqr) From aaa3ece03ee91068c3bf2bdaae2be05cf60db637 Mon Sep 17 00:00:00 2001 From: Felix Chalumeau Date: Tue, 28 Nov 2023 06:31:08 +0200 Subject: [PATCH 12/16] fix(readme): update the paper to cite (#161) * fix(readme): update the paper to cite * Fix python version and dependencies for CI --- .pre-commit-config.yaml | 4 ++-- .readthedocs.yaml | 2 +- README.md | 12 +++++++----- dev.Dockerfile | 12 ++++++------ docs/installation.md | 2 +- environment.yaml | 4 ++-- requirements.txt | 20 ++++++++++---------- tool.Dockerfile | 2 +- 8 files changed, 30 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af8f2bc2..a9329a64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,8 @@ repos: rev: 22.3.0 hooks: - id: black - language_version: python3.8 - args: ["--target-version", "py38"] + language_version: python3.9 + args: ["--target-version", "py39"] - repo: https://github.com/PyCQA/flake8 rev: 3.8.4 hooks: diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 7eec359d..d9f0965b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.8" + python: "3.9" apt_packages: - swig diff --git a/README.md b/README.md index 8b321c6b..249de523 100644 --- a/README.md +++ b/README.md @@ -164,11 +164,13 @@ Issues and contributions are welcome. Please refer to the [contribution guide](h ## Citing QDax If you use QDax in your research and want to cite it in your work, please use: ``` -@article{lim2022accelerated, - title={Accelerated Quality-Diversity for Robotics through Massive Parallelism}, - author={Lim, Bryan and Allard, Maxime and Grillotti, Luca and Cully, Antoine}, - journal={arXiv preprint arXiv:2202.01258}, - year={2022} +@misc{chalumeau2023qdax, + title={QDax: A Library for Quality-Diversity and Population-based Algorithms with Hardware Acceleration}, + author={Felix Chalumeau and Bryan Lim and Raphael Boige and Maxime Allard and Luca Grillotti and Manon Flageat and Valentin Macé and Arthur Flajolet and Thomas Pierrot and Antoine Cully}, + year={2023}, + eprint={2308.03665}, + archivePrefix={arXiv}, + primaryClass={cs.AI} } ``` diff --git a/dev.Dockerfile b/dev.Dockerfile index a6ac351f..458599db 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,4 +1,4 @@ -FROM mambaorg/micromamba:0.22.0 as conda +FROM mambaorg/micromamba:1.5.1 as conda # Speed up the build, and avoid unnecessary writes to disk ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 CONDA_DIR=/opt/conda @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \ FROM python as test-image -ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ @@ -25,8 +25,8 @@ COPY requirements-dev.txt ./ RUN pip install -r requirements-dev.txt -FROM nvidia/cuda:11.4.1-cudnn8-devel-ubuntu20.04 as cuda-image -ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app +FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image +ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH @@ -40,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l ENV TZ=Europe/Paris RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone -RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \ +RUN pip --no-cache-dir install jaxlib==0.4.16+cuda11.cudnn86 \ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ && rm -rf /tmp/* @@ -70,7 +70,7 @@ RUN apt-get update && \ libosmesa6-dev \ patchelf \ python3-opengl \ - python3-dev=3.8* \ + python3-dev=3.9* \ python3-pip \ screen \ sudo \ diff --git a/docs/installation.md b/docs/installation.md index 319a88c7..10e45ea1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -139,7 +139,7 @@ git clone git@github.com:adaptive-intelligent-robotics/QDax.git 2. Activate the environment and manually install the package qdax ```zsh - conda activate qdaxpy38 + conda activate qdaxpy39 pip install -e . ``` diff --git a/environment.yaml b/environment.yaml index e46c034e..0ddf80d5 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,9 +1,9 @@ -name: qdaxpy38 +name: qdaxpy39 channels: - defaults - conda-forge dependencies: -- python=3.8 +- python=3.9 - pip>=20.3.3 - conda>=4.9.2 - pip: diff --git a/requirements.txt b/requirements.txt index 16c91bc3..40626724 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,19 @@ absl-py==1.0.0 -brax==0.0.15 -chex==0.1.5 -dm-haiku==0.0.5 -flax==0.6.0 -gym==0.23.1 +brax==0.1.2 +chex==0.1.83 +dm-haiku==0.0.10 +flax==0.7.4 +gym==0.26.2 ipython -jax==0.3.17 -jaxlib==0.3.15 +jax==0.4.16 +jaxlib==0.4.16 jumanji==0.1.3 jupyter -numpy==1.22.3 -optax==0.1.4 +numpy==1.24.1 +optax==0.1.7 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 seaborn==0.11.2 -tensorflow-probability==0.15.0 +tensorflow-probability==0.19.0 typing-extensions==4.3.0 diff --git a/tool.Dockerfile b/tool.Dockerfile index 2def88da..10b15b02 100644 --- a/tool.Dockerfile +++ b/tool.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.8.13-slim +FROM python:3.9.18-slim ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1 From 68a41d37c44eb5e1afb3e266702d6874e8bff3e4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 28 Nov 2023 21:42:26 +0900 Subject: [PATCH 13/16] update ci --- .github/workflows/ci.yaml | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d1e0c060..7b08ab6c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -33,7 +33,14 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 + if: ${{ github.event_name == 'push' }} + + - name: Checkout + uses: actions/checkout@v3 + if: ${{ github.event_name == 'pull_request_target' }} + with: + ref: "${{ github.event.pull_request.merge_commit_sha }}" - name: Set up Docker Buildx id: buildx @@ -99,7 +106,14 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 + if: ${{ github.event_name == 'push' }} + + - name: Checkout + uses: actions/checkout@v3 + if: ${{ github.event_name == 'pull_request_target' }} + with: + ref: "${{ github.event.pull_request.merge_commit_sha }}" - name: Run pre-commits run: | @@ -117,7 +131,14 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 + if: ${{ github.event_name == 'push' }} + + - name: Checkout + uses: actions/checkout@v3 + if: ${{ github.event_name == 'pull_request_target' }} + with: + ref: "${{ github.event.pull_request.merge_commit_sha }}" - name: Run pytests run: | From a217319c553d52d18c714de75edb6d15e0d8e339 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 28 Nov 2023 22:25:28 +0900 Subject: [PATCH 14/16] update ci --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7b08ab6c..83c688df 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -138,7 +138,7 @@ jobs: uses: actions/checkout@v3 if: ${{ github.event_name == 'pull_request_target' }} with: - ref: "${{ github.event.pull_request.merge_commit_sha }}" + ref: ${{ github.event.pull_request.head.sha }} - name: Run pytests run: | From dd6d71d0e4f14ac3498fdb2aa182cdd25d25c845 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 29 Nov 2023 19:26:38 +0900 Subject: [PATCH 15/16] update ci with new checkout test --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 83c688df..3723e950 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -138,7 +138,7 @@ jobs: uses: actions/checkout@v3 if: ${{ github.event_name == 'pull_request_target' }} with: - ref: ${{ github.event.pull_request.head.sha }} + ref: "refs/pull/${{ github.event.number }}/merge" - name: Run pytests run: | From a190a626f8a2f1276c7377ec03347abefd9d1bbd Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 30 Nov 2023 20:06:59 +0900 Subject: [PATCH 16/16] update README for CI test --- .github/workflows/ci.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3723e950..33a1fd93 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -40,7 +40,7 @@ jobs: uses: actions/checkout@v3 if: ${{ github.event_name == 'pull_request_target' }} with: - ref: "${{ github.event.pull_request.merge_commit_sha }}" + ref: ${{ github.event.pull_request.head.sha }} - name: Set up Docker Buildx id: buildx @@ -113,7 +113,7 @@ jobs: uses: actions/checkout@v3 if: ${{ github.event_name == 'pull_request_target' }} with: - ref: "${{ github.event.pull_request.merge_commit_sha }}" + ref: ${{ github.event.pull_request.head.sha }} - name: Run pre-commits run: | @@ -138,7 +138,7 @@ jobs: uses: actions/checkout@v3 if: ${{ github.event_name == 'pull_request_target' }} with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.event.pull_request.head.sha }} - name: Run pytests run: |