Skip to content

Commit

Permalink
reset_fn returns same init_state in sampling_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 23, 2024
1 parent 83d4201 commit 5e50dc3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/utils_test/sampling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_sampling() -> None:

# Init environment
env = environments.create(env_name, episode_length=episode_length)
reset_fn = jax.jit(env.reset)

# Init a random key
key = jax.random.key(seed)
Expand Down Expand Up @@ -79,12 +78,15 @@ def play_step_fn(

return next_state, policy_params, key, transition

key, subkey = jax.random.split(key)
init_state = env.reset(subkey)

# Prepare the scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
scoring_fn = functools.partial(
scoring_function,
episode_length=episode_length,
play_reset_fn=reset_fn,
play_reset_fn=lambda _: init_state,
play_step_fn=play_step_fn,
descriptor_extractor=descriptor_extraction_fn,
)
Expand Down

0 comments on commit 5e50dc3

Please sign in to comment.