Skip to content

Commit

Permalink
fix: fix all notebooks to run with latest develop
Browse files Browse the repository at this point in the history
  • Loading branch information
manon-but-yes committed Jan 16, 2024
1 parent b4125c3 commit 5f74f60
Show file tree
Hide file tree
Showing 22 changed files with 676 additions and 300 deletions.
202 changes: 108 additions & 94 deletions examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,28 @@
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n",
" !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
Expand All @@ -62,13 +80,20 @@
"from qdax.core.aurora import AURORA\n",
"from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n",
"from qdax import environments\n",
"from qdax.tasks.brax_envs import scoring_aurora_function\n",
"from qdax.environments.bd_extractors import get_aurora_bd\n",
"from qdax.tasks.brax_envs import (\n",
" create_default_brax_task_components,\n",
" get_aurora_scoring_fn,\n",
")\n",
"from qdax.environments.bd_extractors import (\n",
" AuroraExtraInfoNormalization,\n",
" get_aurora_encoding,\n",
")\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
"from qdax.core.neuroevolution.networks.networks import MLP\n",
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
"from qdax.core.emitters.standard_emitters import MixingEmitter\n",
"\n",
"from qdax.types import Observation\n",
"from qdax.utils import train_seq2seq\n",
"\n",
"\n",
Expand Down Expand Up @@ -184,7 +209,7 @@
" \"\"\"\n",
"\n",
" actions = policy_network.apply(policy_params, env_state.obs)\n",
" \n",
"\n",
" state_desc = env_state.info[\"state_descriptor\"]\n",
" next_state = env.step(env_state, actions)\n",
"\n",
Expand All @@ -208,7 +233,7 @@
"source": [
"## Define the scoring function and the way metrics are computed\n",
"\n",
"The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. "
"The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual."
]
},
{
Expand All @@ -218,19 +243,35 @@
"outputs": [],
"source": [
"# Prepare the scoring function\n",
"bd_extraction_fn = functools.partial(\n",
" get_aurora_bd,\n",
" option=observation_option,\n",
" hidden_size=hidden_size,\n",
" traj_sampling_freq=traj_sampling_freq,\n",
" max_observation_size=max_observation_size,\n",
"env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n",
" env_name=env_name,\n",
" random_key=random_key,\n",
")\n",
"scoring_fn = functools.partial(\n",
" scoring_aurora_function,\n",
" init_states=init_states,\n",
" episode_length=episode_length,\n",
" play_step_fn=play_step_fn,\n",
" behavior_descriptor_extractor=bd_extraction_fn,\n",
"\n",
"def observation_extractor_fn(\n",
" data: QDTransition,\n",
") -> Observation:\n",
" \"\"\"Extract observation from the state.\"\"\"\n",
" state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size]\n",
"\n",
" # add the x/y position - (batch_size, traj_length, 2)\n",
" state_desc = data.state_desc[:, ::traj_sampling_freq]\n",
"\n",
" if observation_option == \"full\":\n",
" observations = jnp.concatenate([state_desc, state_obs], axis=-1)\n",
" elif observation_option == \"no_sd\":\n",
" observations = state_obs\n",
" elif observation_option == \"only_sd\":\n",
" observations = state_desc\n",
" else:\n",
" raise ValueError(\"Unknown observation option.\")\n",
"\n",
" return observations\n",
"\n",
"# Prepare the scoring function\n",
"aurora_scoring_fn = get_aurora_scoring_fn(\n",
" scoring_fn=scoring_fn,\n",
" observation_extractor_fn=observation_extractor_fn,\n",
")\n",
"\n",
"# Get minimum reward value to make sure qd_score are positive\n",
Expand Down Expand Up @@ -290,13 +331,6 @@
"metadata": {},
"outputs": [],
"source": [
"# Instantiate AURORA\n",
"aurora = AURORA(\n",
" scoring_function=scoring_fn,\n",
" emitter=mixing_emitter,\n",
" metrics_function=metrics_fn,\n",
")\n",
"\n",
"aurora_dims = hidden_size\n",
"centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n",
"\n",
Expand All @@ -306,23 +340,19 @@
" (\n",
" repertoire,\n",
" random_key,\n",
" model_params,\n",
" mean_observations,\n",
" std_observations,\n",
" aurora_extra_info\n",
" ) = carry\n",
"\n",
" # update\n",
" (repertoire, _, metrics, random_key,) = aurora.update(\n",
" repertoire,\n",
" None,\n",
" random_key,\n",
" model_params,\n",
" mean_observations,\n",
" std_observations,\n",
" aurora_extra_info=aurora_extra_info,\n",
" )\n",
"\n",
" return (\n",
" (repertoire, random_key, model_params, mean_observations, std_observations),\n",
" (repertoire, random_key, aurora_extra_info),\n",
" metrics,\n",
" )\n",
"\n",
Expand All @@ -344,38 +374,67 @@
"else:\n",
" ValueError(\"The chosen option is not correct.\")\n",
"\n",
"# define the seq2seq model\n",
"# Define the seq2seq model\n",
"model = train_seq2seq.get_model(\n",
" observations_dims[-1], True, hidden_size=hidden_size\n",
")\n",
"\n",
"# init the model params\n",
"# Init the model params\n",
"random_key, subkey = jax.random.split(random_key)\n",
"model_params = train_seq2seq.get_initial_params(\n",
" model, subkey, (1, *observations_dims)\n",
")\n",
"\n",
"print(jax.tree_map(lambda x: x.shape, model_params))\n",
"\n",
"# Define the encoder function\n",
"encoder_fn = jax.jit(\n",
" functools.partial(\n",
" get_aurora_encoding,\n",
" model=model,\n",
" )\n",
")\n",
"\n",
"# Define the training function\n",
"train_fn = functools.partial(\n",
" train_seq2seq.lstm_ae_train,\n",
" model=model,\n",
" batch_size=lstm_batch_size,\n",
")\n",
"\n",
"# Instantiate AURORA\n",
"aurora = AURORA(\n",
" scoring_function=aurora_scoring_fn,\n",
" emitter=mixing_emitter,\n",
" metrics_function=metrics_fn,\n",
" encoder_function=encoder_fn,\n",
" training_function=train_fn,\n",
")\n",
"\n",
"# define arbitrary observation's mean/std\n",
"mean_observations = jnp.zeros(observations_dims[-1])\n",
"std_observations = jnp.ones(observations_dims[-1])\n",
"\n",
"# init step of the aurora algorithm\n",
"repertoire, _, random_key = aurora.init(\n",
" init_variables,\n",
" centroids,\n",
" random_key,\n",
"# init all the information needed by AURORA to compute encodings\n",
"aurora_extra_info = AuroraExtraInfoNormalization.create(\n",
" model_params,\n",
" mean_observations,\n",
" std_observations,\n",
" l_value_init,\n",
")\n",
"\n",
"# init step of the aurora algorithm\n",
"repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n",
" init_variables,\n",
" aurora_extra_info,\n",
" jnp.asarray(l_value_init),\n",
" max_observation_size,\n",
" random_key,\n",
")\n",
"\n",
"# initializing means and stds and AURORA\n",
"random_key, subkey = jax.random.split(random_key)\n",
"model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n",
" subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n",
"repertoire, aurora_extra_info = aurora.train(\n",
" repertoire, model_params, iteration=0, random_key=subkey\n",
")\n",
"\n",
"# design aurora's schedule\n",
Expand Down Expand Up @@ -409,11 +468,11 @@
"while iteration < max_iterations:\n",
"\n",
" (\n",
" (repertoire, random_key, model_params, mean_observations, std_observations),\n",
" (repertoire, random_key, aurora_extra_info),\n",
" metrics,\n",
" ) = jax.lax.scan(\n",
" update_scan_fn,\n",
" (repertoire, random_key, model_params, mean_observations, std_observations),\n",
" (repertoire, random_key, aurora_extra_info),\n",
" (),\n",
" length=log_freq,\n",
" )\n",
Expand All @@ -427,60 +486,15 @@
" if (iteration + 1) in schedules:\n",
" # train the autoencoder\n",
" random_key, subkey = jax.random.split(random_key)\n",
" (\n",
" model_params,\n",
" mean_observations,\n",
" std_observations,\n",
" ) = train_seq2seq.lstm_ae_train(\n",
" subkey,\n",
" repertoire,\n",
" model_params,\n",
" iteration,\n",
" hidden_size=hidden_size,\n",
" batch_size=lstm_batch_size\n",
" repertoire, aurora_extra_info = aurora.train(\n",
" repertoire, model_params, iteration, subkey\n",
" )\n",
"\n",
" # re-addition of all the new behavioural descriotpors with the new ae\n",
" normalized_observations = (\n",
" repertoire.observations - mean_observations\n",
" ) / std_observations\n",
"\n",
" new_descriptors = model.apply(\n",
" {\"params\": model_params}, normalized_observations, method=model.encode\n",
" )\n",
" repertoire = repertoire.init(\n",
" genotypes=repertoire.genotypes,\n",
" centroids=repertoire.centroids,\n",
" fitnesses=repertoire.fitnesses,\n",
" descriptors=new_descriptors,\n",
" observations=repertoire.observations,\n",
" l_value=repertoire.l_value,\n",
" )\n",
" num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n",
"\n",
" elif iteration % 2 == 0:\n",
" # update the l value\n",
" num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n",
"\n",
" # CVC Implementation to keep a constant number of individuals in the archive\n",
" current_error = num_indivs - n_target\n",
" change_rate = current_error - previous_error\n",
" prop_gain = 1 * 10e-6\n",
" l_value = (\n",
" repertoire.l_value\n",
" + (prop_gain * (current_error))\n",
" + (prop_gain * change_rate)\n",
" )\n",
"\n",
" previous_error = current_error\n",
"\n",
" repertoire = repertoire.init(\n",
" genotypes=repertoire.genotypes,\n",
" centroids=repertoire.centroids,\n",
" fitnesses=repertoire.fitnesses,\n",
" descriptors=repertoire.descriptors,\n",
" observations=repertoire.observations,\n",
" l_value=l_value,\n",
" repertoire, previous_error = aurora.container_size_control(\n",
" repertoire,\n",
" target_size=n_target,\n",
" previous_error=previous_error,\n",
" )\n",
"\n",
" iteration += 1"
Expand Down
42 changes: 36 additions & 6 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,36 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"try:\n",
" import brax\n",
"except:\n",
" !pip install git+https://github.com/google/[email protected] |tail -n 1\n",
" import brax\n",
"\n",
"try:\n",
" import flax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/google/[email protected] |tail -n 1\n",
" import flax\n",
"\n",
"try:\n",
" import chex\n",
"except:\n",
" !pip install --no-deps git+https://github.com/deepmind/[email protected] |tail -n 1\n",
" import chex\n",
"\n",
"try:\n",
" import jumanji\n",
"except:\n",
" !pip install \"jumanji==0.3.1\"\n",
" import jumanji\n",
"\n",
"try:\n",
" import qdax\n",
"except:\n",
" !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n",
" import qdax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"\n",
Expand Down Expand Up @@ -193,23 +223,23 @@
"iteration_count = 0\n",
"for _ in range(num_iterations):\n",
" iteration_count += 1\n",
" \n",
"\n",
" # sample\n",
" samples, random_key = cmaes.sample(state, random_key)\n",
" \n",
"\n",
" # udpate\n",
" state = cmaes.update(state, samples)\n",
" \n",
"\n",
" # check stop condition\n",
" stop_condition = cmaes.stop_condition(state)\n",
"\n",
" if stop_condition:\n",
" break\n",
" \n",
"\n",
" # store data for plotting\n",
" means.append(state.mean)\n",
" covs.append((state.sigma**2) * state.cov_matrix)\n",
" \n",
"\n",
"print(\"Num iterations before stop condition: \", iteration_count)"
]
},
Expand Down Expand Up @@ -281,7 +311,7 @@
" ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')\n",
" ax.add_patch(ellipse)\n",
" ax.plot(mean[0], mean[1], color='k', marker='x')\n",
" \n",
"\n",
"ax.set_title(f\"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}\")\n",
"plt.show()"
]
Expand Down
Loading

0 comments on commit 5f74f60

Please sign in to comment.