-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from Co-Evolve/develop
v1.0.3
- Loading branch information
Showing
8 changed files
with
149 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import abc | ||
from typing import List, Tuple | ||
|
||
import chex | ||
import numpy as np | ||
from gymnasium.core import RenderFrame | ||
|
||
from moojoco.environment.base import BaseEnvState, BaseEnvironment, SpaceType | ||
|
||
|
||
class PostInitCaller(type): | ||
def __call__(cls, *args, **kwargs): | ||
obj = type.__call__(cls, *args, **kwargs) | ||
obj.__post__init__() | ||
return obj | ||
|
||
|
||
class CombinedMeta(abc.ABCMeta, PostInitCaller): | ||
pass | ||
|
||
|
||
class EnvironmentWrapper(BaseEnvironment, metaclass=CombinedMeta): | ||
def __init__(self, env: BaseEnvironment) -> None: | ||
super().__init__(configuration=env.environment_configuration) | ||
self._env = env | ||
|
||
self._observation_space: SpaceType | None = None | ||
self._action_space: SpaceType | None = None | ||
|
||
def __post__init__(self) -> None: | ||
# Make sure observation space and action space are initialised after creation | ||
# noinspection PyStatementEffect | ||
self.observation_space | ||
# noinspection PyStatementEffect | ||
self.action_space | ||
|
||
@property | ||
def action_space(self) -> SpaceType: | ||
if self._action_space is not None: | ||
return self._action_space | ||
return self._env.action_space | ||
|
||
@property | ||
def actuators(self) -> List[str]: | ||
return self._env.actuators | ||
|
||
@property | ||
def observation_space(self) -> SpaceType: | ||
if self._observation_space is not None: | ||
return self._observation_space | ||
return self._env.observation_space | ||
|
||
def step( | ||
self, state: BaseEnvState, action: chex.Array, *args, **kwargs | ||
) -> BaseEnvState: | ||
return self._env.step(state=state, action=action) | ||
|
||
def reset( | ||
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs | ||
) -> BaseEnvState: | ||
return self._env.reset(rng=rng) | ||
|
||
def render(self, state: BaseEnvState) -> List[RenderFrame] | None: | ||
return self._env.render(state=state) | ||
|
||
def close(self) -> None: | ||
return self._env.close() | ||
|
||
|
||
class TransformObservationEnvWrapper(EnvironmentWrapper, abc.ABC): | ||
def __init__(self, env: BaseEnvironment) -> None: | ||
super().__init__(env=env) | ||
|
||
@property | ||
@abc.abstractmethod | ||
def observation_space(self) -> SpaceType: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def _transform_observations(self, state: BaseEnvState) -> BaseEnvState: | ||
raise NotImplementedError | ||
|
||
def step( | ||
self, state: BaseEnvState, action: chex.Array, *args, **kwargs | ||
) -> BaseEnvState: | ||
state = self._env.step(state=state, action=action) | ||
state = self._transform_observations(state=state) | ||
return state | ||
|
||
def reset( | ||
self, rng: np.random.RandomState | chex.PRNGKey, *args, **kwargs | ||
) -> BaseEnvState: | ||
state = self._env.reset(rng=rng) | ||
state = self._transform_observations(state=state) | ||
return state | ||
|
||
|
||
class TransformActionEnvWrapper(EnvironmentWrapper, abc.ABC): | ||
def __init__(self, env: BaseEnvironment) -> None: | ||
super().__init__(env=env) | ||
|
||
@property | ||
@abc.abstractmethod | ||
def action_space(self) -> SpaceType: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def _transform_action( | ||
self, action: chex.Array, state: BaseEnvState | ||
) -> Tuple[chex.Array, BaseEnvState]: | ||
raise NotImplementedError | ||
|
||
def step( | ||
self, state: BaseEnvState, action: chex.Array, *args, **kwargs | ||
) -> BaseEnvState: | ||
action, state = self._transform_action(action=action, state=state) | ||
state = self._env.step(state=state, action=action) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" | |
|
||
[project] | ||
name = "moojoco" | ||
version = "1.0.2" | ||
version = "1.0.3" | ||
authors = [ | ||
{ name = "Dries Marzougui", email = "[email protected]" }, | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters