diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index f0cb6811..13f65f03 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -32,7 +32,6 @@ def test_sampling() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) - reset_fn = jax.jit(env.reset) # Init a random key key = jax.random.key(seed) @@ -79,12 +78,15 @@ def play_step_fn( return next_state, policy_params, key, transition + key, subkey = jax.random.split(key) + init_state = env.reset(subkey) + # Prepare the scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( scoring_function, episode_length=episode_length, - play_reset_fn=reset_fn, + play_reset_fn=lambda _: init_state, play_step_fn=play_step_fn, descriptor_extractor=descriptor_extraction_fn, )