diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index e4b86238..645ed911 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -512,11 +512,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -530,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index ab2d97fd..d7b30b1d 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -333,7 +333,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 125b739f..ff5fa5c2 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -141,7 +141,7 @@ "def clip(x: jnp.ndarray):\n", " in_bound = (x <= maxval) * (x >= minval)\n", " return jnp.where(\n", - " condition=in_bound,\n", + " in_bound,\n", " x,\n", " (maxval / x)\n", " )\n", @@ -387,7 +387,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index e5749993..739ac3d5 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -315,7 +315,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/dads.ipynb b/examples/dads.ipynb index b3cc43b5..47abd1ec 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -554,7 +548,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 0562e7c2..8e085fce 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -544,7 +538,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index e55041b6..ea2f9b9b 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -40,7 +40,12 @@ "from IPython.display import clear_output\n", "import functools\n", "\n", - "from tqdm import tqdm\n", + "try:\n", + " from tqdm import tqdm\n", + "except:\n", + " !pip install tqdm | tail -n 1\n", + " from tqdm import tqdm\n", + "\n", "import time\n", "\n", "import jax\n", @@ -128,8 +133,7 @@ "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", - "# devices = jax.devices('gpu')\n", - "devices = jax.devices('tpu')\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f'Detected the following {num_devices} device(s): {devices}')" ] @@ -454,7 +458,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 3ab28d85..bfba1e5a 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -97,7 +97,7 @@ "population_size = 100\n", "batch_size = population_size\n", "\n", - "num_iterations = 5000\n", + "num_iterations = 1000\n", "\n", "iso_sigma = 0.005\n", "line_sigma = 0.05" @@ -177,7 +177,7 @@ "outputs": [], "source": [ "def observation_processing(observation):\n", - " network_input = jnp.ravel(observation)\n", + " network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])\n", " return network_input\n", "\n", "\n", @@ -207,7 +207,7 @@ " obs=timestep.observation,\n", " next_obs=next_timestep.observation,\n", " rewards=next_timestep.reward,\n", - " dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),\n", + " dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),\n", " actions=action,\n", " truncations=jnp.array(0),\n", " state_desc=state_desc,\n", @@ -240,7 +240,7 @@ "\n", "# compute observation size from observation spec\n", "obs_spec = env.observation_spec()\n", - "observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))\n", + "observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))\n", "\n", "fake_batch = jnp.zeros(shape=(batch_size, observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index b2de823e..fc6fbe8b 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -38,12 +38,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -80,7 +74,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -261,7 +256,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 238f703c..f28a1db4 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -39,12 +39,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -82,7 +76,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -264,7 +259,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { @@ -443,7 +438,7 @@ "num_cols = 5\n", "\n", "fig, axes = plt.subplots(\n", - " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30)\n", + " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30), squeeze=False,\n", ")\n", "for i, repertoire in enumerate(repertoires):\n", "\n", @@ -492,7 +487,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 9b638b2d..03fd9c00 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -106,7 +106,7 @@ "#@markdown ---\n", "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", "episode_length = 250 #@param {type:\"integer\"}\n", - "num_iterations = 4000 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", "seed = 42 #@param {type:\"integer\"}\n", "policy_hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", "iso_sigma = 0.005 #@param {type:\"number\"}\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 3176b774..a484b035 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -44,12 +44,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -95,7 +89,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -123,8 +118,8 @@ "buffer_size = 100000\n", "\n", "# PBT Config\n", - "num_best_to_replace_from = 20\n", - "num_worse_to_replace = 40\n", + "num_best_to_replace_from = 1\n", + "num_worse_to_replace = 1\n", "\n", "# SAC config\n", "batch_size = 256\n", diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index fe655fe2..d50f448f 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index b96e13bc..484f6d12 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -41,12 +41,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -84,7 +78,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"gpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] diff --git a/requirements.txt b/requirements.txt index ea9cdd29..702b4e60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ protobuf==3.19.4 scikit-learn==1.5.1 scipy==1.10.1 tensorflow-probability==0.24.0 -typing-extensions==4.3.0 +typing-extensions==4.12.2