Skip to content

Commit

Permalink
formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
driesmarzougui committed Feb 9, 2024
1 parent 8cf79b3 commit 2a582ed
Show file tree
Hide file tree
Showing 9 changed files with 694 additions and 988 deletions.
306 changes: 114 additions & 192 deletions mujoco_utils/environment/base.py

Large diffs are not rendered by default.

66 changes: 26 additions & 40 deletions mujoco_utils/environment/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from gymnasium.core import RenderFrame

from mujoco_utils.mjcf.arena import MJCFArena
from mujoco_utils.environment.base import BaseEnvState, BaseEnvironment, MuJoCoEnvironmentConfiguration, SpaceType
from mujoco_utils.environment.base import (
BaseEnvState,
BaseEnvironment,
MuJoCoEnvironmentConfiguration,
SpaceType,
)
from mujoco_utils.environment.mjc_env import MJCEnv
from mujoco_utils.environment.mjx_env import MJXEnv
from mujoco_utils.mjcf.morphology import MJCFMorphology
Expand All @@ -17,71 +22,52 @@ class DualMuJoCoEnvironment(BaseEnvironment):
MJC_ENV_CLASS: type[MJCEnv]
MJX_ENV_CLASS: type[MJXEnv]

def __init__(
self,
env: MJCEnv | MJXEnv,
backend: str
) -> None:
def __init__(self, env: MJCEnv | MJXEnv, backend: str) -> None:
super().__init__(configuration=env.environment_configuration)
self.backend = backend
self._env = env

@classmethod
def from_morphology_and_arena(
cls,
morphology: MJCFMorphology,
arena: MJCFArena,
configuration: MuJoCoEnvironmentConfiguration,
backend: str
) -> DualMuJoCoEnvironment:
assert backend in ["MJC", "MJX"], f"Backend must either be 'MJC' or 'MJX'. {backend} was given."
cls,
morphology: MJCFMorphology,
arena: MJCFArena,
configuration: MuJoCoEnvironmentConfiguration,
backend: str,
) -> DualMuJoCoEnvironment:
assert backend in [
"MJC",
"MJX",
], f"Backend must either be 'MJC' or 'MJX'. {backend} was given."
if backend == "MJC":
env_class = cls.MJC_ENV_CLASS
else:
env_class = cls.MJX_ENV_CLASS
env = env_class.from_morphology_and_arena(
morphology=morphology, arena=arena, configuration=configuration
)
morphology=morphology, arena=arena, configuration=configuration
)
return cls(env=env, backend=backend)

@property
def action_space(
self
) -> SpaceType:
def action_space(self) -> SpaceType:
return self._env.action_space

@property
def actuators(
self
) -> List[str]:
def actuators(self) -> List[str]:
return self._env.actuators

@property
def observation_space(
self
) -> SpaceType:
def observation_space(self) -> SpaceType:
return self._env.observation_space

def step(
self,
state: BaseEnvState,
action: chex.Array
) -> BaseEnvState:
def step(self, state: BaseEnvState, action: chex.Array) -> BaseEnvState:
return self._env.step(state=state, action=action)

def reset(
self,
rng: np.random.RandomState | chex.PRNGKey
) -> BaseEnvState:
def reset(self, rng: np.random.RandomState | chex.PRNGKey) -> BaseEnvState:
return self._env.reset(rng=rng)

def render(
self,
state: BaseEnvState
) -> List[RenderFrame] | None:
def render(self, state: BaseEnvState) -> List[RenderFrame] | None:
return self._env.render(state=state)

def close(
self
) -> None:
def close(self) -> None:
return self._env.close()
Loading

0 comments on commit 2a582ed

Please sign in to comment.