diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index be1d336d..94d4e160 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -159,8 +159,8 @@ def multi_sample_scoring_function( # vectorizing over axis 0 vectorizes over the num_samples random keys in_axes=(None, 0), # indicates that the vectorized axis will become axis 1, i.e., the final - # output is shape (batch_size, num_samples, ...) - out_axes=1, + # output is shape (batch_size, num_samples, ...) except for the random key + out_axes=(1, 1, 1, 0), ) all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( policies_params, keys diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index b2365382..c8dcc5af 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -163,6 +163,9 @@ def scoring_function(genotypes, random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( keys ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 53e2802e..f243725e 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -161,6 +161,9 @@ def scoring_function(genotypes, random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)( keys ) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index 537d4a26..9c4b2c83 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -104,6 +104,9 @@ def init_environments(random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys) diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 813deb60..e45a9701 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -100,6 +100,9 @@ def init_environments(random_key): # type: ignore observation_size=env.observation_size, buffer_size=buffer_size, ) + + # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 + keys = jax.random.key_data(keys) keys, training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys)