diff --git a/moojoco/environment/dual.py b/moojoco/environment/dual.py index 6949ccb..d65a960 100644 --- a/moojoco/environment/dual.py +++ b/moojoco/environment/dual.py @@ -63,12 +63,12 @@ def observation_space(self) -> SpaceType: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - return self._env.step(state=state, action=action, *args, **kwargs) + return self._env.step(state, action, *args, **kwargs) def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - return self._env.reset(rng=rng, *args, **kwargs) + return self._env.reset(rng, *args, **kwargs) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: return self._env.render(state=state) diff --git a/moojoco/environment/mjc_env.py b/moojoco/environment/mjc_env.py index 7c2a889..1f5fb03 100644 --- a/moojoco/environment/mjc_env.py +++ b/moojoco/environment/mjc_env.py @@ -265,7 +265,7 @@ def step( ) -> VectorMJCEnvState: self._states = list( self._pool.map( - lambda env, ste, act: env.step(state=ste, action=act, *args, **kwargs), + lambda env, ste, act: env.step(ste, act, *args, **kwargs), self._envs, self._states, action, diff --git a/moojoco/environment/wrapper.py b/moojoco/environment/wrapper.py index defb170..7c6d4ff 100644 --- a/moojoco/environment/wrapper.py +++ b/moojoco/environment/wrapper.py @@ -53,12 +53,12 @@ def observation_space(self) -> SpaceType: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - return self._env.step(state=state, action=action, *args, **kwargs) + return self._env.step(state, action, *args, **kwargs) def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - return self._env.reset(rng=rng, *args, **kwargs) + return self._env.reset(rng, *args, **kwargs) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: return self._env.render(state=state) @@ -83,14 +83,14 @@ def _transform_observations(self, state: BaseEnvState) -> BaseEnvState: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - state = self._env.step(state=state, action=action, *args, **kwargs) + state = self._env.step(state, action, *args, **kwargs) state = self._transform_observations(state=state) return state def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - state = self._env.reset(rng=rng, *args, **kwargs) + state = self._env.reset(rng, *args, **kwargs) state = self._transform_observations(state=state) return state @@ -114,5 +114,5 @@ def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: action, state = self._transform_action(action=action, state=state) - state = self._env.step(state=state, action=action, *args, **kwargs) + state = self._env.step(state, action, *args, **kwargs) return state