diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index a6a47894..e4b86238 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -49,10 +49,28 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -62,13 +80,20 @@ "from qdax.core.aurora import AURORA\n", "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", "from qdax import environments\n", - "from qdax.tasks.brax_envs import scoring_aurora_function\n", - "from qdax.environments.bd_extractors import get_aurora_bd\n", + "from qdax.tasks.brax_envs import (\n", + " create_default_brax_task_components,\n", + " get_aurora_scoring_fn,\n", + ")\n", + "from qdax.environments.bd_extractors import (\n", + " AuroraExtraInfoNormalization,\n", + " get_aurora_encoding,\n", + ")\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", "\n", + "from qdax.types import Observation\n", "from qdax.utils import train_seq2seq\n", "\n", "\n", @@ -184,7 +209,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -208,7 +233,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { @@ -218,19 +243,35 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = functools.partial(\n", - " get_aurora_bd,\n", - " option=observation_option,\n", - " hidden_size=hidden_size,\n", - " traj_sampling_freq=traj_sampling_freq,\n", - " max_observation_size=max_observation_size,\n", + "env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n", + " env_name=env_name,\n", + " random_key=random_key,\n", ")\n", - "scoring_fn = functools.partial(\n", - " scoring_aurora_function,\n", - " init_states=init_states,\n", - " episode_length=episode_length,\n", - " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + "\n", + "def observation_extractor_fn(\n", + " data: QDTransition,\n", + ") -> Observation:\n", + " \"\"\"Extract observation from the state.\"\"\"\n", + " state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size]\n", + "\n", + " # add the x/y position - (batch_size, traj_length, 2)\n", + " state_desc = data.state_desc[:, ::traj_sampling_freq]\n", + "\n", + " if observation_option == \"full\":\n", + " observations = jnp.concatenate([state_desc, state_obs], axis=-1)\n", + " elif observation_option == \"no_sd\":\n", + " observations = state_obs\n", + " elif observation_option == \"only_sd\":\n", + " observations = state_desc\n", + " else:\n", + " raise ValueError(\"Unknown observation option.\")\n", + "\n", + " return observations\n", + "\n", + "# Prepare the scoring function\n", + "aurora_scoring_fn = get_aurora_scoring_fn(\n", + " scoring_fn=scoring_fn,\n", + " observation_extractor_fn=observation_extractor_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -290,13 +331,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Instantiate AURORA\n", - "aurora = AURORA(\n", - " scoring_function=scoring_fn,\n", - " emitter=mixing_emitter,\n", - " metrics_function=metrics_fn,\n", - ")\n", - "\n", "aurora_dims = hidden_size\n", "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", "\n", @@ -306,9 +340,7 @@ " (\n", " repertoire,\n", " random_key,\n", - " model_params,\n", - " mean_observations,\n", - " std_observations,\n", + " aurora_extra_info\n", " ) = carry\n", "\n", " # update\n", @@ -316,13 +348,11 @@ " repertoire,\n", " None,\n", " random_key,\n", - " model_params,\n", - " mean_observations,\n", - " std_observations,\n", + " aurora_extra_info=aurora_extra_info,\n", " )\n", "\n", " return (\n", - " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " (repertoire, random_key, aurora_extra_info),\n", " metrics,\n", " )\n", "\n", @@ -344,12 +374,12 @@ "else:\n", " ValueError(\"The chosen option is not correct.\")\n", "\n", - "# define the seq2seq model\n", + "# Define the seq2seq model\n", "model = train_seq2seq.get_model(\n", " observations_dims[-1], True, hidden_size=hidden_size\n", ")\n", "\n", - "# init the model params\n", + "# Init the model params\n", "random_key, subkey = jax.random.split(random_key)\n", "model_params = train_seq2seq.get_initial_params(\n", " model, subkey, (1, *observations_dims)\n", @@ -357,25 +387,54 @@ "\n", "print(jax.tree_map(lambda x: x.shape, model_params))\n", "\n", + "# Define the encoder function\n", + "encoder_fn = jax.jit(\n", + " functools.partial(\n", + " get_aurora_encoding,\n", + " model=model,\n", + " )\n", + ")\n", + "\n", + "# Define the training function\n", + "train_fn = functools.partial(\n", + " train_seq2seq.lstm_ae_train,\n", + " model=model,\n", + " batch_size=lstm_batch_size,\n", + ")\n", + "\n", + "# Instantiate AURORA\n", + "aurora = AURORA(\n", + " scoring_function=aurora_scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=metrics_fn,\n", + " encoder_function=encoder_fn,\n", + " training_function=train_fn,\n", + ")\n", + "\n", "# define arbitrary observation's mean/std\n", "mean_observations = jnp.zeros(observations_dims[-1])\n", "std_observations = jnp.ones(observations_dims[-1])\n", "\n", - "# init step of the aurora algorithm\n", - "repertoire, _, random_key = aurora.init(\n", - " init_variables,\n", - " centroids,\n", - " random_key,\n", + "# init all the information needed by AURORA to compute encodings\n", + "aurora_extra_info = AuroraExtraInfoNormalization.create(\n", " model_params,\n", " mean_observations,\n", " std_observations,\n", - " l_value_init,\n", + ")\n", + "\n", + "# init step of the aurora algorithm\n", + "repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n", + " init_variables,\n", + " aurora_extra_info,\n", + " jnp.asarray(l_value_init),\n", + " max_observation_size,\n", + " random_key,\n", ")\n", "\n", "# initializing means and stds and AURORA\n", "random_key, subkey = jax.random.split(random_key)\n", - "model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n", - " subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n", + "repertoire, aurora_extra_info = aurora.train(\n", + " repertoire, model_params, iteration=0, random_key=subkey\n", ")\n", "\n", "# design aurora's schedule\n", @@ -409,11 +468,11 @@ "while iteration < max_iterations:\n", "\n", " (\n", - " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " (repertoire, random_key, aurora_extra_info),\n", " metrics,\n", " ) = jax.lax.scan(\n", " update_scan_fn,\n", - " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " (repertoire, random_key, aurora_extra_info),\n", " (),\n", " length=log_freq,\n", " )\n", @@ -427,60 +486,15 @@ " if (iteration + 1) in schedules:\n", " # train the autoencoder\n", " random_key, subkey = jax.random.split(random_key)\n", - " (\n", - " model_params,\n", - " mean_observations,\n", - " std_observations,\n", - " ) = train_seq2seq.lstm_ae_train(\n", - " subkey,\n", - " repertoire,\n", - " model_params,\n", - " iteration,\n", - " hidden_size=hidden_size,\n", - " batch_size=lstm_batch_size\n", + " repertoire, aurora_extra_info = aurora.train(\n", + " repertoire, model_params, iteration, subkey\n", " )\n", "\n", - " # re-addition of all the new behavioural descriotpors with the new ae\n", - " normalized_observations = (\n", - " repertoire.observations - mean_observations\n", - " ) / std_observations\n", - "\n", - " new_descriptors = model.apply(\n", - " {\"params\": model_params}, normalized_observations, method=model.encode\n", - " )\n", - " repertoire = repertoire.init(\n", - " genotypes=repertoire.genotypes,\n", - " centroids=repertoire.centroids,\n", - " fitnesses=repertoire.fitnesses,\n", - " descriptors=new_descriptors,\n", - " observations=repertoire.observations,\n", - " l_value=repertoire.l_value,\n", - " )\n", - " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", - "\n", " elif iteration % 2 == 0:\n", - " # update the l value\n", - " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", - "\n", - " # CVC Implementation to keep a constant number of individuals in the archive\n", - " current_error = num_indivs - n_target\n", - " change_rate = current_error - previous_error\n", - " prop_gain = 1 * 10e-6\n", - " l_value = (\n", - " repertoire.l_value\n", - " + (prop_gain * (current_error))\n", - " + (prop_gain * change_rate)\n", - " )\n", - "\n", - " previous_error = current_error\n", - "\n", - " repertoire = repertoire.init(\n", - " genotypes=repertoire.genotypes,\n", - " centroids=repertoire.centroids,\n", - " fitnesses=repertoire.fitnesses,\n", - " descriptors=repertoire.descriptors,\n", - " observations=repertoire.observations,\n", - " l_value=l_value,\n", + " repertoire, previous_error = aurora.container_size_control(\n", + " repertoire,\n", + " target_size=n_target,\n", + " previous_error=previous_error,\n", " )\n", "\n", " iteration += 1" diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index 5d6d8756..c8e2a9fe 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -33,6 +33,36 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Ellipse\n", "\n", @@ -193,23 +223,23 @@ "iteration_count = 0\n", "for _ in range(num_iterations):\n", " iteration_count += 1\n", - " \n", + "\n", " # sample\n", " samples, random_key = cmaes.sample(state, random_key)\n", - " \n", + "\n", " # udpate\n", " state = cmaes.update(state, samples)\n", - " \n", + "\n", " # check stop condition\n", " stop_condition = cmaes.stop_condition(state)\n", "\n", " if stop_condition:\n", " break\n", - " \n", + "\n", " # store data for plotting\n", " means.append(state.mean)\n", " covs.append((state.sigma**2) * state.cov_matrix)\n", - " \n", + "\n", "print(\"Num iterations before stop condition: \", iteration_count)" ] }, @@ -281,7 +311,7 @@ " ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')\n", " ax.add_patch(ellipse)\n", " ax.plot(mean[0], mean[1], color='k', marker='x')\n", - " \n", + "\n", "ax.set_title(f\"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}\")\n", "plt.show()" ] diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 1d7337d4..ec9a641c 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -35,25 +35,31 @@ "import matplotlib.cm as cm\n", "import matplotlib.pyplot as plt\n", "\n", - "import jax \n", + "import jax\n", "import jax.numpy as jnp\n", "\n", "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", " import chex\n", "\n", "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -91,7 +97,7 @@ "#@markdown ---\n", "num_iterations = 70000 #70000 #10000\n", "num_dimensions = 100 #1000 #@param {type:\"integer\"}\n", - "grid_shape = (500, 500) # (500, 500) \n", + "grid_shape = (500, 500) # (500, 500)\n", "batch_size = 36 #36 #@param {type:\"integer\"}\n", "sigma_g = .5 #@param {type:\"number\"}\n", "minval = -5.12 #@param {type:\"number\"}\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index 2e00d660..e5749993 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -29,25 +29,31 @@ "metadata": {}, "outputs": [], "source": [ - "import jax \n", + "import jax\n", "import jax.numpy as jnp\n", "\n", "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", " import chex\n", "\n", "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -209,10 +215,10 @@ "initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n", "\n", "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=2, \n", - " num_init_cvt_samples=10000, \n", - " num_centroids=num_centroids, \n", - " minval=minval, \n", + " num_descriptors=2,\n", + " num_init_cvt_samples=10000,\n", + " num_centroids=num_centroids,\n", + " minval=minval,\n", " maxval=maxval,\n", " random_key=random_key,\n", ")\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index ffa5522a..b3cc43b5 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -45,13 +45,25 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -75,7 +87,7 @@ "from qdax.utils.plotting import plot_skills_trajectory\n", "\n", "from IPython.display import HTML\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "\n", "\n", "\n", @@ -202,7 +214,6 @@ " # SAC config\n", " batch_size=batch_size,\n", " episode_length=episode_length,\n", - " grad_updates_per_step=grad_updates_per_step,\n", " tau=tau,\n", " normalize_observations=normalize_observations,\n", " learning_rate=learning_rate,\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 6dcc2c77..0562e7c2 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -41,17 +41,29 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", - " \n", + "\n", "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -75,7 +87,7 @@ "from qdax.utils.plotting import plot_skills_trajectory\n", "\n", "from IPython.display import HTML\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "\n", "\n", "\n", @@ -93,7 +105,7 @@ "source": [ "## Hyperparameters choice\n", "\n", - "Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290) and [DIAYN paper](https://arxiv.org/abs/1802.06070). \n", + "Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290) and [DIAYN paper](https://arxiv.org/abs/1802.06070).\n", "\n", "The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function direclty on the full state." ] @@ -200,7 +212,6 @@ " # SAC config\n", " batch_size=batch_size,\n", " episode_length=episode_length,\n", - " grad_updates_per_step=grad_updates_per_step,\n", " tau=tau,\n", " normalize_observations=normalize_observations,\n", " learning_rate=learning_rate,\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 18d4f0f3..574c56a2 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -49,13 +49,25 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -96,7 +108,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Setup the default platform where the MAP-Elites will be stored and MAP-Elite updates will happen. " + "Setup the default platform where the MAP-Elites will be stored and MAP-Elite updates will happen." ] }, { @@ -223,7 +235,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -247,7 +259,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { @@ -296,9 +308,9 @@ " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", ")\n", "mixing_emitter = MixingEmitter(\n", - " mutation_fn=None, \n", - " variation_fn=variation_fn, \n", - " variation_percentage=1.0, \n", + " mutation_fn=None,\n", + " variation_fn=variation_fn,\n", + " variation_percentage=1.0,\n", " batch_size=batch_size_per_device\n", ")" ] @@ -378,9 +390,9 @@ "\n", "# main loop\n", "for i in tqdm(range(num_loops), total=num_loops):\n", - " \n", + "\n", " start_time = time.time()\n", - " \n", + "\n", " # main iterations\n", " repertoire, emitter_state, random_key, metrics = update_fn(repertoire, emitter_state, random_key)\n", "\n", @@ -409,7 +421,7 @@ "source": [ "## Retrieve the repertoire from the first device\n", "\n", - "All devices have the same duplicated version of the repertoire " + "All devices have the same duplicated version of the repertoire" ] }, { diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 0b206f34..a6a140fd 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -25,7 +25,35 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "import jumanji\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", "\n", "import functools\n", "\n", @@ -212,7 +240,7 @@ "\n", "# compute observation size from observation spec\n", "obs_spec = env.observation_spec()\n", - "observation_size = np.prod(np.array(env.observation_spec().shape))\n", + "observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))\n", "\n", "fake_batch = jnp.zeros(shape=(batch_size, observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", @@ -247,13 +275,13 @@ " This function suppose that state descriptor is the feet contact, as it\n", " just computes the mean of the state descriptors given.\n", " \"\"\"\n", - " \n", + "\n", " # pre-process the observation\n", " observation = jax.vmap(jax.vmap(observation_processing))(data.obs)\n", - " \n", + "\n", " # get the mean\n", " mean_observation = jnp.mean(observation, axis=-2)\n", - " \n", + "\n", " # project those in [-1, 1]^2\n", " descriptors = jnp.tanh(mean_observation @ linear_projection.T)\n", "\n", @@ -323,9 +351,9 @@ " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", ")\n", "mixing_emitter = MixingEmitter(\n", - " mutation_fn=None, \n", - " variation_fn=variation_fn, \n", - " variation_percentage=1.0, \n", + " mutation_fn=None,\n", + " variation_fn=variation_fn,\n", + " variation_percentage=1.0,\n", " batch_size=batch_size\n", ")" ] @@ -551,8 +579,8 @@ " proba_action = policy_network.apply(my_params, network_input)\n", "\n", " action = jnp.argmax(proba_action)\n", - " \n", - " \n", + "\n", + "\n", " state, timestep = jax.jit(env.step)(state, action)" ] } diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index 49765438..b1fea651 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -49,13 +49,25 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -80,7 +92,7 @@ "from jax.flatten_util import ravel_pytree\n", "\n", "from IPython.display import HTML\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "\n", "\n", "\n", @@ -173,7 +185,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the function to play a step with the policy in the environment\n", + "# Define the fonction to play a step with the policy in the environment\n", "def play_step_fn(\n", " env_state,\n", " policy_params,\n", @@ -184,7 +196,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -208,7 +220,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { @@ -257,9 +269,9 @@ " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", ")\n", "mixing_emitter = MixingEmitter(\n", - " mutation_fn=None, \n", - " variation_fn=variation_fn, \n", - " variation_percentage=1.0, \n", + " mutation_fn=None,\n", + " variation_fn=variation_fn,\n", + " variation_percentage=1.0,\n", " batch_size=batch_size\n", ")" ] diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 6b4ae0b5..b2de823e 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -12,8 +12,45 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import haiku\n", + "except:\n", + " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", + " import haiku\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", "import optax\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "from IPython.display import HTML\n", "from tqdm import tqdm\n", "\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index ca127e72..238f703c 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -13,8 +13,45 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import haiku\n", + "except:\n", + " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", + " import haiku\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", "import matplotlib.pyplot as plt\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "from IPython.display import HTML\n", "from tqdm import tqdm\n", "\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb index ab5fad93..ad1a4740 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -54,22 +54,33 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@feat/add-algo-mees |tail -n 1\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", " import qdax\n", "\n", - "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax import environments\n", @@ -201,7 +212,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -227,7 +238,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { @@ -274,7 +285,7 @@ "source": [ "## Define the emitter: MEES Emitter\n", "\n", - "The emitter is used to evolve the population at each mutation step. In this example, the emitter is the MAP-Elites-ES approximated gradient emitter, the one used in \"Scaling MAP-Elites to Deep Neuroevolution\". \n", + "The emitter is used to evolve the population at each mutation step. In this example, the emitter is the MAP-Elites-ES approximated gradient emitter, the one used in \"Scaling MAP-Elites to Deep Neuroevolution\".\n", "At every generations, it uses samples-approximated gradients to improve the solutions of the archive.Half of the time it approximates the gradient of fitness and half of the time the gradient of novelty." ] }, @@ -435,7 +446,6 @@ "metadata": { "accelerator": "GPU", "colab": { - "collapsed_sections": [], "provenance": [] }, "gpuClass": "standard", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index 1fcd6c42..bd489ca2 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -50,13 +50,25 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -82,7 +94,7 @@ "from jax.flatten_util import ravel_pytree\n", "\n", "from IPython.display import HTML\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "\n", "\n", "\n", @@ -263,9 +275,9 @@ " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", ")\n", "mixing_emitter = MixingEmitter(\n", - " mutation_fn=None, \n", - " variation_fn=variation_fn, \n", - " variation_percentage=1.0, \n", + " mutation_fn=None,\n", + " variation_fn=variation_fn,\n", + " variation_percentage=1.0,\n", " batch_size=batch_size\n", ")" ] diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 05387158..bf0a5225 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -39,21 +39,27 @@ "from functools import partial\n", "\n", "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", " import chex\n", "\n", "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -65,8 +71,8 @@ "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax.core.mome import MOME\n", "from qdax.core.emitters.mutation_operators import (\n", - " polynomial_mutation, \n", - " polynomial_crossover, \n", + " polynomial_mutation,\n", + " polynomial_crossover,\n", ")\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", "from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_mome_pareto_fronts\n", @@ -233,9 +239,9 @@ "\n", "# Define emitter\n", "mixing_emitter = MixingEmitter(\n", - " mutation_fn=mutation_function, \n", - " variation_fn=crossover_function, \n", - " variation_percentage=crossover_percentage, \n", + " mutation_fn=mutation_function,\n", + " variation_fn=crossover_function,\n", + " variation_percentage=crossover_percentage,\n", " batch_size=batch_size\n", ")" ] @@ -256,10 +262,10 @@ "outputs": [], "source": [ "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=2, \n", - " num_init_cvt_samples=20000, \n", - " num_centroids=num_centroids, \n", - " minval=minval, \n", + " num_descriptors=2,\n", + " num_init_cvt_samples=20000,\n", + " num_centroids=num_centroids,\n", + " minval=minval,\n", " maxval=maxval,\n", " random_key=random_key,\n", ")" diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index e10c0d91..4e9ab3b0 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -40,21 +40,27 @@ "from functools import partial\n", "\n", "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", " import chex\n", "\n", "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -71,7 +77,7 @@ ")\n", "\n", "from qdax.core.emitters.mutation_operators import (\n", - " polynomial_crossover, \n", + " polynomial_crossover,\n", " polynomial_mutation\n", ")\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 0a28876a..8d417cc0 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -30,26 +30,32 @@ "metadata": {}, "outputs": [], "source": [ - "import jax \n", + "import jax\n", "import jax.numpy as jnp\n", "import math\n", "\n", "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", " import chex\n", "\n", "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -181,7 +187,7 @@ "source": [ "## Define the initial population, the emitter and the MAP Elites instance\n", "\n", - "The emitter is defined using the OMGMEGA emitter class. This emitter is given to a MAP-Elites instance to create an instance of the OMG-MEGA algorithm. " + "The emitter is defined using the OMGMEGA emitter class. This emitter is given to a MAP-Elites instance to create an instance of the OMG-MEGA algorithm." ] }, { @@ -196,13 +202,13 @@ "random_key, subkey = jax.random.split(random_key)\n", "initial_population = jax.random.normal(subkey, shape=(init_population_size, num_dimensions))\n", "\n", - "sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid \n", + "sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid\n", "grid_shape = (sqrt_centroids, sqrt_centroids)\n", "centroids = compute_euclidean_centroids(\n", " grid_shape = grid_shape,\n", " minval = minval,\n", " maxval = maxval\n", - ") \n", + ")\n", "\n", "# defines the emitter\n", "emitter = OMGMEGAEmitter(\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 11ed6afe..6152ce63 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -49,10 +49,28 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -62,13 +80,20 @@ "from qdax.core.aurora import AURORA\n", "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", "from qdax import environments\n", - "from qdax.tasks.brax_envs import scoring_aurora_function\n", - "from qdax.environments.bd_extractors import get_aurora_bd\n", + "from qdax.tasks.brax_envs import (\n", + " create_default_brax_task_components,\n", + " get_aurora_scoring_fn,\n", + ")\n", + "from qdax.environments.bd_extractors import (\n", + " AuroraExtraInfoNormalization,\n", + " get_aurora_encoding,\n", + ")\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter\n", "\n", + "from qdax.types import Observation\n", "from qdax.utils import train_seq2seq\n", "\n", "\n", @@ -202,7 +227,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -226,7 +251,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { @@ -236,19 +261,35 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = functools.partial(\n", - " get_aurora_bd,\n", - " option=observation_option,\n", - " hidden_size=hidden_size,\n", - " traj_sampling_freq=traj_sampling_freq,\n", - " max_observation_size=max_observation_size,\n", + "env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n", + " env_name=env_name,\n", + " random_key=random_key,\n", ")\n", - "scoring_fn = functools.partial(\n", - " scoring_aurora_function,\n", - " init_states=init_states,\n", - " episode_length=episode_length,\n", - " play_step_fn=play_step_fn,\n", - " behavior_descriptor_extractor=bd_extraction_fn,\n", + "\n", + "def observation_extractor_fn(\n", + " data: QDTransition,\n", + ") -> Observation:\n", + " \"\"\"Extract observation from the state.\"\"\"\n", + " state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size]\n", + "\n", + " # add the x/y position - (batch_size, traj_length, 2)\n", + " state_desc = data.state_desc[:, ::traj_sampling_freq]\n", + "\n", + " if observation_option == \"full\":\n", + " observations = jnp.concatenate([state_desc, state_obs], axis=-1)\n", + " elif observation_option == \"no_sd\":\n", + " observations = state_obs\n", + " elif observation_option == \"only_sd\":\n", + " observations = state_desc\n", + " else:\n", + " raise ValueError(\"Unknown observation option.\")\n", + "\n", + " return observations\n", + "\n", + "# Prepare the scoring function\n", + "aurora_scoring_fn = get_aurora_scoring_fn(\n", + " scoring_fn=scoring_fn,\n", + " observation_extractor_fn=observation_extractor_fn,\n", ")\n", "\n", "# Get minimum reward value to make sure qd_score are positive\n", @@ -336,13 +377,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Instantiate AURORA\n", - "aurora = AURORA(\n", - " scoring_function=scoring_fn,\n", - " emitter=pg_emitter,\n", - " metrics_function=metrics_fn,\n", - ")\n", - "\n", "aurora_dims = hidden_size\n", "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", "\n", @@ -353,9 +387,7 @@ " repertoire,\n", " emitter_state,\n", " random_key,\n", - " model_params,\n", - " mean_observations,\n", - " std_observations,\n", + " aurora_extra_info\n", " ) = carry\n", "\n", " # update\n", @@ -363,13 +395,11 @@ " repertoire,\n", " emitter_state,\n", " random_key,\n", - " model_params,\n", - " mean_observations,\n", - " std_observations,\n", + " aurora_extra_info=aurora_extra_info,\n", " )\n", "\n", " return (\n", - " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " (repertoire, emitter_state, random_key, aurora_extra_info),\n", " metrics,\n", " )\n", "\n", @@ -391,11 +421,43 @@ "else:\n", " ValueError(\"The chosen option is not correct.\")\n", "\n", - "# define the seq2seq model\n", + "# Define the seq2seq model\n", "model = train_seq2seq.get_model(\n", " observations_dims[-1], True, hidden_size=hidden_size\n", ")\n", "\n", + "# Init the model params\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params = train_seq2seq.get_initial_params(\n", + " model, subkey, (1, *observations_dims)\n", + ")\n", + "\n", + "print(jax.tree_map(lambda x: x.shape, model_params))\n", + "\n", + "# Define the encoder function\n", + "encoder_fn = jax.jit(\n", + " functools.partial(\n", + " get_aurora_encoding,\n", + " model=model,\n", + " )\n", + ")\n", + "\n", + "# Define the training function\n", + "train_fn = functools.partial(\n", + " train_seq2seq.lstm_ae_train,\n", + " model=model,\n", + " batch_size=lstm_batch_size,\n", + ")\n", + "\n", + "# Instantiate AURORA\n", + "aurora = AURORA(\n", + " scoring_function=aurora_scoring_fn,\n", + " emitter=pg_emitter,\n", + " metrics_function=metrics_fn,\n", + " encoder_function=encoder_fn,\n", + " training_function=train_fn,\n", + ")\n", + "\n", "# init the model params\n", "random_key, subkey = jax.random.split(random_key)\n", "model_params = train_seq2seq.get_initial_params(\n", @@ -408,21 +470,26 @@ "mean_observations = jnp.zeros(observations_dims[-1])\n", "std_observations = jnp.ones(observations_dims[-1])\n", "\n", - "# init step of the aurora algorithm\n", - "repertoire, emitter_state, random_key = aurora.init(\n", - " init_variables,\n", - " centroids,\n", - " random_key,\n", + "# init all the information needed by AURORA to compute encodings\n", + "aurora_extra_info = AuroraExtraInfoNormalization.create(\n", " model_params,\n", " mean_observations,\n", " std_observations,\n", - " l_value_init,\n", + ")\n", + "\n", + "# init step of the aurora algorithm\n", + "repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n", + " init_variables,\n", + " aurora_extra_info,\n", + " jnp.asarray(l_value_init),\n", + " max_observation_size,\n", + " random_key,\n", ")\n", "\n", "# initializing means and stds and AURORA\n", "random_key, subkey = jax.random.split(random_key)\n", - "model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n", - " subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n", + "repertoire, aurora_extra_info = aurora.train(\n", + " repertoire, model_params, iteration=0, random_key=subkey\n", ")\n", "\n", "# design aurora's schedule\n", @@ -456,11 +523,11 @@ "while iteration < max_iterations:\n", "\n", " (\n", - " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " (repertoire, emitter_state, random_key, aurora_extra_info),\n", " metrics,\n", " ) = jax.lax.scan(\n", " update_scan_fn,\n", - " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " (repertoire, emitter_state, random_key, aurora_extra_info),\n", " (),\n", " length=log_freq,\n", " )\n", @@ -474,61 +541,17 @@ " if (iteration + 1) in schedules:\n", " # train the autoencoder\n", " random_key, subkey = jax.random.split(random_key)\n", - " (\n", - " model_params,\n", - " mean_observations,\n", - " std_observations,\n", - " ) = train_seq2seq.lstm_ae_train(\n", - " subkey,\n", - " repertoire,\n", - " model_params,\n", - " iteration,\n", - " hidden_size=hidden_size,\n", - " batch_size=lstm_batch_size\n", + " repertoire, aurora_extra_info = aurora.train(\n", + " repertoire, model_params, iteration, subkey\n", " )\n", "\n", - " # re-addition of all the new behavioural descriotpors with the new ae\n", - " normalized_observations = (\n", - " repertoire.observations - mean_observations\n", - " ) / std_observations\n", - "\n", - " new_descriptors = model.apply(\n", - " {\"params\": model_params}, normalized_observations, method=model.encode\n", - " )\n", - " repertoire = repertoire.init(\n", - " genotypes=repertoire.genotypes,\n", - " centroids=repertoire.centroids,\n", - " fitnesses=repertoire.fitnesses,\n", - " descriptors=new_descriptors,\n", - " observations=repertoire.observations,\n", - " l_value=repertoire.l_value,\n", - " )\n", - " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", - "\n", " elif iteration % 2 == 0:\n", - " # update the l value\n", - " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", - "\n", - " # CVC Implementation to keep a constant number of individuals in the archive\n", - " current_error = num_indivs - n_target\n", - " change_rate = current_error - previous_error\n", - " prop_gain = 1 * 10e-6\n", - " l_value = (\n", - " repertoire.l_value\n", - " + (prop_gain * (current_error))\n", - " + (prop_gain * change_rate)\n", + " repertoire, previous_error = aurora.container_size_control(\n", + " repertoire,\n", + " target_size=n_target,\n", + " previous_error=previous_error,\n", " )\n", "\n", - " previous_error = current_error\n", - "\n", - " repertoire = repertoire.init(\n", - " genotypes=repertoire.genotypes,\n", - " centroids=repertoire.centroids,\n", - " fitnesses=repertoire.fitnesses,\n", - " descriptors=repertoire.descriptors,\n", - " observations=repertoire.observations,\n", - " l_value=l_value,\n", - " )\n", "\n", " iteration += 1" ] diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index ab1ae221..9b638b2d 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -48,13 +48,25 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -191,7 +203,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -215,7 +227,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index d778ad1d..102d5262 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -48,13 +48,25 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", "\n", "try:\n", @@ -204,7 +216,7 @@ " \"\"\"\n", "\n", " actions = policy_network.apply(policy_params, env_state.obs)\n", - " \n", + "\n", " state_desc = env_state.info[\"state_descriptor\"]\n", " next_state = env.step(env_state, actions)\n", "\n", @@ -228,7 +240,7 @@ "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", - "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." ] }, { @@ -278,7 +290,7 @@ " batch_size=transitions_batch_size,\n", " critic_hidden_layer_size=critic_hidden_layer_size,\n", " critic_learning_rate=critic_learning_rate,\n", - " greedy_learning_rate=greedy_learning_rate,\n", + " actor_learning_rate=greedy_learning_rate,\n", " policy_learning_rate=policy_learning_rate,\n", " noise_clip=noise_clip,\n", " policy_noise=policy_noise,\n", @@ -297,7 +309,7 @@ " batch_size=transitions_batch_size,\n", " critic_hidden_layer_size=critic_hidden_layer_size,\n", " critic_learning_rate=critic_learning_rate,\n", - " greedy_learning_rate=greedy_learning_rate,\n", + " actor_learning_rate=greedy_learning_rate,\n", " policy_learning_rate=policy_learning_rate,\n", " noise_clip=noise_clip,\n", " policy_noise=policy_noise,\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 4f225667..7762083f 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -18,7 +18,44 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", - "from brax.io import html\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import haiku\n", + "except:\n", + " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", + " import haiku\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "from brax.v1.io import html\n", "from IPython.display import HTML\n", "from tqdm import tqdm\n", "\n", diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index a59e6246..fe655fe2 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -45,15 +45,27 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", " import jumanji\n", "except:\n", - " !pip install \"jumanji==0.2.2\"\n", + " !pip install \"jumanji==0.3.1\"\n", " import jumanji\n", - " \n", + "\n", "try:\n", " import haiku\n", "except:\n", @@ -76,7 +88,7 @@ "from qdax.utils.plotting import plot_skills_trajectory\n", "\n", "from IPython.display import HTML\n", - "from brax.io import html\n", + "from brax.v1.io import html\n", "\n", "\n", "\n", @@ -212,7 +224,6 @@ " # SAC config\n", " batch_size=batch_size,\n", " episode_length=episode_length,\n", - " grad_updates_per_step=grad_updates_per_step,\n", " tau=tau,\n", " normalize_observations=normalize_observations,\n", " learning_rate=learning_rate,\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 7eba8043..ec98b9da 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -15,6 +15,43 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import haiku\n", + "except:\n", + " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", + " import haiku\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", "from tqdm import tqdm\n", "\n", "from qdax import environments\n",