Skip to content

Commit

Permalink
Merge pull request #9 from Co-Evolve/develop
Browse files Browse the repository at this point in the history
v1.0.3
  • Loading branch information
driesmarzougui authored Mar 25, 2024
2 parents 73a5c53 + 3d15ca7 commit e1da263
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 14 deletions.
18 changes: 13 additions & 5 deletions moojoco/environment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from mujoco import mjx

import moojoco.environment.mjx_spaces as mjx_spaces
from moojoco.mjcf.arena import MJCFArena
from moojoco.environment.renderer import MujocoRenderer
from moojoco.mjcf.arena import MJCFArena
from moojoco.mjcf.morphology import MJCFMorphology


Expand Down Expand Up @@ -123,11 +123,15 @@ def observation_space(self) -> SpaceType:
raise NotImplementedError

@abc.abstractmethod
def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState:
def step(
self, state: BaseEnvState, action: chex.Array, *args, **kwargs
) -> BaseEnvState:
raise NotImplementedError

@abc.abstractmethod
def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState:
def reset(
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs
) -> BaseEnvState:
raise NotImplementedError

@abc.abstractmethod
Expand Down Expand Up @@ -225,7 +229,9 @@ def _initialize_mj_model_and_data(self) -> Tuple[mujoco.MjModel, mujoco.MjData]:
mj_data = mujoco.MjData(mj_model)
return mj_model, mj_data

def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState:
def step(
self, state: BaseEnvState, action: chex.Array, *args, **kwargs
) -> BaseEnvState:
previous_state = state
state = self._update_simulation(state=state, ctrl=action)

