diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 6a56ccd5..25c89f3d 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -149,9 +149,9 @@ def scoring_function_brax_envs( generate_unroll, episode_length=episode_length, play_step_fn=play_step_fn, - key=key, ) - _, data = jax.vmap(unroll_fn)(init_states, policies_params) + keys = jax.random.split(key, jax.tree.leaves(policies_params)[0].shape[0]) + _, data = jax.vmap(unroll_fn)(init_states, policies_params, keys) # Create a mask to extract data properly mask = get_mask_from_transitions(data) diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index d390d67a..ed4c4580 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -156,25 +156,21 @@ def jumanji_scoring_function( When the init states are different, this is not purely stochastic. """ - # Perform rollouts with each policy - key, subkey = jax.random.split(key) + # Step environments unroll_fn = partial( generate_jumanji_unroll, episode_length=episode_length, play_step_fn=play_step_fn, - key=subkey, ) + keys = jax.random.split(key, jax.tree.leaves(policies_params)[0].shape[0]) + _, _, data = jax.vmap(unroll_fn)(init_states, init_timesteps, policies_params, keys) - _final_state, _final_timestep, data = jax.vmap(unroll_fn)( - init_states, init_timesteps, policies_params - ) - - # create a mask to extract data properly + # Create a mask to extract data properly is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) mask = jnp.roll(is_done, 1, axis=1) mask = mask.at[:, 0].set(0) - # scores + # Evaluate fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) descriptors = descriptor_extractor(data, mask)