Skip to content

Commit

Permalink
Stop returning random keys
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Sep 3, 2024
1 parent cba9f27 commit b1535ee
Show file tree
Hide file tree
Showing 113 changed files with 1,354 additions and 1,604 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ min_descriptor = 0.0
max_descriptor = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
key, subkey = jax.random.split(key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
Expand Down Expand Up @@ -111,14 +111,14 @@ centroids = compute_euclidean_centroids(
)

# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)
repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key)

# Run MAP-Elites loop
for i in range(num_iterations):
(repertoire, emitter_state, metrics, random_key,) = map_elites.update(
(repertoire, emitter_state, metrics, key,) = map_elites.update(
repertoire,
emitter_state,
random_key,
key,
)

# Get contents of repertoire
Expand Down
10 changes: 5 additions & 5 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ More importantly, QDax handles the archive management which is the key idea of Q
## Code Example
```python
# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)
repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key)

for i in range(num_iterations):

# generate new population with the emitter
genotypes, random_key = map_elites._emitter.emit(
repertoire, emitter_state, random_key
genotypes, key = map_elites._emitter.emit(
repertoire, emitter_state, key
)

# scores/evaluates the population
fitnesses, descriptors, extra_scores, random_key = map_elites._scoring_function(
genotypes, random_key
fitnesses, descriptors, extra_scores, key = map_elites._scoring_function(
genotypes, key
)

# update repertoire
Expand Down
38 changes: 19 additions & 19 deletions examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand All @@ -170,14 +170,14 @@
")\n",
"\n",
"# Init population of controllers\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"keys = jax.random.split(subkey, num=batch_size)\n",
"fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n",
"init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n",
"\n",
"\n",
"# Create the initial environment states\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n",
"reset_fn = jax.jit(jax.vmap(env.reset))\n",
"init_states = reset_fn(keys)"
Expand All @@ -202,7 +202,7 @@
"def play_step_fn(\n",
" env_state,\n",
" policy_params,\n",
" random_key,\n",
" key,\n",
"):\n",
" \"\"\"\n",
" Play an environment step and return the updated state and the transition.\n",
Expand All @@ -224,7 +224,7 @@
" next_state_desc=next_state.info[\"state_descriptor\"],\n",
" )\n",
"\n",
" return next_state, policy_params, random_key, transition"
" return next_state, policy_params, key, transition"
]
},
{
Expand All @@ -243,9 +243,9 @@
"outputs": [],
"source": [
"# Prepare the scoring function\n",
"env, policy_network, scoring_fn, random_key = create_default_brax_task_components(\n",
"env, policy_network, scoring_fn, key = create_default_brax_task_components(\n",
" env_name=env_name,\n",
" random_key=random_key,\n",
" key=key,\n",
")\n",
"\n",
"def observation_extractor_fn(\n",
Expand Down Expand Up @@ -339,20 +339,20 @@
" \"\"\"Scan the udpate function.\"\"\"\n",
" (\n",
" repertoire,\n",
" random_key,\n",
" key,\n",
" aurora_extra_info\n",
" ) = carry\n",
"\n",
" # update\n",
" (repertoire, _, metrics, random_key,) = aurora.update(\n",
" (repertoire, _, metrics, key,) = aurora.update(\n",
" repertoire,\n",
" None,\n",
" random_key,\n",
" key,\n",
" aurora_extra_info=aurora_extra_info,\n",
" )\n",
"\n",
" return (\n",
" (repertoire, random_key, aurora_extra_info),\n",
" (repertoire, key, aurora_extra_info),\n",
" metrics,\n",
" )\n",
"\n",
Expand Down Expand Up @@ -380,7 +380,7 @@
")\n",
"\n",
"# Init the model params\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"model_params = train_seq2seq.get_initial_params(\n",
" model, subkey, (1, *observations_dims)\n",
")\n",
Expand Down Expand Up @@ -423,18 +423,18 @@
")\n",
"\n",
"# init step of the aurora algorithm\n",
"repertoire, emitter_state, aurora_extra_info, random_key = aurora.init(\n",
"repertoire, emitter_state, aurora_extra_info, key = aurora.init(\n",
" init_variables,\n",
" aurora_extra_info,\n",
" jnp.asarray(l_value_init),\n",
" max_observation_size,\n",
" random_key,\n",
" key,\n",
")\n",
"\n",
"# initializing means and stds and AURORA\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"repertoire, aurora_extra_info = aurora.train(\n",
" repertoire, model_params, iteration=0, random_key=subkey\n",
" repertoire, model_params, iteration=0, key=subkey\n",
")\n",
"\n",
"# design aurora's schedule\n",
Expand Down Expand Up @@ -468,11 +468,11 @@
"while iteration < max_iterations:\n",
"\n",
" (\n",
" (repertoire, random_key, aurora_extra_info),\n",
" (repertoire, key, aurora_extra_info),\n",
" metrics,\n",
" ) = jax.lax.scan(\n",
" update_scan_fn,\n",
" (repertoire, random_key, aurora_extra_info),\n",
" (repertoire, key, aurora_extra_info),\n",
" (),\n",
" length=log_freq,\n",
" )\n",
Expand All @@ -485,7 +485,7 @@
" # autoencoder steps and CVC\n",
" if (iteration + 1) in schedules:\n",
" # train the autoencoder\n",
" random_key, subkey = jax.random.split(random_key)\n",
" key, subkey = jax.random.split(key)\n",
" repertoire, aurora_extra_info = aurora.train(\n",
" repertoire, model_params, iteration, subkey\n",
" )\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
"outputs": [],
"source": [
"state = cmaes.init()\n",
"random_key = jax.random.PRNGKey(0)"
"key = jax.random.key(0)"
]
},
{
Expand Down Expand Up @@ -225,7 +225,7 @@
" iteration_count += 1\n",
"\n",
" # sample\n",
" samples, random_key = cmaes.sample(state, random_key)\n",
" samples, key = cmaes.sample(state, key)\n",
"\n",
" # udpate\n",
" state = cmaes.update(state, samples)\n",
Expand Down Expand Up @@ -296,7 +296,7 @@
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# sample points to show fitness landscape\n",
"random_key, subkey = jax.random.split(random_key)\n",
"key, subkey = jax.random.split(key)\n",
"x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n",
"f_x = fitness_fn(x)\n",
"\n",
Expand Down
14 changes: 7 additions & 7 deletions examples/cmame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@
" scores, descriptors = fitness_scoring(x), _descriptors(x)\n",
" return scores, descriptors, {}\n",
"\n",
"def scoring_fn(x, random_key):\n",
"def scoring_fn(x, key):\n",
" fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n",
" return fitnesses, descriptors, extra_scores, random_key"
" return fitnesses, descriptors, extra_scores, key"
]
},
{
Expand Down Expand Up @@ -217,10 +217,10 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"# in CMA-ME settings (from the paper), there is no init population\n",
"# we multipy by zero to reproduce this setting\n",
"initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n",
"initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.\n",
"\n",
"centroids = compute_euclidean_centroids(\n",
" grid_shape=grid_shape,\n",
Expand Down Expand Up @@ -271,7 +271,7 @@
"metadata": {},
"outputs": [],
"source": [
"repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)"
"repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key)"
]
},
{
Expand All @@ -289,9 +289,9 @@
"source": [
"%%time\n",
"\n",
"(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n",
"(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n",
" map_elites.scan_update,\n",
" (repertoire, emitter_state, random_key),\n",
" (repertoire, emitter_state, key),\n",
" (),\n",
" length=num_iterations,\n",
")"
Expand Down
18 changes: 9 additions & 9 deletions examples/cmamega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@
"\n",
" return scores, descriptors, extra_scores\n",
"\n",
"def scoring_fn(x, random_key):\n",
"def scoring_fn(x, key):\n",
" fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)\n",
" return fitnesses, descriptors, extra_scores, random_key"
" return fitnesses, descriptors, extra_scores, key"
]
},
{
Expand Down Expand Up @@ -210,17 +210,17 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"# no initial population - give all the same value as emitter init value\n",
"initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n",
"initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.\n",
"\n",
"centroids, random_key = compute_cvt_centroids(\n",
"centroids, key = compute_cvt_centroids(\n",
" num_descriptors=2,\n",
" num_init_cvt_samples=10000,\n",
" num_centroids=num_centroids,\n",
" minval=minval,\n",
" maxval=maxval,\n",
" random_key=random_key,\n",
" key=key,\n",
")\n",
"\n",
"emitter = CMAMEGAEmitter(\n",
Expand All @@ -245,7 +245,7 @@
"metadata": {},
"outputs": [],
"source": [
"repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)"
"repertoire, emitter_state, key = map_elites.init(initial_population, centroids, key)"
]
},
{
Expand All @@ -256,9 +256,9 @@
"source": [
"%%time\n",
"\n",
"(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n",
"(repertoire, emitter_state, key,), metrics = jax.lax.scan(\n",
" map_elites.scan_update,\n",
" (repertoire, emitter_state, random_key),\n",
" (repertoire, emitter_state, key),\n",
" (),\n",
" length=num_iterations,\n",
")"
Expand Down
14 changes: 7 additions & 7 deletions examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -492,10 +492,10 @@
"jit_env_step = jax.jit(visual_env.step)\n",
"\n",
"@jax.jit\n",
"def jit_inference_fn(params, observation, random_key):\n",
"def jit_inference_fn(params, observation, key):\n",
" obs = jnp.concatenate([observation, skill], axis=0)\n",
" action, random_key = dads.select_action(obs, params, random_key, deterministic=True)\n",
" return action, random_key"
" action, key = dads.select_action(obs, params, key, deterministic=True)\n",
" return action, key"
]
},
{
Expand All @@ -512,11 +512,11 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
" action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n",
" action, key = jit_inference_fn(my_params, state.obs, key)\n",
" state = jit_env_step(state, action)\n",
"\n",
"print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")"
Expand Down
14 changes: 7 additions & 7 deletions examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -482,10 +482,10 @@
"jit_env_step = jax.jit(visual_env.step)\n",
"\n",
"@jax.jit\n",
"def jit_inference_fn(params, observation, random_key):\n",
"def jit_inference_fn(params, observation, key):\n",
" obs = jnp.concatenate([observation, skill], axis=0)\n",
" action, random_key = diayn.select_action(obs, params, random_key, deterministic=True)\n",
" return action, random_key"
" action, key = diayn.select_action(obs, params, key, deterministic=True)\n",
" return action, key"
]
},
{
Expand All @@ -502,11 +502,11 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
" action, random_key = jit_inference_fn(my_params, state.obs, random_key)\n",
" action, key = jit_inference_fn(my_params, state.obs, key)\n",
" state = jit_env_step(state, action)\n",
"\n",
"print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")"
Expand Down
Loading

0 comments on commit b1535ee

Please sign in to comment.