Skip to content

Commit

Permalink
fixing bugs from new data structure of RNG keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 20, 2024
1 parent dc31b0c commit 2040636
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions qdax/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def multi_sample_scoring_function(
# vectorizing over axis 0 vectorizes over the num_samples random keys
in_axes=(None, 0),
# indicates that the vectorized axis will become axis 1, i.e., the final
# output is shape (batch_size, num_samples, ...)
out_axes=1,
# output is shape (batch_size, num_samples, ...) except for the random key
out_axes=(1, 1, 1, 0),
)
all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn(
policies_params, keys
Expand Down
3 changes: 3 additions & 0 deletions tests/baselines_test/me_pbt_sac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def scoring_function(genotypes, random_key): # type: ignore
observation_size=env.observation_size,
buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(
keys
)
Expand Down
3 changes: 3 additions & 0 deletions tests/baselines_test/me_pbt_td3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def scoring_function(genotypes, random_key): # type: ignore
observation_size=env.observation_size,
buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(
keys
)
Expand Down
3 changes: 3 additions & 0 deletions tests/baselines_test/pbt_sac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def init_environments(random_key): # type: ignore
observation_size=env.observation_size,
buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
keys, training_states, replay_buffers = jax.pmap(
agent_init_fn, axis_name="p", devices=devices
)(keys)
Expand Down
3 changes: 3 additions & 0 deletions tests/baselines_test/pbt_td3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def init_environments(random_key): # type: ignore
observation_size=env.observation_size,
buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
keys, training_states, replay_buffers = jax.pmap(
agent_init_fn, axis_name="p", devices=devices
)(keys)
Expand Down

0 comments on commit 2040636

Please sign in to comment.