Skip to content

Commit

Permalink
Hotfix: libraries installation in Colab notebooks (#193)
Browse files Browse the repository at this point in the history
* fix installations of dependencies of QDax in all notebooks
* add missing colab parameters and update description for DCRL
  • Loading branch information
Lookatator authored Sep 9, 2024
1 parent 8c9d2ec commit f16b0da
Show file tree
Hide file tree
Showing 23 changed files with 532 additions and 812 deletions.
51 changes: 19 additions & 32 deletions examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,25 @@
"metadata": {},
"outputs": [],
"source": [
"#@title Installs and Imports\n",
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"!pip install ipympl |tail -n 1\n",
"# %matplotlib widget\n",
"# from google.colab import output\n",
Expand All @@ -46,37 +64,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" import qdax\n",
"\n",
"\n",
"from qdax.core.aurora import AURORA\n",
"from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n",
"from qdax import environments\n",
Expand Down
71 changes: 30 additions & 41 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,28 @@
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"\n",
Expand All @@ -71,7 +60,7 @@
},
{
"cell_type": "markdown",
"id": "3",
"id": "4",
"metadata": {},
"source": [
"## Set the hyperparameters"
Expand All @@ -80,7 +69,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"id": "5",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -98,7 +87,7 @@
},
{
"cell_type": "markdown",
"id": "5",
"id": "6",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -111,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"id": "7",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -133,7 +122,7 @@
},
{
"cell_type": "markdown",
"id": "7",
"id": "8",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -146,7 +135,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"id": "9",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand All @@ -167,7 +156,7 @@
},
{
"cell_type": "markdown",
"id": "9",
"id": "10",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -180,7 +169,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"id": "11",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand All @@ -194,7 +183,7 @@
},
{
"cell_type": "markdown",
"id": "11",
"id": "12",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -207,7 +196,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"id": "13",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand Down Expand Up @@ -245,7 +234,7 @@
},
{
"cell_type": "markdown",
"id": "13",
"id": "14",
"metadata": {},
"source": [
"## Check final fitnesses and distribution mean"
Expand All @@ -254,7 +243,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"id": "15",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -272,7 +261,7 @@
},
{
"cell_type": "markdown",
"id": "15",
"id": "16",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -285,7 +274,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"id": "17",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand Down
48 changes: 18 additions & 30 deletions examples/cmame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
"- how to visualise the optimization process"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -38,36 +56,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" import qdax\n",
"\n",
"from qdax.core.map_elites import MAPElites\n",
"from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter\n",
"from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter\n",
Expand Down
42 changes: 15 additions & 27 deletions examples/cmamega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,27 @@
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from qdax.core.map_elites import MAPElites\n",
"from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter\n",
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n",
Expand Down
Loading

0 comments on commit f16b0da

Please sign in to comment.