From 165a051906d1b9e9ff75e23be58346e6f0aab279 Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Mon, 25 Mar 2024 10:03:59 +0100 Subject: [PATCH 1/3] bump to version 1.0.4 --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b9f6ef4..80ac2bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "moojoco" -version = "1.0.3" +version = "1.0.4" authors = [ { name = "Dries Marzougui", email = "dries.marzougui@gmail.com" }, ] diff --git a/setup.py b/setup.py index f5049c0..9f3ce12 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='moojoco', - version='1.0.3', + version='1.0.4', description='A unified framework for implementing and interfacing with MuJoCo and MuJoCo-XLA simulation ' 'environments.', long_description=readme, From f45144d4a3afdc7ee9b4f3fac53889dad5e7ed0b Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Mon, 25 Mar 2024 10:29:16 +0100 Subject: [PATCH 2/3] Bugfix: pass *args and **kwargs through --- moojoco/environment/dual.py | 4 ++-- moojoco/environment/mjc_env.py | 4 ++-- moojoco/environment/wrapper.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/moojoco/environment/dual.py b/moojoco/environment/dual.py index e215d36..6949ccb 100644 --- a/moojoco/environment/dual.py +++ b/moojoco/environment/dual.py @@ -63,12 +63,12 @@ def observation_space(self) -> SpaceType: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - return self._env.step(state=state, action=action) + return self._env.step(state=state, action=action, *args, **kwargs) def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - return self._env.reset(rng=rng) + return self._env.reset(rng=rng, *args, **kwargs) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: return self._env.render(state=state) diff --git a/moojoco/environment/mjc_env.py b/moojoco/environment/mjc_env.py index ae29058..a710ab0 100644 --- a/moojoco/environment/mjc_env.py +++ b/moojoco/environment/mjc_env.py @@ -265,7 +265,7 @@ def step( ) -> VectorMJCEnvState: self._states = list( self._pool.map( - lambda env, ste, act: env.step(state=ste, action=act), + lambda env, ste, act: env.step(state=ste, action=act, *args, **kwargs), self._envs, self._states, action, @@ -277,7 +277,7 @@ def reset( self, rng: List[np.random.RandomState], *args, **kwargs ) -> VectorMJCEnvState: self._states = list( - self._pool.map(lambda env, sub_rng: env.reset(sub_rng), self._envs, rng) + self._pool.map(lambda env, sub_rng: env.reset(sub_rng, *args, **kwargs), self._envs, rng) ) return self._merged_states diff --git a/moojoco/environment/wrapper.py b/moojoco/environment/wrapper.py index e7cdd8f..defb170 100644 --- a/moojoco/environment/wrapper.py +++ b/moojoco/environment/wrapper.py @@ -53,12 +53,12 @@ def observation_space(self) -> SpaceType: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - return self._env.step(state=state, action=action) + return self._env.step(state=state, action=action, *args, **kwargs) def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - return self._env.reset(rng=rng) + return self._env.reset(rng=rng, *args, **kwargs) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: return self._env.render(state=state) @@ -83,14 +83,14 @@ def _transform_observations(self, state: BaseEnvState) -> BaseEnvState: def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: - state = self._env.step(state=state, action=action) + state = self._env.step(state=state, action=action, *args, **kwargs) state = self._transform_observations(state=state) return state def reset( self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs ) -> BaseEnvState: - state = self._env.reset(rng=rng) + state = self._env.reset(rng=rng, *args, **kwargs) state = self._transform_observations(state=state) return state @@ -114,5 +114,5 @@ def step( self, state: BaseEnvState, action: chex.Array, *args, **kwargs ) -> BaseEnvState: action, state = self._transform_action(action=action, state=state) - state = self._env.step(state=state, action=action) + state = self._env.step(state=state, action=action, *args, **kwargs) return state From e2b05db54680804e4f7aa1559e84df6d6c67ef4f Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Mon, 25 Mar 2024 10:31:12 +0100 Subject: [PATCH 3/3] lint --- moojoco/environment/mjc_env.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/moojoco/environment/mjc_env.py b/moojoco/environment/mjc_env.py index a710ab0..7c2a889 100644 --- a/moojoco/environment/mjc_env.py +++ b/moojoco/environment/mjc_env.py @@ -277,7 +277,11 @@ def reset( self, rng: List[np.random.RandomState], *args, **kwargs ) -> VectorMJCEnvState: self._states = list( - self._pool.map(lambda env, sub_rng: env.reset(sub_rng, *args, **kwargs), self._envs, rng) + self._pool.map( + lambda env, sub_rng: env.reset(sub_rng, *args, **kwargs), + self._envs, + rng, + ) ) return self._merged_states