Skip to content

Commit

Permalink
Fix brax and jumaji scoring functions by not passing the key to the p…
Browse files Browse the repository at this point in the history
…artial
  • Loading branch information
maxencefaldor committed Sep 22, 2024
1 parent af99afd commit cdb9573
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
4 changes: 2 additions & 2 deletions qdax/tasks/brax_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def scoring_function_brax_envs(
generate_unroll,
episode_length=episode_length,
play_step_fn=play_step_fn,
key=key,
)
_, data = jax.vmap(unroll_fn)(init_states, policies_params)
keys = jax.random.split(key, jax.tree.leaves(policies_params)[0].shape[0])
_, data = jax.vmap(unroll_fn)(init_states, policies_params, keys)

# Create a mask to extract data properly
mask = get_mask_from_transitions(data)
Expand Down
14 changes: 5 additions & 9 deletions qdax/tasks/jumanji_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,25 +156,21 @@ def jumanji_scoring_function(
When the init states are different, this is not purely stochastic.
"""

# Perform rollouts with each policy
key, subkey = jax.random.split(key)
# Step environments
unroll_fn = partial(
generate_jumanji_unroll,
episode_length=episode_length,
play_step_fn=play_step_fn,
key=subkey,
)
keys = jax.random.split(key, jax.tree.leaves(policies_params)[0].shape[0])
_, _, data = jax.vmap(unroll_fn)(init_states, init_timesteps, policies_params, keys)

_final_state, _final_timestep, data = jax.vmap(unroll_fn)(
init_states, init_timesteps, policies_params
)

# create a mask to extract data properly
# Create a mask to extract data properly
is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1)
mask = jnp.roll(is_done, 1, axis=1)
mask = mask.at[:, 0].set(0)

# scores
# Evaluate
fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1)
descriptors = descriptor_extractor(data, mask)

Expand Down

0 comments on commit cdb9573

Please sign in to comment.