From f3a8c4c18d5e15f01661a62dcfa2c3647245bbfe Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Wed, 23 Oct 2024 13:36:40 +0100 Subject: [PATCH] feat: update jumanji version and the env specs --- README.md | 2 +- matrax/env.py | 20 ++++++++++---------- matrax/env_test.py | 4 ++-- requirements/requirements.txt | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 976996f..3d4d3d1 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) # Interact with the (jit-able) environment -action = env.action_spec().generate_value() # Action selection (dummy value here) +action = env.action_spec.generate_value() # Action selection (dummy value here) state, timestep = jax.jit(env.step)(state, action) # Take a step and observe the next state and time step ``` diff --git a/matrax/env.py b/matrax/env.py index 1f4f4a5..5b0a079 100644 --- a/matrax/env.py +++ b/matrax/env.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools +from functools import cached_property, partial from typing import Tuple import chex @@ -25,7 +25,7 @@ from matrax.types import Observation, State -class MatrixGame(Environment[State]): +class MatrixGame(Environment[State, specs.MultiDiscreteArray, Observation]): """JAX implementation of the 2-player matrix game environment: https://github.com/uoe-agents/matrix-games @@ -42,7 +42,7 @@ class MatrixGame(Environment[State]): env = MatrixGame(payoff_matrix) key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) ``` """ @@ -92,9 +92,9 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: dummy_actions = jnp.ones((self.num_agents,), int) * -1 # collect first observations and create timestep - agent_obs = jax.vmap( - functools.partial(self._make_agent_observation, dummy_actions) - )(jnp.arange(self.num_agents)) + agent_obs = jax.vmap(partial(self._make_agent_observation, dummy_actions))( + jnp.arange(self.num_agents) + ) observation = Observation( agent_obs=agent_obs, step_count=state.step_count, @@ -124,16 +124,14 @@ def compute_reward( reward_idx = tuple(actions) return payoff_matrix_per_agent[reward_idx].astype(float) - rewards = jax.vmap(functools.partial(compute_reward, actions))( - self.payoff_matrix - ) + rewards = jax.vmap(partial(compute_reward, actions))(self.payoff_matrix) # construct timestep and check environment termination steps = state.step_count + 1 done = steps >= self.time_limit # compute next observation - agent_obs = jax.vmap(functools.partial(self._make_agent_observation, actions))( + agent_obs = jax.vmap(partial(self._make_agent_observation, actions))( jnp.arange(self.num_agents) ) next_observation = Observation( @@ -169,6 +167,7 @@ def _make_agent_observation( lambda: jnp.zeros(self.num_agents, int), ) + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the MatrixGame environment. Returns: @@ -190,6 +189,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. Since this is a multi-agent environment, the environment expects an array of actions. diff --git a/matrax/env_test.py b/matrax/env_test.py index 03676c4..a48a3a6 100644 --- a/matrax/env_test.py +++ b/matrax/env_test.py @@ -43,8 +43,8 @@ def matrix_game_env_with_state() -> MatrixGame: def test_matrix_game__specs(matrix_game_env: MatrixGame) -> None: """Validate environment specs conform to the expected shapes and values""" - action_spec = matrix_game_env.action_spec() - observation_spec = matrix_game_env.observation_spec() + action_spec = matrix_game_env.action_spec + observation_spec = matrix_game_env.observation_spec assert observation_spec.agent_obs.shape == (2, 2) # type: ignore assert action_spec.num_values.shape[0] == matrix_game_env.num_agents diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95efb21..e808e41 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1 +1 @@ -jumanji==0.3.1 +jumanji==1.0.1