Skip to content

Commit

Permalink
fix bad merge of notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 10, 2024
1 parent 6238d83 commit 433e6be
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 347 deletions.
71 changes: 30 additions & 41 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,28 @@
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"\n",
Expand All @@ -71,7 +60,7 @@
},
{
"cell_type": "markdown",
"id": "3",
"id": "4",
"metadata": {},
"source": [
"## Set the hyperparameters"
Expand All @@ -80,7 +69,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"id": "5",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -98,7 +87,7 @@
},
{
"cell_type": "markdown",
"id": "5",
"id": "6",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -111,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"id": "7",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -133,7 +122,7 @@
},
{
"cell_type": "markdown",
"id": "7",
"id": "8",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -146,7 +135,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"id": "9",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand All @@ -167,7 +156,7 @@
},
{
"cell_type": "markdown",
"id": "9",
"id": "10",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -180,7 +169,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"id": "11",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand All @@ -194,7 +183,7 @@
},
{
"cell_type": "markdown",
"id": "11",
"id": "12",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -207,7 +196,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"id": "13",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand Down Expand Up @@ -245,7 +234,7 @@
},
{
"cell_type": "markdown",
"id": "13",
"id": "14",
"metadata": {},
"source": [
"## Check final fitnesses and distribution mean"
Expand All @@ -254,7 +243,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"id": "15",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -272,7 +261,7 @@
},
{
"cell_type": "markdown",
"id": "15",
"id": "16",
"metadata": {
"pycharm": {
"name": "#%% md\n"
Expand All @@ -285,7 +274,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"id": "17",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand Down
51 changes: 19 additions & 32 deletions examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,25 @@
"metadata": {},
"outputs": [],
"source": [
"#@title Installs and Imports\n",
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"!pip install ipympl |tail -n 1\n",
"# %matplotlib widget\n",
"# from google.colab import output\n",
Expand All @@ -42,37 +60,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" import qdax\n",
"\n",
"\n",
"from qdax import environments\n",
"from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
Expand Down
50 changes: 19 additions & 31 deletions examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,25 @@
"metadata": {},
"outputs": [],
"source": [
"#@title Installs and Imports\n",
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"!pip install ipympl |tail -n 1\n",
"# %matplotlib widget\n",
"# from google.colab import output\n",
Expand All @@ -42,36 +60,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" import qdax\n",
"\n",
"\n",
"from qdax import environments\n",
"from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n",
Expand Down
50 changes: 19 additions & 31 deletions examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" print(\"QDax not found. Installing...\")\n",
" !pip install qdax[cuda12]\n",
" import qdax\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -13,36 +31,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" import qdax\n",
"\n",
"import optax\n",
"from brax.v1.io import html\n",
"from IPython.display import HTML\n",
Expand All @@ -54,7 +42,7 @@
"from qdax.core.distributed_map_elites import DistributedMAPElites\n",
"from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n",
"from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn\n",
"from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n",
"from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n",
"from qdax.utils.metrics import CSVLogger, default_qd_metrics\n",
"from qdax.utils.plotting import plot_map_elites_results"
]
Expand Down
Loading

0 comments on commit 433e6be

Please sign in to comment.