From 5e50dc359080bc4ef574f054dfbcec672c18b99f Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Mon, 23 Sep 2024 09:29:00 +0000 Subject: [PATCH] reset_fn returns same init_state in sampling_test --- tests/utils_test/sampling_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index f0cb6811..13f65f03 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -32,7 +32,6 @@ def test_sampling() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) - reset_fn = jax.jit(env.reset) # Init a random key key = jax.random.key(seed) @@ -79,12 +78,15 @@ def play_step_fn( return next_state, policy_params, key, transition + key, subkey = jax.random.split(key) + init_state = env.reset(subkey) + # Prepare the scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( scoring_function, episode_length=episode_length, - play_reset_fn=reset_fn, + play_reset_fn=lambda _: init_state, play_step_fn=play_step_fn, descriptor_extractor=descriptor_extraction_fn, )