diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index d7b30b1d..ba059cdc 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -30,39 +30,28 @@ "metadata": {}, "outputs": [], "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Ellipse\n", "\n", @@ -71,7 +60,7 @@ }, { "cell_type": "markdown", - "id": "3", + "id": "4", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -80,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +87,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": { "pycharm": { "name": "#%% md\n" @@ -111,7 +100,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -133,7 +122,7 @@ }, { "cell_type": "markdown", - "id": "7", + "id": "8", "metadata": { "pycharm": { "name": "#%% md\n" @@ -146,7 +135,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -167,7 +156,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "10", "metadata": { "pycharm": { "name": "#%% md\n" @@ -180,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -194,7 +183,7 @@ }, { "cell_type": "markdown", - "id": "11", + "id": "12", "metadata": { "pycharm": { "name": "#%% md\n" @@ -207,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -245,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "13", + "id": "14", "metadata": {}, "source": [ "## Check final fitnesses and distribution mean" @@ -254,7 +243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +261,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "16", "metadata": { "pycharm": { "name": "#%% md\n" @@ -285,7 +274,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/dads.ipynb b/examples/dads.ipynb index 47abd1ec..57d1df05 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -28,7 +28,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -42,37 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax import environments\n", "from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 8e085fce..cdee8b4b 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -28,7 +28,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -42,36 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "\n", "from qdax import environments\n", "from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index fc6fbe8b..86deebc4 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -1,5 +1,23 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -13,36 +31,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "import optax\n", "from brax.v1.io import html\n", "from IPython.display import HTML\n", @@ -54,7 +42,7 @@ "from qdax.core.distributed_map_elites import DistributedMAPElites\n", "from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n", "from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn\n", - "from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n", + "from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n", "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", "from qdax.utils.plotting import plot_map_elites_results" ] diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index f28a1db4..f72ccda1 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -1,5 +1,23 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -14,36 +32,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "import matplotlib.pyplot as plt\n", "from brax.v1.io import html\n", "from IPython.display import HTML\n", @@ -56,7 +44,7 @@ "from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n", "from qdax.core.emitters.pbt_variation_operators import td3_pbt_variation_fn\n", "from qdax.core.distributed_map_elites import DistributedMAPElites\n", - "from qdax.types import RNGKey\n", + "from qdax.custom_types import RNGKey\n", "from qdax.utils.metrics import default_qd_metrics\n", "from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_map_elites_results" ] diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 7e28b608..0d005dbe 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -32,42 +32,31 @@ "metadata": {}, "outputs": [], "source": [ - "import jax.numpy as jnp\n", - "import jax\n", - "from typing import Tuple\n", - "\n", - "from functools import partial\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax\n", + "from typing import Tuple\n", + "\n", + "from functools import partial\n", + "\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax.core.mome import MOME\n", "from qdax.core.emitters.mutation_operators import (\n", @@ -81,12 +70,12 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "from qdax.types import Fitness, Descriptor, RNGKey, ExtraScores" + "from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores" ] }, { "cell_type": "markdown", - "id": "3", + "id": "4", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -95,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +108,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": {}, "source": [ "## Define the scoring function: rastrigin multi-objective\n", @@ -130,7 +119,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +154,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ "## Define the metrics function that will be used" @@ -187,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "11", "metadata": {}, "source": [ "## Define the initial population and the emitter" @@ -211,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "13", "metadata": {}, "source": [ "## Compute the centroids" @@ -257,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +262,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "15", "metadata": {}, "source": [ "## Define a MOME instance" @@ -282,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "17", "metadata": {}, "source": [ "## Init the algorithm" @@ -304,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +307,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "19", "metadata": {}, "source": [ "## Run MOME iterations" @@ -327,7 +316,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "21", "metadata": {}, "source": [ "## Plot the results" @@ -353,7 +342,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +352,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +380,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "24", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index a484b035..915cc272 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -4,6 +4,25 @@ "cell_type": "code", "execution_count": null, "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", "metadata": { "jupyter": { "outputs_hidden": false @@ -19,36 +38,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from brax.v1.io import html\n", "from IPython.display import HTML\n", "from tqdm import tqdm\n", @@ -61,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1", + "id": "2", "metadata": { "jupyter": { "outputs_hidden": false @@ -78,7 +67,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": { "jupyter": { "outputs_hidden": false @@ -98,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": { "jupyter": { "outputs_hidden": false @@ -139,7 +128,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": { "jupyter": { "outputs_hidden": false @@ -170,7 +159,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", "metadata": { "jupyter": { "outputs_hidden": false @@ -209,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": { "jupyter": { "outputs_hidden": false @@ -232,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": { "jupyter": { "outputs_hidden": false @@ -261,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": { "jupyter": { "outputs_hidden": false @@ -288,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": { "jupyter": { "outputs_hidden": false @@ -306,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": { "jupyter": { "outputs_hidden": false @@ -331,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": { "jupyter": { "outputs_hidden": false @@ -357,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": { "jupyter": { "outputs_hidden": false @@ -379,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": { "jupyter": { "outputs_hidden": false @@ -402,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": { "jupyter": { "outputs_hidden": false @@ -442,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -456,7 +445,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" @@ -472,7 +461,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": { "pycharm": { "name": "#%%\n" @@ -486,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": { "pycharm": { "name": "#%%\n" @@ -504,7 +493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": { "pycharm": { "name": "#%%\n" @@ -518,7 +507,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "21", "metadata": { "pycharm": { "name": "#%%\n" @@ -547,7 +536,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": { "pycharm": { "name": "#%%\n" @@ -564,7 +553,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index d50f448f..0e332192 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -28,7 +28,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -42,37 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax import environments\n", "from qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig, DiaynTrainingState\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 484f6d12..d2d98f85 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -4,6 +4,25 @@ "cell_type": "code", "execution_count": null, "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", "metadata": { "pycharm": { "name": "#%%\n" @@ -16,36 +35,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from tqdm import tqdm\n", "\n", "from qdax import environments\n", @@ -56,7 +45,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1", + "id": "2", "metadata": { "pycharm": { "name": "#%%\n" @@ -70,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": { "pycharm": { "name": "#%%\n" @@ -87,7 +76,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": { "pycharm": { "name": "#%%\n" @@ -124,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": { "pycharm": { "name": "#%%\n" @@ -152,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", "metadata": { "pycharm": { "name": "#%%\n" @@ -183,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": { "pycharm": { "name": "#%%\n" @@ -203,7 +192,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -227,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -251,7 +240,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -266,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -288,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -311,7 +300,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -330,7 +319,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": { "pycharm": { "name": "#%%\n" @@ -350,7 +339,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" @@ -387,7 +376,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n"