Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix imports jumpy and formatting issues #163

Merged
merged 11 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
if: ${{ github.event_name == 'push' }}

- name: Checkout
uses: actions/checkout@v3
if: ${{ github.event_name == 'pull_request_target' }}
with:
ref: ${{ github.event.pull_request.head.sha }}

- name: Set up Docker Buildx
id: buildx
Expand Down Expand Up @@ -99,7 +106,14 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
if: ${{ github.event_name == 'push' }}

- name: Checkout
uses: actions/checkout@v3
if: ${{ github.event_name == 'pull_request_target' }}
with:
ref: ${{ github.event.pull_request.head.sha }}

- name: Run pre-commits
run: |
Expand All @@ -117,7 +131,14 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
if: ${{ github.event_name == 'push' }}

- name: Checkout
uses: actions/checkout@v3
if: ${{ github.event_name == 'pull_request_target' }}
with:
ref: ${{ github.event.pull_request.head.sha }}

- name: Run pytests
run: |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
<img src="docs/img/qdax_logo.png" alt="qdax_logo" width="140"></img>
</div>


# QDax: Accelerated Quality-Diversity

[![Documentation Status](https://readthedocs.org/projects/qdax/badge/?version=latest)](https://qdax.readthedocs.io/en/latest/?badge=latest)
Expand Down
2 changes: 1 addition & 1 deletion examples/jumanji_snake.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"outputs": [],
"source": [
"# Instantiate a Jumanji environment using the registry\n",
"env = jumanji.make('Snake-6x6-v0')\n",
"env = jumanji.make('Snake-v1')\n",
"\n",
"# Reset your (jit-able) environment\n",
"key = jax.random.PRNGKey(0)\n",
Expand Down
10 changes: 7 additions & 3 deletions qdax/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import functools
from typing import Any, Callable, List, Optional, Union

from brax.v1.envs import Env
from brax.v1.envs import _envs
from brax.v1.envs.wrappers import EpisodeWrapper, AutoResetWrapper, EvalWrapper, VectorWrapper
from brax.v1.envs import Env, _envs
from brax.v1.envs.wrappers import (
AutoResetWrapper,
EpisodeWrapper,
EvalWrapper,
VectorWrapper,
)

from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper
from qdax.environments.bd_extractors import (
Expand Down
6 changes: 2 additions & 4 deletions qdax/environments/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict

from brax.v1.envs import State, Wrapper
import flax.struct
import jax
from brax.v1 import jumpy as jp
from brax.v1.envs import State, Wrapper


class CompletedEvalMetrics(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -34,9 +34,7 @@ def reset(self, rng: jp.ndarray) -> State:
reset_state.info[self.STATE_INFO_KEY] = eval_metrics
return reset_state

def step(
self, state: State, action: jp.ndarray
) -> State:
def step(self, state: State, action: jp.ndarray) -> State:
state_metrics = state.info[self.STATE_INFO_KEY]
if not isinstance(state_metrics, CompletedEvalMetrics):
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
Expand Down
12 changes: 8 additions & 4 deletions tests/default_tasks_test/jumanji_envs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import jumanji
import jumanji.environments.routing.snake
import numpy as np
import pytest

Expand All @@ -26,7 +27,7 @@ def test_jumanji_utils() -> None:
batch_size = population_size

# Instantiate a Jumanji environment using the registry
env = jumanji.make("Snake-6x6-v0")
env = jumanji.make("Snake-v1")

# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
Expand All @@ -49,8 +50,10 @@ def test_jumanji_utils() -> None:
final_activation=jax.nn.softmax,
)

def observation_processing(observation: jumanji.types.Observation) -> Observation:
network_input = jnp.ravel(observation)
def observation_processing(
observation: jumanji.environments.routing.snake.types.Observation,
) -> Observation:
network_input = jnp.ravel(observation.grid)
return network_input

play_step_fn = make_policy_network_play_step_fn_jumanji(
Expand All @@ -64,7 +67,7 @@ def observation_processing(observation: jumanji.types.Observation) -> Observatio
keys = jax.random.split(subkey, num=batch_size)

# compute observation size from observation spec
observation_size = np.prod(np.array(env.observation_spec().shape))
observation_size = np.prod(np.array(env.observation_spec().grid.shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
Expand Down Expand Up @@ -136,4 +139,5 @@ def bd_extraction(


if __name__ == "__main__":
pytest.assume
test_jumanji_utils()
9 changes: 4 additions & 5 deletions tests/environments_test/pointmaze_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Tuple

import brax
import brax.envs
import jax
import pytest
from brax import jumpy as jp
from brax.v1 import jumpy as jp
from brax.v1.envs import Env

import qdax
from qdax.environments.pointmaze import PointMaze
Expand All @@ -15,7 +14,7 @@ def test_pointmaze() -> None:
# create env with class
qd_env = PointMaze()
# verify class
pytest.assume(isinstance(qd_env, brax.envs.Env))
pytest.assume(isinstance(qd_env, Env))

# check state_descriptor_length
pytest.assume(qd_env.state_descriptor_length == 2)
Expand All @@ -25,7 +24,7 @@ def test_pointmaze() -> None:
qd_env = qdax.environments.create(env_name="pointmaze") # type: ignore

# verify class
pytest.assume(isinstance(qd_env, brax.envs.Env))
pytest.assume(isinstance(qd_env, Env))

# check state_descriptor_length
pytest.assume(qd_env.state_descriptor_length == 2)
Expand Down
8 changes: 4 additions & 4 deletions tests/environments_test/wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict, List, Union

import brax.envs
import brax
import jax
import jax.numpy as jnp
import pytest
from brax import jumpy as jp
from brax.physics.base import vec_to_arr
from brax.physics.config_pb2 import Joint
from brax.v1 import jumpy as jp
from brax.v1.physics.base import vec_to_arr
from brax.v1.physics.config_pb2 import Joint

from qdax import environments

Expand Down
Loading