Skip to content

Commit

Permalink
fix all notebooks and update typing extensions (for running notebooks)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Aug 27, 2024
1 parent 1191a14 commit 9bf9177
Show file tree
Hide file tree
Showing 15 changed files with 36 additions and 73 deletions.
7 changes: 2 additions & 5 deletions examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "venv",
"language": "python",
"name": "python3"
},
Expand All @@ -530,7 +527,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
4 changes: 2 additions & 2 deletions examples/cmame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"def clip(x: jnp.ndarray):\n",
" in_bound = (x <= maxval) * (x >= minval)\n",
" return jnp.where(\n",
" condition=in_bound,\n",
" in_bound,\n",
" x,\n",
" (maxval / x)\n",
" )\n",
Expand Down Expand Up @@ -387,7 +387,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
2 changes: 1 addition & 1 deletion examples/cmamega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
8 changes: 1 addition & 7 deletions examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down Expand Up @@ -554,7 +548,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
8 changes: 1 addition & 7 deletions examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down Expand Up @@ -544,7 +538,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
Expand Down
12 changes: 8 additions & 4 deletions examples/distributed_mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
"from IPython.display import clear_output\n",
"import functools\n",
"\n",
"from tqdm import tqdm\n",
"try:\n",
" from tqdm import tqdm\n",
"except:\n",
" !pip install tqdm | tail -n 1\n",
" from tqdm import tqdm\n",
"\n",
"import time\n",
"\n",
"import jax\n",
Expand Down Expand Up @@ -128,8 +133,7 @@
"outputs": [],
"source": [
"# Get devices (change gpu by tpu if needed)\n",
"# devices = jax.devices('gpu')\n",
"devices = jax.devices('tpu')\n",
"devices = jax.devices('gpu')\n",
"num_devices = len(devices)\n",
"print(f'Detected the following {num_devices} device(s): {devices}')"
]
Expand Down Expand Up @@ -454,7 +458,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions examples/jumanji_snake.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
"population_size = 100\n",
"batch_size = population_size\n",
"\n",
"num_iterations = 5000\n",
"num_iterations = 1000\n",
"\n",
"iso_sigma = 0.005\n",
"line_sigma = 0.05"
Expand Down Expand Up @@ -177,7 +177,7 @@
"outputs": [],
"source": [
"def observation_processing(observation):\n",
" network_input = jnp.ravel(observation)\n",
" network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])\n",
" return network_input\n",
"\n",
"\n",
Expand Down Expand Up @@ -207,7 +207,7 @@
" obs=timestep.observation,\n",
" next_obs=next_timestep.observation,\n",
" rewards=next_timestep.reward,\n",
" dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),\n",
" dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),\n",
" actions=action,\n",
" truncations=jnp.array(0),\n",
" state_desc=state_desc,\n",
Expand Down Expand Up @@ -240,7 +240,7 @@
"\n",
"# compute observation size from observation spec\n",
"obs_spec = env.observation_spec()\n",
"observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))\n",
"observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))\n",
"\n",
"fake_batch = jnp.zeros(shape=(batch_size, observation_size))\n",
"init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n",
Expand Down
11 changes: 3 additions & 8 deletions examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down Expand Up @@ -80,7 +74,8 @@
"metadata": {},
"outputs": [],
"source": [
"devices = jax.devices(\"tpu\")\n",
"# Get devices (change gpu by tpu if needed)\n",
"devices = jax.devices('gpu')\n",
"num_devices = len(devices)\n",
"print(f\"Detected the following {num_devices} device(s): {devices}\")"
]
Expand Down Expand Up @@ -261,7 +256,7 @@
" lambda x: jnp.repeat(x, population_size, axis=0), first_states\n",
" )\n",
" population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n",
" return population_returns, population_bds, None, random_key"
" return population_returns, population_bds, {}, random_key"
]
},
{
Expand Down
15 changes: 5 additions & 10 deletions examples/me_td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down Expand Up @@ -82,7 +76,8 @@
"metadata": {},
"outputs": [],
"source": [
"devices = jax.devices(\"tpu\")\n",
"# Get devices (change gpu by tpu if needed)\n",
"devices = jax.devices('gpu')\n",
"num_devices = len(devices)\n",
"print(f\"Detected the following {num_devices} device(s): {devices}\")"
]
Expand Down Expand Up @@ -264,7 +259,7 @@
" lambda x: jnp.repeat(x, population_size, axis=0), first_states\n",
" )\n",
" population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n",
" return population_returns, population_bds, None, random_key"
" return population_returns, population_bds, {}, random_key"
]
},
{
Expand Down Expand Up @@ -443,7 +438,7 @@
"num_cols = 5\n",
"\n",
"fig, axes = plt.subplots(\n",
" nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30)\n",
" nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30), squeeze=False,\n",
")\n",
"for i, repertoire in enumerate(repertoires):\n",
"\n",
Expand Down Expand Up @@ -492,7 +487,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/pgame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
"#@markdown ---\n",
"env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n",
"episode_length = 250 #@param {type:\"integer\"}\n",
"num_iterations = 4000 #@param {type:\"integer\"}\n",
"num_iterations = 1000 #@param {type:\"integer\"}\n",
"seed = 42 #@param {type:\"integer\"}\n",
"policy_hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n",
"iso_sigma = 0.005 #@param {type:\"number\"}\n",
Expand Down
13 changes: 4 additions & 9 deletions examples/sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down Expand Up @@ -95,7 +89,8 @@
},
"outputs": [],
"source": [
"devices = jax.devices(\"tpu\")\n",
"# Get devices (change gpu by tpu if needed)\n",
"devices = jax.devices('gpu')\n",
"num_devices = len(devices)\n",
"print(f\"Detected the following {num_devices} device(s): {devices}\")"
]
Expand Down Expand Up @@ -123,8 +118,8 @@
"buffer_size = 100000\n",
"\n",
"# PBT Config\n",
"num_best_to_replace_from = 20\n",
"num_worse_to_replace = 40\n",
"num_best_to_replace_from = 1\n",
"num_worse_to_replace = 1\n",
"\n",
"# SAC config\n",
"batch_size = 256\n",
Expand Down
6 changes: 0 additions & 6 deletions examples/smerl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down
9 changes: 2 additions & 7 deletions examples/td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@
" import jumanji\n",
"\n",
"try:\n",
" import haiku\n",
"except:\n",
" !pip install git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import haiku\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand Down Expand Up @@ -84,7 +78,8 @@
},
"outputs": [],
"source": [
"devices = jax.devices(\"gpu\")\n",
"# Get devices (change gpu by tpu if needed)\n",
"devices = jax.devices('gpu')\n",
"num_devices = len(devices)\n",
"print(f\"Detected the following {num_devices} device(s): {devices}\")"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ protobuf==3.19.4
scikit-learn==1.5.1
scipy==1.10.1
tensorflow-probability==0.24.0
typing-extensions==4.3.0
typing-extensions==4.12.2

0 comments on commit 9bf9177

Please sign in to comment.