From bd03dcb2adbcadbd4629175d1beab6e8eafac0a3 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 28 Aug 2024 09:13:37 +0000 Subject: [PATCH] fix observation space in jumanji test script --- tests/default_tasks_test/jumanji_envs_test.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index 737a4a17..636a02cf 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -53,7 +53,13 @@ def test_jumanji_utils() -> None: def observation_processing( observation: jumanji.environments.routing.snake.types.Observation, ) -> Observation: - network_input = jnp.ravel(observation.grid) + network_input = jnp.concatenate( + [ + jnp.ravel(observation.grid), + jnp.array([observation.step_count]), + observation.action_mask.ravel(), + ] + ) return network_input play_step_fn = make_policy_network_play_step_fn_jumanji( @@ -67,7 +73,12 @@ def observation_processing( keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec - observation_size = np.prod(np.array(env.observation_spec().grid.shape)) + obs_spec = env.observation_spec() + observation_size = int( + np.prod(obs_spec.grid.shape) + + np.prod(obs_spec.step_count.shape) + + np.prod(obs_spec.action_mask.shape) + ) fake_batch = jnp.zeros(shape=(batch_size, observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch)