Skip to content

Commit

Permalink
fix observation space in jumanji test script
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Aug 28, 2024
1 parent a86495c commit bd03dcb
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tests/default_tasks_test/jumanji_envs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def test_jumanji_utils() -> None:
def observation_processing(
observation: jumanji.environments.routing.snake.types.Observation,
) -> Observation:
network_input = jnp.ravel(observation.grid)
network_input = jnp.concatenate(
[
jnp.ravel(observation.grid),
jnp.array([observation.step_count]),
observation.action_mask.ravel(),
]
)
return network_input

play_step_fn = make_policy_network_play_step_fn_jumanji(
Expand All @@ -67,7 +73,12 @@ def observation_processing(
keys = jax.random.split(subkey, num=batch_size)

# compute observation size from observation spec
observation_size = np.prod(np.array(env.observation_spec().grid.shape))
obs_spec = env.observation_spec()
observation_size = int(
np.prod(obs_spec.grid.shape)
+ np.prod(obs_spec.step_count.shape)
+ np.prod(obs_spec.action_mask.shape)
)

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
Expand Down

0 comments on commit bd03dcb

Please sign in to comment.