Skip to content

Commit

Permalink
Fix imports, CI and formatting issues (#163)
Browse files Browse the repository at this point in the history
* fix imports jumpy
* fix formatting issues
* update ci to handle internal and external PRs
* adapt Jumanji test to new Jumanji API
  • Loading branch information
Lookatator authored Dec 10, 2023
1 parent a190a62 commit b221c68
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 21 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
<img src="docs/img/qdax_logo.png" alt="qdax_logo" width="140"></img>
</div>


# QDax: Accelerated Quality-Diversity

[![Documentation Status](https://readthedocs.org/projects/qdax/badge/?version=latest)](https://qdax.readthedocs.io/en/latest/?badge=latest)
Expand Down
2 changes: 1 addition & 1 deletion examples/jumanji_snake.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"outputs": [],
"source": [
"# Instantiate a Jumanji environment using the registry\n",
"env = jumanji.make('Snake-6x6-v0')\n",
"env = jumanji.make('Snake-v1')\n",
"\n",
"# Reset your (jit-able) environment\n",
"key = jax.random.PRNGKey(0)\n",
Expand Down
10 changes: 7 additions & 3 deletions qdax/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import functools
from typing import Any, Callable, List, Optional, Union

from brax.v1.envs import Env
from brax.v1.envs import _envs
from brax.v1.envs.wrappers import EpisodeWrapper, AutoResetWrapper, EvalWrapper, VectorWrapper
from brax.v1.envs import Env, _envs
from brax.v1.envs.wrappers import (
AutoResetWrapper,
EpisodeWrapper,
EvalWrapper,
VectorWrapper,
)

from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper
from qdax.environments.bd_extractors import (
Expand Down
6 changes: 2 additions & 4 deletions qdax/environments/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict

from brax.v1.envs import State, Wrapper
import flax.struct
import jax
from brax.v1 import jumpy as jp
from brax.v1.envs import State, Wrapper


class CompletedEvalMetrics(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -34,9 +34,7 @@ def reset(self, rng: jp.ndarray) -> State:
reset_state.info[self.STATE_INFO_KEY] = eval_metrics
return reset_state

def step(
self, state: State, action: jp.ndarray
) -> State:
def step(self, state: State, action: jp.ndarray) -> State:
state_metrics = state.info[self.STATE_INFO_KEY]
if not isinstance(state_metrics, CompletedEvalMetrics):
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
Expand Down
12 changes: 8 additions & 4 deletions tests/default_tasks_test/jumanji_envs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import jumanji
import jumanji.environments.routing.snake
import numpy as np
import pytest

Expand All @@ -26,7 +27,7 @@ def test_jumanji_utils() -> None:
batch_size = population_size

# Instantiate a Jumanji environment using the registry
env = jumanji.make("Snake-6x6-v0")
env = jumanji.make("Snake-v1")

# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
Expand All @@ -49,8 +50,10 @@ def test_jumanji_utils() -> None:
final_activation=jax.nn.softmax,
)

def observation_processing(observation: jumanji.types.Observation) -> Observation:
network_input = jnp.ravel(observation)
def observation_processing(
observation: jumanji.environments.routing.snake.types.Observation,
) -> Observation:
network_input = jnp.ravel(observation.grid)
return network_input

play_step_fn = make_policy_network_play_step_fn_jumanji(
Expand All @@ -64,7 +67,7 @@ def observation_processing(observation: jumanji.types.Observation) -> Observatio
keys = jax.random.split(subkey, num=batch_size)

# compute observation size from observation spec
observation_size = np.prod(np.array(env.observation_spec().shape))
observation_size = np.prod(np.array(env.observation_spec().grid.shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
Expand Down Expand Up @@ -136,4 +139,5 @@ def bd_extraction(


if __name__ == "__main__":
pytest.assume
test_jumanji_utils()
9 changes: 4 additions & 5 deletions tests/environments_test/pointmaze_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Tuple

import brax
import brax.envs
import jax
import pytest
from brax import jumpy as jp
from brax.v1 import jumpy as jp
from brax.v1.envs import Env

import qdax
from qdax.environments.pointmaze import PointMaze
Expand All @@ -15,7 +14,7 @@ def test_pointmaze() -> None:
# create env with class
qd_env = PointMaze()
# verify class
pytest.assume(isinstance(qd_env, brax.envs.Env))
pytest.assume(isinstance(qd_env, Env))

# check state_descriptor_length
pytest.assume(qd_env.state_descriptor_length == 2)
Expand All @@ -25,7 +24,7 @@ def test_pointmaze() -> None:
qd_env = qdax.environments.create(env_name="pointmaze") # type: ignore

# verify class
pytest.assume(isinstance(qd_env, brax.envs.Env))
pytest.assume(isinstance(qd_env, Env))

# check state_descriptor_length
pytest.assume(qd_env.state_descriptor_length == 2)
Expand Down
8 changes: 4 additions & 4 deletions tests/environments_test/wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict, List, Union

import brax.envs
import brax
import jax
import jax.numpy as jnp
import pytest
from brax import jumpy as jp
from brax.physics.base import vec_to_arr
from brax.physics.config_pb2 import Joint
from brax.v1 import jumpy as jp
from brax.v1.physics.base import vec_to_arr
from brax.v1.physics.config_pb2 import Joint

from qdax import environments

Expand Down

0 comments on commit b221c68

Please sign in to comment.