Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Nov 22, 2024
2 parents 162a74d + e959bf3 commit 4996869
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 49 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/docs_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,26 @@ jobs:
docs:
runs-on: ubuntu-latest
container:
image: python:3.8.10
image: python:3.12.6
steps:
- name: Install dependencies for deploy
run: apt-get update && apt-get install -y rsync

- name: Checkout jumanji 🐍
uses: actions/checkout@v3

- name: Install python dependencies 🔧
run: pip install .[dev]

- name: Git trust site directory 🤝
run: |
git config --global --add safe.directory /__w/jumanji/jumanji
git config --global --add safe.directory /docs
git config --global --add safe.directory /docs_public
- name: Build docs 📖
run: mkdocs build --verbose --site-dir docs_public

- name: Deploy 🚀
uses: JamesIves/github-pages-deploy-action@v4
with:
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Install dependencies
Expand All @@ -23,4 +24,4 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
hatch build
twine upload dist/*
twine upload dist/* --verbose
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ combinatorial problems.
- 🍬 **Wrappers**: easily connect to your favourite RL frameworks and libraries such as
[Acme](https://github.com/deepmind/acme),
[Stable Baselines3](https://github.com/DLR-RM/stable-baselines3),
[RLlib](https://docs.ray.io/en/latest/rllib/index.html), [OpenAI Gym](https://github.com/openai/gym)
[RLlib](https://docs.ray.io/en/latest/rllib/index.html), [Gymnasium](https://github.com/Farama-Foundation/Gymnasium)
and [DeepMind-Env](https://github.com/deepmind/dm_env) through our `dm_env` and `gym` wrappers.
- 🎓 **Examples**: guides to facilitate Jumanji's adoption and highlight the added value of
JAX-based environments.
Expand Down
10 changes: 5 additions & 5 deletions docs/guides/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ next_timestep = dm_env.step(action)
...
```

## Jumanji To Gym
We can also convert our Jumanji environments to a [Gym](https://github.com/openai/gym) environment!
Below is an example of how to convert a Jumanji environment into a Gym environment.
## Jumanji To Gymnasium
We can also convert our Jumanji environments to a [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) environment!
Below is an example of how to convert a Jumanji environment into a Gymnasium environment.

```python
import jumanji.wrappers

env = jumanji.make("Snake-6x6-v0")
gym_env = jumanji.wrappers.JumanjiToGymWrapper(env)

obs = gym_env.reset()
obs, info = gym_env.reset()
action = gym_env.action_space.sample()
observation, reward, done, extra = gym_env.step(action)
observation, reward, term, trunc, info = gym_env.step(action)
...
```

Expand Down
2 changes: 1 addition & 1 deletion jumanji/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import chex
import dm_env.specs
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import chex
import dm_env.specs
import gym.spaces
import gymnasium as gym
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down
2 changes: 1 addition & 1 deletion jumanji/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "1.0.1"
__version__ = "1.1.0"
41 changes: 18 additions & 23 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from __future__ import annotations

from functools import cached_property
from typing import Any, Callable, ClassVar, Dict, Generic, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeAlias, Union

import chex
import dm_env.specs
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -28,7 +28,7 @@
from jumanji.types import TimeStep

# Type alias that corresponds to ObsType in the Gym API
GymObservation = Any
GymObservation: TypeAlias = chex.ArrayNumpy | Dict[str, Union[chex.ArrayNumpy, "GymObservation"]]


class Wrapper(Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]):
Expand Down Expand Up @@ -584,10 +584,6 @@ def render(self, state: State) -> Any:
class JumanjiToGymWrapper(gym.Env, Generic[State, ActionSpec, Observation]):
"""A wrapper that converts a Jumanji `Environment` to one that follows the `gym.Env` API."""

# Flag that prevents `gym.register` from misinterpreting the `_step` and
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(
self,
env: Environment[State, ActionSpec, Observation],
Expand Down Expand Up @@ -618,21 +614,21 @@ def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]:

def step(
state: State, action: chex.Array
) -> Tuple[State, Observation, chex.Array, bool, Optional[Any]]:
) -> Tuple[State, Observation, chex.Array, chex.Array, chex.Array, Optional[Any]]:
"""Step function of a Jumanji environment to be jitted."""
state, timestep = self._env.step(state, action)
done = jnp.bool_(timestep.last())
return state, timestep.observation, timestep.reward, done, timestep.extras
term = timestep.discount.astype(bool)
trunc = timestep.last().astype(bool)
return state, timestep.observation, timestep.reward, term, trunc, timestep.extras

self._step = jax.jit(step, backend=self.backend)

def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[GymObservation, Tuple[GymObservation, Optional[Any]]]:
) -> Tuple[GymObservation, Dict[str, Any]]:
"""Resets the environment to an initial state by starting a new sequence
and returns the first `Observation` of this sequence.
Expand All @@ -648,13 +644,11 @@ def reset(
# Convert the observation to a numpy array or a nested dict thereof
obs = jumanji_to_gym_obs(obs)

if return_info:
info = jax.tree_util.tree_map(np.asarray, extras)
return obs, info
else:
return obs # type: ignore
return obs, jax.device_get(extras)

def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Optional[Any]]:
def step(
self, action: chex.ArrayNumpy
) -> Tuple[GymObservation, float, bool, bool, Dict[str, Any]]:
"""Updates the environment according to the action and returns an `Observation`.
Args:
Expand All @@ -667,16 +661,17 @@ def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Op
info: contains supplementary information such as metrics.
"""

action = jnp.array(action) # Convert input numpy array to JAX array
self._state, obs, reward, done, extras = self._step(self._state, action)
action_jax = jnp.asarray(action) # Convert input numpy array to JAX array
self._state, obs, reward, term, trunc, extras = self._step(self._state, action_jax)

# Convert to get the correct signature
obs = jumanji_to_gym_obs(obs)
reward = float(reward)
terminated = bool(done)
info = jax.tree_util.tree_map(np.asarray, extras)
terminated = bool(term)
truncated = bool(trunc)
info = jax.device_get(extras)

return obs, reward, terminated, info
return obs, reward, terminated, truncated, info

def seed(self, seed: int = 0) -> None:
"""Function which sets the seed for the environment's random number generator(s).
Expand Down
15 changes: 10 additions & 5 deletions jumanji/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import chex
import dm_env.specs
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -265,15 +265,18 @@ def test_jumanji_environment_to_gym_env__reset(
self, fake_gym_env: FakeJumanjiToGymWrapper
) -> None:
"""Validates reset function of the wrapped environment."""
observation1 = fake_gym_env.reset()
observation1, info1 = fake_gym_env.reset()
state1 = fake_gym_env._state
observation2 = fake_gym_env.reset()
observation2, info2 = fake_gym_env.reset()
state2 = fake_gym_env._state

# Observation is typically numpy array
assert isinstance(observation1, chex.ArrayNumpy)
assert isinstance(observation2, chex.ArrayNumpy)

assert isinstance(info1, dict)
assert isinstance(info2, dict)

# Check that the observations are equal
chex.assert_trees_all_equal(observation1, observation2)
assert_trees_are_different(state1, state2)
Expand All @@ -282,12 +285,14 @@ def test_jumanji_environment_to_gym_env__step(
self, fake_gym_env: FakeJumanjiToGymWrapper
) -> None:
"""Validates step function of the wrapped environment."""
observation = fake_gym_env.reset()
observation, _ = fake_gym_env.reset()
action = fake_gym_env.action_space.sample()
next_observation, reward, terminated, info = fake_gym_env.step(action)
next_observation, reward, terminated, truncated, info = fake_gym_env.step(action)
assert_trees_are_different(observation, next_observation)
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
assert isinstance(info, dict)

def test_jumanji_environment_to_gym_env__observation_space(
self, fake_gym_env: FakeJumanjiToGymWrapper
Expand Down
14 changes: 7 additions & 7 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
coverage
livereload
mkdocs==1.2.3
mkdocs==1.6.1
mkdocs-git-revision-date-plugin==0.3.2
mkdocs-include-markdown-plugin==4.0.4
mkdocs-material==8.2.7
mkdocs-mermaid2-plugin==0.6.0
mkdocs_autorefs<1.0
mkdocstrings==0.18.0
mknotebooks==0.7.1
mkdocs-include-markdown-plugin==7.1.1
mkdocs-material==9.5.45
mkdocs-mermaid2-plugin==1.1.0
mkdocs_autorefs==1.2.0
mkdocstrings[python]==0.27.0
mknotebooks==0.8.0
mypy
pre-commit
promise
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
chex>=0.1.3
dm-env>=1.5
esquilax>=1.0.3
gym>=0.22.0
gymnasium>=1.0
huggingface-hub
jax>=0.2.26
matplotlib~=3.7.4
Expand Down

0 comments on commit 4996869

Please sign in to comment.