diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index a6a140fd..3ab28d85 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "233e0f03", + "id": "0", "metadata": {}, "source": [ "# Training a population on Jumanji-Snake with QDax\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47b46c2f", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "03c2f1f7", + "id": "2", "metadata": {}, "source": [ "## Define hyperparameters" @@ -87,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52dd1e3b", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +105,7 @@ }, { "cell_type": "markdown", - "id": "8b8c890a", + "id": "4", "metadata": {}, "source": [ "## Instantiate the snake environment" @@ -114,7 +114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a842cccc", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "markdown", - "id": "776862f1", + "id": "6", "metadata": {}, "source": [ "## Define the type of policy that will be used to solve the problem" @@ -141,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a1ce7d0", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +161,7 @@ }, { "cell_type": "markdown", - "id": "49586b07", + "id": "8", "metadata": {}, "source": [ "## Utils to interact with the environment\n", @@ -172,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1ff7827", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -219,7 +219,7 @@ }, { "cell_type": "markdown", - "id": "0078bc01", + "id": "10", "metadata": {}, "source": [ "## Init a population of policies\n", @@ -230,7 +230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6cbd2065", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -255,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "fe6bf07f", + "id": "12", "metadata": {}, "source": [ "## Define a method to extract behavior descriptor when relevant" @@ -264,7 +264,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a264b672", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +311,7 @@ }, { "cell_type": "markdown", - "id": "1cdc5f87", + "id": "14", "metadata": {}, "source": [ "## Define the scoring function" @@ -320,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b77d826", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "6555491a", + "id": "16", "metadata": {}, "source": [ "## Define the emitter used" @@ -342,7 +342,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30061ff4", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "da7e9b74", + "id": "18", "metadata": {}, "source": [ "## Define the algorithm used and apply the initial step\n", @@ -371,7 +371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7b5c2d6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ }, { "cell_type": "markdown", - "id": "9b1bfee5", + "id": "20", "metadata": {}, "source": [ "## Run the optimization loop" @@ -424,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1af3a35", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "114ea4a8", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "92a35bf0", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79ada2d5", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -472,7 +472,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe5da301", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "93d8154e", + "id": "26", "metadata": {}, "source": [ "## Play snake with the best policy\n", @@ -500,7 +500,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3ff882f4", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -511,7 +511,7 @@ { "cell_type": "code", "execution_count": null, - "id": "762c167e", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -524,7 +524,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07523e33", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -537,7 +537,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c75ce088", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -550,7 +550,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50ef95f6", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +563,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40a03409", + "id": "32", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/mome.ipynb b/examples/mome.ipynb index bf0a5225..7e28b608 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "59f748d3", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "a5e13ff6", + "id": "1", "metadata": {}, "source": [ "# Optimizing multiple objectives with MOME in Jax\n", @@ -28,7 +28,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af063418", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "markdown", - "id": "22495c16", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b96b5d07", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ }, { "cell_type": "markdown", - "id": "c2850d54", + "id": "5", "metadata": {}, "source": [ "## Define the scoring function: rastrigin multi-objective\n", @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5effe11", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231d273d", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "29250e72", + "id": "8", "metadata": {}, "source": [ "## Define the metrics function that will be used" @@ -187,7 +187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab5d6334", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "a4828ca8", + "id": "10", "metadata": {}, "source": [ "## Define the initial population and the emitter" @@ -211,7 +211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebf3bd27", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "c904664b", + "id": "12", "metadata": {}, "source": [ "## Compute the centroids" @@ -257,7 +257,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76547c4c", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "markdown", - "id": "15936d15", + "id": "14", "metadata": {}, "source": [ "## Define a MOME instance" @@ -282,7 +282,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07a0d1d9", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "f7ec5a77", + "id": "16", "metadata": {}, "source": [ "## Init the algorithm" @@ -304,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c05cbf1e", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +318,7 @@ }, { "cell_type": "markdown", - "id": "6de4cedf", + "id": "18", "metadata": {}, "source": [ "## Run MOME iterations" @@ -327,7 +327,7 @@ { "cell_type": "code", "execution_count": null, - "id": "96ea04e6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "3ff9ca98", + "id": "20", "metadata": {}, "source": [ "## Plot the results" @@ -353,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6766dc4f", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28ab56c9", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab36cb7", + "id": "23", "metadata": {}, "outputs": [], "source": [ diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 4dab2d73..f67d7b4f 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -12,7 +12,6 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.environments.bd_extractors import AuroraExtraInfo from qdax.custom_types import ( Descriptor, Fitness, @@ -22,6 +21,7 @@ Params, RNGKey, ) +from qdax.environments.bd_extractors import AuroraExtraInfo class AURORA: diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index e0b22443..b473d4b3 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -15,7 +15,14 @@ from numpy.random import RandomState from sklearn.cluster import KMeans -from qdax.custom_types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) def compute_cvt_centroids( diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py index 285db940..7ef57bb9 100644 --- a/qdax/core/containers/mels_repertoire.py +++ b/qdax/core/containers/mels_repertoire.py @@ -14,7 +14,14 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.custom_types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Spread, +) def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 7f66c5d0..8512d3d6 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -8,7 +8,14 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from qdax.custom_types import Centroid, Descriptor, Fitness, Genotype, Observation, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + Fitness, + Genotype, + Observation, + RNGKey, +) @partial(jax.jit, static_argnames=("k_nn",)) diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index 429895d6..315dcd9b 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -13,7 +13,14 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.custom_types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class CMAEmitterState(EmitterState): diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py index eaee1dd4..fea237c6 100644 --- a/qdax/core/emitters/dcg_me_emitter.py +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import Params, RNGKey +from qdax.environments.base_wrappers import QDEnv @dataclass diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index fee069eb..ea921237 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -18,7 +18,6 @@ QualityPGEmitterState, ) from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import ( Descriptor, ExtraScores, @@ -29,6 +28,7 @@ RNGKey, StateDescriptor, ) +from qdax.environments.base_wrappers import QDEnv @dataclass diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index f8863c3b..580bd151 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -6,7 +6,14 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.custom_types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class OMGMEGAEmitterState(EmitterState): diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index deefda9e..55bded4e 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -12,8 +12,8 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey +from qdax.environments.base_wrappers import QDEnv class PBTEmitterState(EmitterState): diff --git a/qdax/core/emitters/pga_me_emitter.py b/qdax/core/emitters/pga_me_emitter.py index 8b80c53b..a4f8b33f 100644 --- a/qdax/core/emitters/pga_me_emitter.py +++ b/qdax/core/emitters/pga_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import Params, RNGKey +from qdax.environments.base_wrappers import QDEnv @dataclass diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index f36f542b..0fb19c4b 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -16,8 +16,8 @@ from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn from qdax.core.neuroevolution.networks.networks import QModuleDC -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey +from qdax.environments.base_wrappers import QDEnv @dataclass diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index 1bbd64eb..b9de6090 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -18,8 +18,8 @@ from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import Reward, StateDescriptor +from qdax.environments.base_wrappers import QDEnv @dataclass diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index e8d3914e..c6e2df7e 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -17,8 +17,8 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn from qdax.core.neuroevolution.networks.networks import QModule -from qdax.environments.base_wrappers import QDEnv from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey +from qdax.environments.base_wrappers import QDEnv @dataclass diff --git a/qdax/core/neuroevolution/losses/dads_loss.py b/qdax/core/neuroevolution/losses/dads_loss.py index 62b18e78..60edfee1 100644 --- a/qdax/core/neuroevolution/losses/dads_loss.py +++ b/qdax/core/neuroevolution/losses/dads_loss.py @@ -6,7 +6,14 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.custom_types import Action, Observation, Params, RNGKey, Skill, StateDescriptor +from qdax.custom_types import ( + Action, + Observation, + Params, + RNGKey, + Skill, + StateDescriptor, +) def make_dads_loss_fn( diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py index e8494e20..59108ae5 100644 --- a/qdax/tasks/qd_suite/archimedean_spiral.py +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -4,8 +4,8 @@ import jax.lax import jax.numpy as jnp -from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask from qdax.custom_types import Descriptor, Fitness, Genotype +from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask class ParameterizationGenotype(Enum): diff --git a/qdax/tasks/qd_suite/deceptive_evolvability.py b/qdax/tasks/qd_suite/deceptive_evolvability.py index f3b5f1d8..830ad523 100644 --- a/qdax/tasks/qd_suite/deceptive_evolvability.py +++ b/qdax/tasks/qd_suite/deceptive_evolvability.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp -from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask from qdax.custom_types import Descriptor, Fitness, Genotype +from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask def multivariate_normal( diff --git a/qdax/tasks/qd_suite/ssf.py b/qdax/tasks/qd_suite/ssf.py index 783babda..601aa6ad 100644 --- a/qdax/tasks/qd_suite/ssf.py +++ b/qdax/tasks/qd_suite/ssf.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp -from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask from qdax.custom_types import Descriptor, Fitness, Genotype +from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask class SsfV0(QDSuiteTask): diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 6e8081fb..bd9570a9 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -16,8 +16,8 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq -from qdax.environments.bd_extractors import AuroraExtraInfoNormalization from qdax.custom_types import Params, RNGKey +from qdax.environments.bd_extractors import AuroraExtraInfoNormalization Array = Any PRNGKey = Any diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index ce77ee07..d1913b02 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.custom_types import EnvState, Params, RNGKey +from qdax.tasks.brax_envs import scoring_function_brax_envs def test_mees() -> None: diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 3468e3ab..0490a481 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -15,8 +15,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.custom_types import EnvState, Params, RNGKey +from qdax.tasks.brax_envs import scoring_function_brax_envs def test_pgame() -> None: diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 3a4bdb03..704416a4 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -17,8 +17,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.custom_types import EnvState, Params, RNGKey +from qdax.tasks.brax_envs import scoring_function_brax_envs def test_qdpg() -> None: diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 13149e36..4bbb9d82 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -11,6 +11,7 @@ from qdax import environments from qdax.core.aurora import AURORA from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.custom_types import Observation from qdax.environments.bd_extractors import ( AuroraExtraInfoNormalization, get_aurora_encoding, @@ -19,7 +20,6 @@ create_default_brax_task_components, get_aurora_scoring_fn, ) -from qdax.custom_types import Observation from qdax.utils import train_seq2seq from qdax.utils.metrics import default_qd_metrics from tests.core_test.map_elites_test import get_mixing_emitter diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 277d9df9..c89ce04f 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.custom_types import EnvState, Params, RNGKey +from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.utils.metrics import default_qd_metrics diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index ac0b563a..66bcc05f 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -15,8 +15,8 @@ from qdax.core.mels import MELS from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs from qdax.custom_types import EnvState, Params, RNGKey +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs @pytest.mark.parametrize( diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index 6683208e..737a4a17 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -11,11 +11,11 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Descriptor, Observation from qdax.tasks.jumanji_envs import ( jumanji_scoring_function, make_policy_network_play_step_fn_jumanji, ) -from qdax.custom_types import Descriptor, Observation def test_jumanji_utils() -> None: diff --git a/tests/environments_test/pointmaze_test.py b/tests/environments_test/pointmaze_test.py index 8c42fe4f..ecc97864 100644 --- a/tests/environments_test/pointmaze_test.py +++ b/tests/environments_test/pointmaze_test.py @@ -6,8 +6,8 @@ from brax.v1.envs import Env import qdax -from qdax.environments.pointmaze import PointMaze from qdax.custom_types import EnvState +from qdax.environments.pointmaze import PointMaze def test_pointmaze() -> None: diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index ee66503e..8d19379e 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -8,8 +8,8 @@ from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.custom_types import EnvState, Params, RNGKey +from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.utils.sampling import ( average, closest,