Skip to content

Commit

Permalink
fix: use brax.v1 and update requirements (#156)
Browse files Browse the repository at this point in the history
* Change brax for brax.v1

* update jax/jaxlib and flax dependencies
  • Loading branch information
DBraun authored Oct 9, 2023
1 parent c111cea commit df954c2
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: "3.8"
python: "3.9"
apt_packages:
- swig

Expand Down
12 changes: 6 additions & 6 deletions dev.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM mambaorg/micromamba:0.22.0 as conda
FROM mambaorg/micromamba:1.5.1 as conda

# Speed up the build, and avoid unnecessary writes to disk
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 CONDA_DIR=/opt/conda
Expand All @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \


FROM python as test-image
ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app
ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app
ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH

COPY --from=conda /opt/conda/envs/. /opt/conda/envs/
Expand All @@ -25,8 +25,8 @@ COPY requirements-dev.txt ./
RUN pip install -r requirements-dev.txt


FROM nvidia/cuda:11.4.1-cudnn8-devel-ubuntu20.04 as cuda-image
ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app
FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image
ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app
ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH


Expand All @@ -40,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l

ENV TZ=Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \
RUN pip --no-cache-dir install jaxlib==0.4.16+cuda11.cudnn86 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& rm -rf /tmp/*

Expand Down Expand Up @@ -70,7 +70,7 @@ RUN apt-get update && \
libosmesa6-dev \
patchelf \
python3-opengl \
python3-dev=3.8* \
python3-dev=3.9* \
python3-pip \
screen \
sudo \
Expand Down
2 changes: 1 addition & 1 deletion qdax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.2"
__version__ = "0.2.4"
25 changes: 13 additions & 12 deletions qdax/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
from typing import Any, Callable, List, Optional, Union

import brax
import brax.envs
from brax.v1.envs import Env
from brax.v1.envs import _envs
from brax.v1.envs.wrappers import EpisodeWrapper, AutoResetWrapper, EvalWrapper, VectorWrapper

from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper
from qdax.environments.bd_extractors import (
Expand Down Expand Up @@ -122,20 +123,20 @@ def create(
fixed_init_state: bool = False,
qdax_wrappers_kwargs: Optional[List] = None,
**kwargs: Any,
) -> Union[brax.envs.env.Env, QDEnv]:
) -> Union[Env, QDEnv]:
"""Creates an Env with a specified brax system.
Please use namespace to avoid confusion between this function and
brax.envs.create.
"""

if env_name in brax.envs._envs.keys():
env = brax.envs._envs[env_name](legacy_spring=True, **kwargs)
if env_name in _envs.keys():
env = _envs[env_name](legacy_spring=True, **kwargs)
elif env_name in _qdax_envs.keys():
env = _qdax_envs[env_name](**kwargs)
elif env_name in _qdax_custom_envs.keys():
base_env_name = _qdax_custom_envs[env_name]["env"]
if base_env_name in brax.envs._envs.keys():
env = brax.envs._envs[base_env_name](legacy_spring=True, **kwargs)
if base_env_name in _envs.keys():
env = _envs[base_env_name](legacy_spring=True, **kwargs)
elif base_env_name in _qdax_envs.keys():
env = _qdax_envs[base_env_name](**kwargs) # type: ignore
else:
Expand All @@ -152,27 +153,27 @@ def create(
env = wrapper(env, base_env_name, **kwargs) # type: ignore

if episode_length is not None:
env = brax.envs.wrappers.EpisodeWrapper(env, episode_length, action_repeat)
env = EpisodeWrapper(env, episode_length, action_repeat)
if batch_size:
env = brax.envs.wrappers.VectorWrapper(env, batch_size)
env = VectorWrapper(env, batch_size)
if fixed_init_state:
# retrieve the base env
if env_name not in _qdax_custom_envs.keys():
base_env_name = env_name
# wrap the env
env = FixedInitialStateWrapper(env, base_env_name=base_env_name) # type: ignore
if auto_reset:
env = brax.envs.wrappers.AutoResetWrapper(env)
env = AutoResetWrapper(env)
if env_name in _qdax_custom_envs.keys():
env = StateDescriptorResetWrapper(env)
if eval_metrics:
env = brax.envs.wrappers.EvalWrapper(env)
env = EvalWrapper(env)
env = CompletedEvalWrapper(env)

return env


def create_fn(env_name: str, **kwargs: Any) -> Callable[..., brax.envs.Env]:
def create_fn(env_name: str, **kwargs: Any) -> Callable[..., Env]:
"""Returns a function that when called, creates an Env.
Please use namespace to avoid confusion between this function and
brax.envs.create_fn.
Expand Down
4 changes: 2 additions & 2 deletions qdax/environments/base_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import abstractmethod
from typing import Any, List, Tuple

from brax import jumpy as jp
from brax.envs.env import Env, State
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State


class QDEnv(Env):
Expand Down
14 changes: 7 additions & 7 deletions qdax/environments/exploration_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings

import brax
import brax.v1 as brax
import jax.numpy as jnp
from brax import jumpy as jp
from brax.envs import State, env
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State, Wrapper
from google.protobuf import text_format # type: ignore

from qdax.environments.locomotion_wrappers import COG_NAMES
Expand Down Expand Up @@ -103,7 +103,7 @@
}


class TrapWrapper(env.Wrapper):
class TrapWrapper(Wrapper):
"""Wraps gym environments to add a Trap in the environment.
Utilisation is simple: create an environment with Brax, pass
Expand Down Expand Up @@ -143,7 +143,7 @@ class TrapWrapper(env.Wrapper):
"""

def __init__(self, env: env.Env, env_name: str) -> None:
def __init__(self, env: Env, env_name: str) -> None:
if (
env_name not in ENV_SYSTEM_CONFIG.keys()
or env_name not in COG_NAMES.keys()
Expand Down Expand Up @@ -323,7 +323,7 @@ def step(self, state: State, action: jp.ndarray) -> State:
}


class MazeWrapper(env.Wrapper):
class MazeWrapper(Wrapper):
"""Wraps gym environments to add a maze in the environment
and a new reward (distance to the goal).
Expand Down Expand Up @@ -364,7 +364,7 @@ class MazeWrapper(env.Wrapper):
"""

def __init__(self, env: env.Env, env_name: str) -> None:
def __init__(self, env: Env, env_name: str) -> None:
if (
env_name not in ENV_SYSTEM_CONFIG.keys()
or env_name not in COG_NAMES.keys()
Expand Down
16 changes: 8 additions & 8 deletions qdax/environments/humanoidtrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from typing import Any, Dict

import brax
from brax import jumpy as jp
from brax.envs import env
from brax.physics import bodies
import brax.v1 as brax
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State
from brax.v1.physics import bodies

TRAP_CONFIG = """bodies {
name: "Trap"
Expand Down Expand Up @@ -58,7 +58,7 @@
"""


class HumanoidTrap(env.Env):
class HumanoidTrap(Env):
"""Trains a humanoid to run in the +x direction.
RMQ: uses legacy spring from Brax.
Expand All @@ -76,7 +76,7 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None:
self.inertia = body.inertia
self.inertia_matrix = jp.array([jp.diag(a) for a in self.inertia])

def reset(self, rng: jp.ndarray) -> env.State:
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
qpos = self.sys.default_angle() + jp.random_uniform(
Expand All @@ -93,9 +93,9 @@ def reset(self, rng: jp.ndarray) -> env.State:
"reward_alive": zero,
"reward_impact": zero,
}
return env.State(qp, obs, reward, done, metrics)
return State(qp, obs, reward, done, metrics)

def step(self, state: env.State, action: jp.ndarray) -> env.State:
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
qp, info = self.sys.step(state.qp, action)
obs = self._get_obs(qp, info, action)
Expand Down
8 changes: 4 additions & 4 deletions qdax/environments/init_state_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Callable, Optional

import brax
from brax import jumpy as jp
from brax.envs import Env, State, Wrapper
import brax.v1 as brax
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State, Wrapper


class FixedInitialStateWrapper(Wrapper):
Expand Down Expand Up @@ -51,7 +51,7 @@ def reset(self, rng: jp.ndarray) -> State:
# Run the default reset method of parent environment
state = self.env.reset(rng)

# Compute new initial positions and velicities
# Compute new initial positions and velocities
qpos = self.sys.default_angle()
qvel = jp.zeros((self.sys.num_joint_dof,))

Expand Down
10 changes: 5 additions & 5 deletions qdax/environments/locomotion_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, List, Optional, Sequence, Tuple

import jax.numpy as jnp
from brax import jumpy as jp
from brax.envs import Env, State, Wrapper
from brax.physics import config_pb2
from brax.physics.base import QP, Info
from brax.physics.system import System
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State, Wrapper
from brax.v1.physics import config_pb2
from brax.v1.physics.base import QP, Info
from brax.v1.physics.system import System

from qdax.environments.base_wrappers import QDEnv

Expand Down
16 changes: 8 additions & 8 deletions qdax/environments/pointmaze.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Any, Dict, List, Tuple, Union

import brax
from brax import jumpy as jp
from brax.envs import env
import brax.v1 as brax
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State


class PointMaze(env.Env):
class PointMaze(Env):
"""Jax/Brax implementation of the PointMaze.
Highly inspired from the old python implementation of
the PointMaze.
In order to stay in the Brax API, I will use a fake QP
at several moment of the implementation. This enable to
use the brax.envs.env.State from Brax. To avoid this,
use the brax.envs.State from Brax. To avoid this,
it would be good to ask Brax to enlarge a bit their API
for environments that are not physically simulated.
"""
Expand Down Expand Up @@ -103,7 +103,7 @@ def action_size(self) -> int:
"""The size of the observation vector returned in step and reset."""
return 2

def reset(self, rng: jp.ndarray) -> env.State:
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
# get initial position - reproduce the old implementation
Expand All @@ -117,9 +117,9 @@ def reset(self, rng: jp.ndarray) -> env.State:
metrics: Dict = {}
# managing state descriptor by our own
info_init = {"state_descriptor": obs_init}
return env.State(fake_qp, obs_init, reward, done, metrics, info_init)
return State(fake_qp, obs_init, reward, done, metrics, info_init)

def step(self, state: env.State, action: jp.ndarray) -> env.State:
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""

# clip action taken
Expand Down
12 changes: 6 additions & 6 deletions qdax/environments/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict

import brax.envs
from brax.v1.envs import State, Wrapper
import flax.struct
import jax
from brax import jumpy as jp
from brax.v1 import jumpy as jp


class CompletedEvalMetrics(flax.struct.PyTreeNode):
Expand All @@ -13,12 +13,12 @@ class CompletedEvalMetrics(flax.struct.PyTreeNode):
completed_episodes_steps: jp.ndarray


class CompletedEvalWrapper(brax.envs.env.Wrapper):
class CompletedEvalWrapper(Wrapper):
"""Brax env with eval metrics for completed episodes."""

STATE_INFO_KEY = "completed_eval_metrics"

def reset(self, rng: jp.ndarray) -> brax.envs.env.State:
def reset(self, rng: jp.ndarray) -> State:
reset_state = self.env.reset(rng)
reset_state.metrics["reward"] = reset_state.reward
eval_metrics = CompletedEvalMetrics(
Expand All @@ -35,8 +35,8 @@ def reset(self, rng: jp.ndarray) -> brax.envs.env.State:
return reset_state

def step(
self, state: brax.envs.env.State, action: jp.ndarray
) -> brax.envs.env.State:
self, state: State, action: jp.ndarray
) -> State:
state_metrics = state.info[self.STATE_INFO_KEY]
if not isinstance(state_metrics, CompletedEvalMetrics):
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
Expand Down
20 changes: 10 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
absl-py==1.0.0
brax==0.0.15
chex==0.1.5
brax==0.9.2
chex==0.1.83
dm-haiku==0.0.9
flax==0.6.0
gym==0.23.1
flax==0.7.4
gym==0.26.2
ipython
jax==0.3.17
jaxlib==0.3.15
jumanji==0.1.3
jax==0.4.16
jaxlib==0.4.16
jumanji==0.3.1
jupyter
numpy==1.22.3
optax==0.1.4
numpy==1.24.1
optax==0.1.7
protobuf==3.19.4
scikit-learn==1.0.2
scipy==1.8.0
seaborn==0.11.2
tensorflow-probability==0.15.0
typing-extensions==4.3.0
typing-extensions==4.3.0
Loading

0 comments on commit df954c2

Please sign in to comment.