Skip to content

Commit

Permalink
Merge branch 'new-style-jax-rng-keys' of github.com:miltonllera/QDax …
Browse files Browse the repository at this point in the history
…into miltonllera-new-style-jax-rng-keys
  • Loading branch information
Lookatator committed Sep 13, 2024
2 parents 85b6ba6 + b1d0685 commit 7ba39e3
Show file tree
Hide file tree
Showing 60 changed files with 73 additions and 73 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ min_bd = 0.0
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key = jax.random.key(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
Expand Down
2 changes: 1 addition & 1 deletion examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
"outputs": [],
"source": [
"state = cmaes.init()\n",
"random_key = jax.random.PRNGKey(0)"
"random_key = jax.random.key(0)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/cmame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.random.key(0)\n",
"# in CMA-ME settings (from the paper), there is no init population\n",
"# we multipy by zero to reproduce this setting\n",
"initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/cmamega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.key(0)\n",
"# no initial population - give all the same value as emitter init value\n",
"initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -499,7 +499,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"random_key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -490,7 +490,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"random_key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed_mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/jumanji_snake.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"env = jumanji.make('Snake-v1')\n",
"\n",
"# Reset your (jit-able) environment\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"state, timestep = jax.jit(env.reset)(key)\n",
"\n",
"# Interact with the (jit-able) environment\n",
Expand All @@ -137,7 +137,7 @@
"outputs": [],
"source": [
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# get number of actions\n",
"num_actions = env.action_spec().maximum + 1\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.Key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down Expand Up @@ -494,7 +494,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=rng)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"outputs": [],
"source": [
"%%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, subkey = jax.random.split(key)\n",
"env_states = jax.jit(env.reset)(rng=subkey)\n",
"eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)"
Expand Down Expand Up @@ -504,7 +504,7 @@
"%%time\n",
"rollout = []\n",
"\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"env_state = jax.jit(env.reset)(rng=rng)\n",
"\n",
"training_state, env_state = jax.tree_map(\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/me_td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"outputs": [],
"source": [
"%%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, subkey = jax.random.split(key)\n",
"env_states = jax.jit(env.reset)(rng=subkey)\n",
"eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)"
Expand Down
2 changes: 1 addition & 1 deletion examples/mees.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/mels.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down Expand Up @@ -509,7 +509,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=rng)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/mome.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"outputs": [],
"source": [
"# initial population\n",
"random_key = jax.random.PRNGKey(42)\n",
"random_key = jax.random.key(42)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"genotypes = jax.random.uniform(\n",
" random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/nsga2_spea2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
"outputs": [],
"source": [
"# Initial population\n",
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.random.key(0)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"genotypes = jax.random.uniform(\n",
" subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/omgmega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.random.key(0)\n",
"\n",
"# defines the population\n",
"random_key, subkey = jax.random.split(random_key)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/pga_aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/pgame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/qdpg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
"outputs": [],
"source": [
"# %%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, *keys = jax.random.split(key, num=1 + num_devices)\n",
"keys = jnp.stack(keys)\n",
"env_states, eval_env_first_states = jax.pmap(\n",
Expand Down Expand Up @@ -518,7 +518,7 @@
"%%time\n",
"rollout = []\n",
"\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"env_state = jax.jit(env.reset)(rng=rng)\n",
"\n",
"training_state, env_state = jax.tree_util.tree_map(\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/me_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_me() -> None:
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key = jax.random.key(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
Expand Down
4 changes: 2 additions & 2 deletions examples/smerl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.Key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -504,7 +504,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"random_key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
"outputs": [],
"source": [
"%%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, *keys = jax.random.split(key, num=1 + num_devices)\n",
"keys = jnp.stack(keys)\n",
"env_states, eval_env_first_states = jax.pmap(\n",
Expand Down
2 changes: 1 addition & 1 deletion qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def compute_cvt_centroids(
init="k-means++",
n_clusters=num_centroids,
n_init=1,
random_state=RandomState(subkey),
random_state=RandomState(jax.random.key_data(subkey)),
)
k_means.fit(x)
centroids = k_means.cluster_centers_
Expand Down
4 changes: 2 additions & 2 deletions qdax/core/neuroevolution/networks/seq2seq_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flax import linen as nn

Array = Any
PRNGKey = Any
PRNGKey = jax.Array


class EncoderLSTM(nn.Module):
Expand Down Expand Up @@ -52,7 +52,7 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array:
def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]:
# Use a dummy key since the default state init fn is just zeros.
return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore
jax.random.PRNGKey(0), (batch_size, hidden_size)
jax.random.key(0), (batch_size, hidden_size)
)


Expand Down
6 changes: 3 additions & 3 deletions qdax/tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Notes:
import jax
from qdax.tasks.arm import arm_scoring_function

random_key = jax.random.PRNGKey(0)
random_key = jax.random.key(0)

# Get scoring function
scoring_fn = arm_scoring_function
Expand Down Expand Up @@ -56,7 +56,7 @@ desc_size = 2
import jax
from qdax.tasks.standard_functions import sphere_scoring_function

random_key = jax.random.PRNGKey(0)
random_key = jax.random.key(0)

# Get scoring function
scoring_fn = sphere_scoring_function
Expand Down Expand Up @@ -98,7 +98,7 @@ desc_size = 2
import jax
from qdax.tasks.hypervolume_functions import square_scoring_function

random_key = jax.random.PRNGKey(0)
random_key = jax.random.key(0)

# Get scoring function
scoring_fn = square_scoring_function
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/cmame_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:
max_fitness = jnp.max(adjusted_fitness)
return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

random_key = jax.random.PRNGKey(0)
random_key = jax.random.key(0)
initial_population = (
jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.0
)
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/cmamega_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:
max_fitness = jnp.max(adjusted_fitness)
return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

random_key = jax.random.PRNGKey(0)
random_key = jax.random.key(0)
initial_population = jax.random.uniform(
random_key, shape=(batch_size, num_dimensions)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/dads_smerl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_dads_smerl() -> None:
eval_metrics=True,
)

key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
env_state = jax.jit(env.reset)(rng=key)
eval_env_first_state = jax.jit(eval_env.reset)(rng=key)

Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/dads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_dads() -> None:
eval_metrics=True,
)

key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
env_state = jax.jit(env.reset)(rng=key)
eval_env_first_state = jax.jit(eval_env.reset)(rng=key)

Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/diayn_smerl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_diayn_smerl() -> None:
eval_metrics=True,
)

key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
env_state = jax.jit(env.reset)(rng=key)
eval_env_first_state = jax.jit(eval_env.reset)(rng=key)

Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/diayn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_diayn() -> None:
eval_metrics=True,
)

key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
env_state = jax.jit(env.reset)(rng=key)
eval_env_first_state = jax.jit(eval_env.reset)(rng=key)

Expand Down
Loading

0 comments on commit 7ba39e3

Please sign in to comment.