Skip to content

Commit

Permalink
adapt test to new Jumanji API
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Dec 10, 2023
1 parent 71e65e9 commit 17bec85
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/default_tasks_test/jumanji_envs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import jumanji
import jumanji.environments.routing.snake
import numpy as np
import pytest

Expand Down Expand Up @@ -49,8 +50,10 @@ def test_jumanji_utils() -> None:
final_activation=jax.nn.softmax,
)

def observation_processing(observation: jumanji.types.Observation) -> Observation:
network_input = jnp.ravel(observation)
def observation_processing(
observation: jumanji.environments.routing.snake.types.Observation,
) -> Observation:
network_input = jnp.ravel(observation.grid)
return network_input

play_step_fn = make_policy_network_play_step_fn_jumanji(
Expand All @@ -64,7 +67,7 @@ def observation_processing(observation: jumanji.types.Observation) -> Observatio
keys = jax.random.split(subkey, num=batch_size)

# compute observation size from observation spec
observation_size = np.prod(np.array(env.observation_spec().shape))
observation_size = np.prod(np.array(env.observation_spec().grid.shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
Expand Down Expand Up @@ -136,4 +139,5 @@ def bd_extraction(


if __name__ == "__main__":
pytest.assume
test_jumanji_utils()

0 comments on commit 17bec85

Please sign in to comment.