From f2fd28ecddc247b2846f6bcd42668f8e381529f6 Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Mon, 25 Mar 2024 10:37:51 +0100 Subject: [PATCH] Fixed argument bug --- moojoco/environment/dual.py | 4 ++-- moojoco/environment/mjc_env.py | 2 +- moojoco/environment/wrapper.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/moojoco/environment/dual.py b/moojoco/environment/dual.py index 6949ccb..d65a960 100644 --- a/moojoco/environment/dual.py +++ b/moojoco/environment/dual.py @@ -63,12 +63,12 @@ def observation_space(self) -> SpaceType: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - return self._env.step(state=state, action=action, *args, **kwargs) + return self._env.step(state, action, *args, **kwargs) def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - return self._env.reset(rng=rng, *args, **kwargs) + return self._env.reset(rng, *args, **kwargs) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: return self._env.render(state=state) diff --git a/moojoco/environment/mjc_env.py b/moojoco/environment/mjc_env.py index 7c2a889..1f5fb03 100644 --- a/moojoco/environment/mjc_env.py +++ b/moojoco/environment/mjc_env.py @@ -265,7 +265,7 @@ def step( ) -> VectorMJCEnvState: self._states = list( self._pool.map( - lambda env, ste, act: env.step(state=ste, action=act, *args, **kwargs), + lambda env, ste, act: env.step(ste, act, *args, **kwargs), self._envs, self._states, action, diff --git a/moojoco/environment/wrapper.py b/moojoco/environment/wrapper.py index defb170..7c6d4ff 100644 --- a/moojoco/environment/wrapper.py +++ b/moojoco/environment/wrapper.py @@ -53,12 +53,12 @@ def observation_space(self) -> SpaceType: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - return self._env.step(state=state, action=action, *args, **kwargs) + return self._env.step(state, action, *args, **kwargs) def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - return self._env.reset(rng=rng, *args, **kwargs) + return self._env.reset(rng, *args, **kwargs) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: return self._env.render(state=state) @@ -83,14 +83,14 @@ def _transform_observations(self, state: BaseEnvState) -> BaseEnvState: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - state = self._env.step(state=state, action=action, *args, **kwargs) + state = self._env.step(state, action, *args, **kwargs) state = self._transform_observations(state=state) return state def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - state = self._env.reset(rng=rng, *args, **kwargs) + state = self._env.reset(rng, *args, **kwargs) state = self._transform_observations(state=state) return state @@ -114,5 +114,5 @@ def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: action, state = self._transform_action(action=action, state=state) - state = self._env.step(state=state, action=action, *args, **kwargs) + state = self._env.step(state, action, *args, **kwargs) return state