diff --git a/moojoco/environment/base.py b/moojoco/environment/base.py index a01fac7..ac23ceb 100644 --- a/moojoco/environment/base.py +++ b/moojoco/environment/base.py @@ -12,8 +12,8 @@ from mujoco import mjx import moojoco.environment.mjx_spaces as mjx_spaces -from moojoco.mjcf.arena import MJCFArena from moojoco.environment.renderer import MujocoRenderer +from moojoco.mjcf.arena import MJCFArena from moojoco.mjcf.morphology import MJCFMorphology @@ -123,11 +123,15 @@ def observation_space(self) -> SpaceType: raise NotImplementedError @abc.abstractmethod - def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState: + def step( + self, state: BaseEnvState, action: chex.Array, *args, **kwargs + ) -> BaseEnvState: raise NotImplementedError @abc.abstractmethod - def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState: + def reset( + self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs + ) -> BaseEnvState: raise NotImplementedError @abc.abstractmethod @@ -225,7 +229,9 @@ def _initialize_mj_model_and_data(self) -> Tuple[mujoco.MjModel, mujoco.MjData]: mj_data = mujoco.MjData(mj_model) return mj_model, mj_data - def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState: + def step( + self, state: BaseEnvState, action: chex.Array, *args, **kwargs + ) -> BaseEnvState: previous_state = state state = self._update_simulation(state=state, ctrl=action) @@ -345,7 +351,9 @@ def _finish_reset( raise NotImplementedError @abc.abstractmethod - def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState: + def reset( + self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs + ) -> BaseEnvState: raise NotImplementedError @abc.abstractmethod diff --git a/moojoco/environment/dual.py b/moojoco/environment/dual.py index f0e792e..e215d36 100644 --- a/moojoco/environment/dual.py +++ b/moojoco/environment/dual.py @@ -60,10 +60,14 @@ def actuators(self) -> List[str]: def observation_space(self) -> SpaceType: return self._env.observation_space - def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState: + def step( + self, state: BaseEnvState, action: chex.Array, *args, **kwargs + ) -> BaseEnvState: return self._env.step(state=state, action=action) - def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState: + def reset( + self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs + ) -> BaseEnvState: return self._env.reset(rng=rng) def render(self, state: BaseEnvState) -> List[RenderFrame] | None: diff --git a/moojoco/environment/mjc_env.py b/moojoco/environment/mjc_env.py index 2f68429..ae29058 100644 --- a/moojoco/environment/mjc_env.py +++ b/moojoco/environment/mjc_env.py @@ -144,7 +144,7 @@ def _create_observables(self) -> List[MJCObservable]: raise NotImplementedError @abc.abstractmethod - def reset(self, rng: np.random.RandomState) -> MJCEnvState: + def reset(self, rng: np.random.RandomState, *args, **kwargs) -> MJCEnvState: raise NotImplementedError @abc.abstractmethod @@ -260,7 +260,9 @@ def actuators(self) -> List[str]: def observation_space(self) -> gymnasium.spaces.Space: return self._observation_space - def step(self, state: VectorMJCEnvState, action: np.ndarray) -> VectorMJCEnvState: + def step( + self, state: VectorMJCEnvState, action: np.ndarray, *args, **kwargs + ) -> VectorMJCEnvState: self._states = list( self._pool.map( lambda env, ste, act: env.step(state=ste, action=act), @@ -271,7 +273,9 @@ def step(self, state: VectorMJCEnvState, action: np.ndarray) -> VectorMJCEnvStat ) return self._merged_states - def reset(self, rng: List[np.random.RandomState]) -> VectorMJCEnvState: + def reset( + self, rng: List[np.random.RandomState], *args, **kwargs + ) -> VectorMJCEnvState: self._states = list( self._pool.map(lambda env, sub_rng: env.reset(sub_rng), self._envs, rng) ) diff --git a/moojoco/environment/mjx_env.py b/moojoco/environment/mjx_env.py index c463b26..513690a 100644 --- a/moojoco/environment/mjx_env.py +++ b/moojoco/environment/mjx_env.py @@ -228,7 +228,7 @@ def _create_observables(self) -> List[MJXObservable]: raise NotImplementedError @abc.abstractmethod - def reset(self, rng: jnp.ndarray) -> MJXEnvState: + def reset(self, rng: jnp.ndarray, *args, **kwargs) -> MJXEnvState: raise NotImplementedError @abc.abstractmethod diff --git a/moojoco/environment/wrapper.py b/moojoco/environment/wrapper.py new file mode 100644 index 0000000..e7cdd8f --- /dev/null +++ b/moojoco/environment/wrapper.py @@ -0,0 +1,118 @@ +import abc +from typing import List, Tuple + +import chex +import numpy as np +from gymnasium.core import RenderFrame + +from moojoco.environment.base import BaseEnvState, BaseEnvironment, SpaceType + + +class PostInitCaller(type): + def __call__(cls, *args, **kwargs): + obj = type.__call__(cls, *args, **kwargs) + obj.__post__init__() + return obj + + +class CombinedMeta(abc.ABCMeta, PostInitCaller): + pass + + +class EnvironmentWrapper(BaseEnvironment, metaclass=CombinedMeta): + def __init__(self, env: BaseEnvironment) -> None: + super().__init__(configuration=env.environment_configuration) + self._env = env + + self._observation_space: SpaceType | None = None + self._action_space: SpaceType | None = None + + def __post__init__(self) -> None: + # Make sure observation space and action space are initialised after creation + # noinspection PyStatementEffect + self.observation_space + # noinspection PyStatementEffect + self.action_space + + @property + def action_space(self) -> SpaceType: + if self._action_space is not None: + return self._action_space + return self._env.action_space + + @property + def actuators(self) -> List[str]: + return self._env.actuators + + @property + def observation_space(self) -> SpaceType: + if self._observation_space is not None: + return self._observation_space + return self._env.observation_space + + def step( + self, state: BaseEnvState, action: chex.Array, *args, **kwargs + ) -> BaseEnvState: + return self._env.step(state=state, action=action) + + def reset( + self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs + ) -> BaseEnvState: + return self._env.reset(rng=rng) + + def render(self, state: BaseEnvState) -> List[RenderFrame] | None: + return self._env.render(state=state) + + def close(self) -> None: + return self._env.close() + + +class TransformObservationEnvWrapper(EnvironmentWrapper, abc.ABC): + def __init__(self, env: BaseEnvironment) -> None: + super().__init__(env=env) + + @property + @abc.abstractmethod + def observation_space(self) -> SpaceType: + raise NotImplementedError + + @abc.abstractmethod + def _transform_observations(self, state: BaseEnvState) -> BaseEnvState: + raise NotImplementedError + + def step( + self, state: BaseEnvState, action: chex.Array, *args, **kwargs + ) -> BaseEnvState: + state = self._env.step(state=state, action=action) + state = self._transform_observations(state=state) + return state + + def reset( + self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs + ) -> BaseEnvState: + state = self._env.reset(rng=rng) + state = self._transform_observations(state=state) + return state + + +class TransformActionEnvWrapper(EnvironmentWrapper, abc.ABC): + def __init__(self, env: BaseEnvironment) -> None: + super().__init__(env=env) + + @property + @abc.abstractmethod + def action_space(self) -> SpaceType: + raise NotImplementedError + + @abc.abstractmethod + def _transform_action( + self, action: chex.Array, state: BaseEnvState + ) -> Tuple[chex.Array, BaseEnvState]: + raise NotImplementedError + + def step( + self, state: BaseEnvState, action: chex.Array, *args, **kwargs + ) -> BaseEnvState: + action, state = self._transform_action(action=action, state=state) + state = self._env.step(state=state, action=action) + return state diff --git a/moojoco/mjcf/component.py b/moojoco/mjcf/component.py index 37bcc20..e016322 100644 --- a/moojoco/mjcf/component.py +++ b/moojoco/mjcf/component.py @@ -7,6 +7,7 @@ import numpy as np from dm_control import mjcf from dm_control.mjcf import export_with_assets +from dm_control.mjcf.element import _AttachmentFrame from scipy.spatial.transform import Rotation @@ -35,7 +36,7 @@ def attach( position: Optional[np.ndarray] = None, euler: Optional[np.ndarray] = None, free_joint: bool = False, - ) -> None: + ) -> _AttachmentFrame: attachment_site = self.mjcf_body.add( "site", name=f"{self.base_name}_attachment_{other.base_name}", diff --git a/pyproject.toml b/pyproject.toml index 18abe7f..b9f6ef4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "moojoco" -version = "1.0.2" +version = "1.0.3" authors = [ { name = "Dries Marzougui", email = "dries.marzougui@gmail.com" }, ] diff --git a/setup.py b/setup.py index 9fc791d..f5049c0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='moojoco', - version='1.0.2', + version='1.0.3', description='A unified framework for implementing and interfacing with MuJoCo and MuJoCo-XLA simulation ' 'environments.', long_description=readme,