Expand Down Expand Up @@ -345,7 +351,9 @@ def _finish_reset(
raise NotImplementedError

@abc.abstractmethod
def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState:
def reset(
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs
) -> BaseEnvState:
raise NotImplementedError

@abc.abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions moojoco/environment/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ def actuators(self) -> List[str]:
def observation_space(self) -> SpaceType:
return self._env.observation_space

def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState:
def step(
self, state: BaseEnvState, action: chex.Array, *args, **kwargs
) -> BaseEnvState:
return self._env.step(state=state, action=action)

def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState:
def reset(
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs
) -> BaseEnvState:
return self._env.reset(rng=rng)

def render(self, state: BaseEnvState) -> List[RenderFrame] | None:
Expand Down
10 changes: 7 additions & 3 deletions moojoco/environment/mjc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _create_observables(self) -> List[MJCObservable]:
raise NotImplementedError

@abc.abstractmethod
def reset(self, rng: np.random.RandomState) -> MJCEnvState:
def reset(self, rng: np.random.RandomState, *args, **kwargs) -> MJCEnvState:
raise NotImplementedError

@abc.abstractmethod
Expand Down Expand Up @@ -260,7 +260,9 @@ def actuators(self) -> List[str]:
def observation_space(self) -> gymnasium.spaces.Space:
return self._observation_space

def step(self, state: VectorMJCEnvState, action: np.ndarray) -> VectorMJCEnvState:
def step(
self, state: VectorMJCEnvState, action: np.ndarray, *args, **kwargs
) -> VectorMJCEnvState:
self._states = list(
self._pool.map(
lambda env, ste, act: env.step(state=ste, action=act),
Expand All @@ -271,7 +273,9 @@ def step(self, state: VectorMJCEnvState, action: np.ndarray) -> VectorMJCEnvStat
)
return self._merged_states

def reset(self, rng: List[np.random.RandomState]) -> VectorMJCEnvState:
def reset(
self, rng: List[np.random.RandomState], *args, **kwargs
) -> VectorMJCEnvState:
self._states = list(
self._pool.map(lambda env, sub_rng: env.reset(sub_rng), self._envs, rng)
)
Expand Down
2 changes: 1 addition & 1 deletion moojoco/environment/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _create_observables(self) -> List[MJXObservable]:
raise NotImplementedError

@abc.abstractmethod
def reset(self, rng: jnp.ndarray) -> MJXEnvState:
def reset(self, rng: jnp.ndarray, *args, **kwargs) -> MJXEnvState:
raise NotImplementedError

@abc.abstractmethod
Expand Down
118 changes: 118 additions & 0 deletions moojoco/environment/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import abc
from typing import List, Tuple

import chex
import numpy as np
from gymnasium.core import RenderFrame

from moojoco.environment.base import BaseEnvState, BaseEnvironment, SpaceType


class PostInitCaller(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post__init__()
return obj


class CombinedMeta(abc.ABCMeta, PostInitCaller):
pass


class EnvironmentWrapper(BaseEnvironment, metaclass=CombinedMeta):
def __init__(self, env: BaseEnvironment) -> None:
super().__init__(configuration=env.environment_configuration)
self._env = env

self._observation_space: SpaceType | None = None
self._action_space: SpaceType | None = None

def __post__init__(self) -> None:
# Make sure observation space and action space are initialised after creation
# noinspection PyStatementEffect
self.observation_space
# noinspection PyStatementEffect
self.action_space

@property
def action_space(self) -> SpaceType:
if self._action_space is not None:
return self._action_space
return self._env.action_space

@property
def actuators(self) -> List[str]:
return self._env.actuators

@property
def observation_space(self) -> SpaceType:
if self._observation_space is not None:
return self._observation_space
return self._env.observation_space

def step(
self, state: BaseEnvState, action: chex.Array, *args, **kwargs
) -> BaseEnvState:
return self._env.step(state=state, action=action)

def reset(
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs
) -> BaseEnvState:
return self._env.reset(rng=rng)

def render(self, state: BaseEnvState) -> List[RenderFrame] | None:
return self._env.render(state=state)

def close(self) -> None:
return self._env.close()


class TransformObservationEnvWrapper(EnvironmentWrapper, abc.ABC):
def __init__(self, env: BaseEnvironment) -> None:
super().__init__(env=env)

@property
@abc.abstractmethod
def observation_space(self) -> SpaceType:
raise NotImplementedError

@abc.abstractmethod
def _transform_observations(self, state: BaseEnvState) -> BaseEnvState:
raise NotImplementedError

def step(
self, state: BaseEnvState, action: chex.Array, *args, **kwargs
) -> BaseEnvState:
state = self._env.step(state=state, action=action)
state = self._transform_observations(state=state)
return state

def reset(
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs
) -> BaseEnvState:
state = self._env.reset(rng=rng)
state = self._transform_observations(state=state)
return state


class TransformActionEnvWrapper(EnvironmentWrapper, abc.ABC):
def __init__(self, env: BaseEnvironment) -> None:
super().__init__(env=env)

@property
@abc.abstractmethod
def action_space(self) -> SpaceType:
raise NotImplementedError

@abc.abstractmethod
def _transform_action(
self, action: chex.Array, state: BaseEnvState
) -> Tuple[chex.Array, BaseEnvState]:
raise NotImplementedError

def step(
self, state: BaseEnvState, action: chex.Array, *args, **kwargs
) -> BaseEnvState:
action, state = self._transform_action(action=action, state=state)
state = self._env.step(state=state, action=action)
return state
3 changes: 2 additions & 1 deletion moojoco/mjcf/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from dm_control import mjcf
from dm_control.mjcf import export_with_assets
from dm_control.mjcf.element import _AttachmentFrame
from scipy.spatial.transform import Rotation


Expand Down Expand Up @@ -35,7 +36,7 @@ def attach(
position: Optional[np.ndarray] = None,
euler: Optional[np.ndarray] = None,
free_joint: bool = False,
) -> None:
) -> _AttachmentFrame:
attachment_site = self.mjcf_body.add(
"site",
name=f"{self.base_name}_attachment_{other.base_name}",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "moojoco"
version = "1.0.2"
version = "1.0.3"
authors = [
{ name = "Dries Marzougui", email = "[email protected]" },
]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='moojoco',
version='1.0.2',
version='1.0.3',
description='A unified framework for implementing and interfacing with MuJoCo and MuJoCo-XLA simulation '
'environments.',
long_description=readme,
Expand Down

0 comments on commit e1da263

Please sign in to comment.