Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use brax.v1 and update requirements #156

Merged
merged 9 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: "3.8"
python: "3.9"
apt_packages:
- swig

Expand Down
12 changes: 6 additions & 6 deletions dev.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM mambaorg/micromamba:0.22.0 as conda
FROM mambaorg/micromamba:1.5.1 as conda

# Speed up the build, and avoid unnecessary writes to disk
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 CONDA_DIR=/opt/conda
Expand All @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \


FROM python as test-image
ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app
ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app
ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH

COPY --from=conda /opt/conda/envs/. /opt/conda/envs/
Expand All @@ -25,8 +25,8 @@ COPY requirements-dev.txt ./
RUN pip install -r requirements-dev.txt


FROM nvidia/cuda:11.4.1-cudnn8-devel-ubuntu20.04 as cuda-image
ENV PATH=/opt/conda/envs/qdaxpy38/bin/:$PATH APP_FOLDER=/app
FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image
ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app
ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH


Expand All @@ -40,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l

ENV TZ=Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \
RUN pip --no-cache-dir install jaxlib==0.4.16+cuda11.cudnn86 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& rm -rf /tmp/*

Expand Down Expand Up @@ -70,7 +70,7 @@ RUN apt-get update && \
libosmesa6-dev \
patchelf \
python3-opengl \
python3-dev=3.8* \
python3-dev=3.9* \
python3-pip \
screen \
sudo \
Expand Down
2 changes: 1 addition & 1 deletion qdax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.2"
__version__ = "0.2.4"
25 changes: 13 additions & 12 deletions qdax/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
from typing import Any, Callable, List, Optional, Union

import brax
import brax.envs
from brax.v1.envs import Env
from brax.v1.envs import _envs
from brax.v1.envs.wrappers import EpisodeWrapper, AutoResetWrapper, EvalWrapper, VectorWrapper

from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper
from qdax.environments.bd_extractors import (
Expand Down Expand Up @@ -122,20 +123,20 @@ def create(
fixed_init_state: bool = False,
qdax_wrappers_kwargs: Optional[List] = None,
**kwargs: Any,
) -> Union[brax.envs.env.Env, QDEnv]:
) -> Union[Env, QDEnv]:
"""Creates an Env with a specified brax system.
Please use namespace to avoid confusion between this function and
brax.envs.create.
"""

if env_name in brax.envs._envs.keys():
env = brax.envs._envs[env_name](legacy_spring=True, **kwargs)
if env_name in _envs.keys():
env = _envs[env_name](legacy_spring=True, **kwargs)
elif env_name in _qdax_envs.keys():
env = _qdax_envs[env_name](**kwargs)
elif env_name in _qdax_custom_envs.keys():
base_env_name = _qdax_custom_envs[env_name]["env"]
if base_env_name in brax.envs._envs.keys():
env = brax.envs._envs[base_env_name](legacy_spring=True, **kwargs)
if base_env_name in _envs.keys():
env = _envs[base_env_name](legacy_spring=True, **kwargs)
elif base_env_name in _qdax_envs.keys():
env = _qdax_envs[base_env_name](**kwargs) # type: ignore
else:
Expand All @@ -152,27 +153,27 @@ def create(
env = wrapper(env, base_env_name, **kwargs) # type: ignore

if episode_length is not None:
env = brax.envs.wrappers.EpisodeWrapper(env, episode_length, action_repeat)
env = EpisodeWrapper(env, episode_length, action_repeat)
if batch_size:
env = brax.envs.wrappers.VectorWrapper(env, batch_size)
env = VectorWrapper(env, batch_size)
if fixed_init_state:
# retrieve the base env
if env_name not in _qdax_custom_envs.keys():
base_env_name = env_name
# wrap the env
env = FixedInitialStateWrapper(env, base_env_name=base_env_name) # type: ignore
if auto_reset:
env = brax.envs.wrappers.AutoResetWrapper(env)
env = AutoResetWrapper(env)
if env_name in _qdax_custom_envs.keys():
env = StateDescriptorResetWrapper(env)
if eval_metrics:
env = brax.envs.wrappers.EvalWrapper(env)
env = EvalWrapper(env)
env = CompletedEvalWrapper(env)

return env


def create_fn(env_name: str, **kwargs: Any) -> Callable[..., brax.envs.Env]:
def create_fn(env_name: str, **kwargs: Any) -> Callable[..., Env]:
"""Returns a function that when called, creates an Env.
Please use namespace to avoid confusion between this function and
brax.envs.create_fn.
Expand Down
4 changes: 2 additions & 2 deletions qdax/environments/base_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import abstractmethod
from typing import Any, List, Tuple

from brax import jumpy as jp
from brax.envs.env import Env, State
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State


class QDEnv(Env):
Expand Down
14 changes: 7 additions & 7 deletions qdax/environments/exploration_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings

import brax
import brax.v1 as brax
import jax.numpy as jnp
from brax import jumpy as jp
from brax.envs import State, env
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State, Wrapper
from google.protobuf import text_format # type: ignore

from qdax.environments.locomotion_wrappers import COG_NAMES
Expand Down Expand Up @@ -103,7 +103,7 @@
}


class TrapWrapper(env.Wrapper):
class TrapWrapper(Wrapper):
"""Wraps gym environments to add a Trap in the environment.

Utilisation is simple: create an environment with Brax, pass
Expand Down Expand Up @@ -143,7 +143,7 @@ class TrapWrapper(env.Wrapper):

"""

def __init__(self, env: env.Env, env_name: str) -> None:
def __init__(self, env: Env, env_name: str) -> None:
if (
env_name not in ENV_SYSTEM_CONFIG.keys()
or env_name not in COG_NAMES.keys()
Expand Down Expand Up @@ -323,7 +323,7 @@ def step(self, state: State, action: jp.ndarray) -> State:
}


class MazeWrapper(env.Wrapper):
class MazeWrapper(Wrapper):
"""Wraps gym environments to add a maze in the environment
and a new reward (distance to the goal).

Expand Down Expand Up @@ -364,7 +364,7 @@ class MazeWrapper(env.Wrapper):

"""

def __init__(self, env: env.Env, env_name: str) -> None:
def __init__(self, env: Env, env_name: str) -> None:
if (
env_name not in ENV_SYSTEM_CONFIG.keys()
or env_name not in COG_NAMES.keys()
Expand Down
16 changes: 8 additions & 8 deletions qdax/environments/humanoidtrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from typing import Any, Dict

import brax
from brax import jumpy as jp
from brax.envs import env
from brax.physics import bodies
import brax.v1 as brax
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State
from brax.v1.physics import bodies

TRAP_CONFIG = """bodies {
name: "Trap"
Expand Down Expand Up @@ -58,7 +58,7 @@
"""


class HumanoidTrap(env.Env):
class HumanoidTrap(Env):
"""Trains a humanoid to run in the +x direction.

RMQ: uses legacy spring from Brax.
Expand All @@ -76,7 +76,7 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None:
self.inertia = body.inertia
self.inertia_matrix = jp.array([jp.diag(a) for a in self.inertia])

def reset(self, rng: jp.ndarray) -> env.State:
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
qpos = self.sys.default_angle() + jp.random_uniform(
Expand All @@ -93,9 +93,9 @@ def reset(self, rng: jp.ndarray) -> env.State:
"reward_alive": zero,
"reward_impact": zero,
}
return env.State(qp, obs, reward, done, metrics)
return State(qp, obs, reward, done, metrics)

def step(self, state: env.State, action: jp.ndarray) -> env.State:
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
qp, info = self.sys.step(state.qp, action)
obs = self._get_obs(qp, info, action)
Expand Down
8 changes: 4 additions & 4 deletions qdax/environments/init_state_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Callable, Optional

import brax
from brax import jumpy as jp
from brax.envs import Env, State, Wrapper
import brax.v1 as brax
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State, Wrapper


class FixedInitialStateWrapper(Wrapper):
Expand Down Expand Up @@ -51,7 +51,7 @@ def reset(self, rng: jp.ndarray) -> State:
# Run the default reset method of parent environment
state = self.env.reset(rng)

# Compute new initial positions and velicities
# Compute new initial positions and velocities
qpos = self.sys.default_angle()
qvel = jp.zeros((self.sys.num_joint_dof,))

Expand Down
10 changes: 5 additions & 5 deletions qdax/environments/locomotion_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, List, Optional, Sequence, Tuple

import jax.numpy as jnp
from brax import jumpy as jp
from brax.envs import Env, State, Wrapper
from brax.physics import config_pb2
from brax.physics.base import QP, Info
from brax.physics.system import System
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State, Wrapper
from brax.v1.physics import config_pb2
from brax.v1.physics.base import QP, Info
from brax.v1.physics.system import System

from qdax.environments.base_wrappers import QDEnv

Expand Down
16 changes: 8 additions & 8 deletions qdax/environments/pointmaze.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Any, Dict, List, Tuple, Union

import brax
from brax import jumpy as jp
from brax.envs import env
import brax.v1 as brax
from brax.v1 import jumpy as jp
from brax.v1.envs import Env, State


class PointMaze(env.Env):
class PointMaze(Env):
"""Jax/Brax implementation of the PointMaze.
Highly inspired from the old python implementation of
the PointMaze.

In order to stay in the Brax API, I will use a fake QP
at several moment of the implementation. This enable to
use the brax.envs.env.State from Brax. To avoid this,
use the brax.envs.State from Brax. To avoid this,
it would be good to ask Brax to enlarge a bit their API
for environments that are not physically simulated.
"""
Expand Down Expand Up @@ -103,7 +103,7 @@ def action_size(self) -> int:
"""The size of the observation vector returned in step and reset."""
return 2

def reset(self, rng: jp.ndarray) -> env.State:
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
# get initial position - reproduce the old implementation
Expand All @@ -117,9 +117,9 @@ def reset(self, rng: jp.ndarray) -> env.State:
metrics: Dict = {}
# managing state descriptor by our own
info_init = {"state_descriptor": obs_init}
return env.State(fake_qp, obs_init, reward, done, metrics, info_init)
return State(fake_qp, obs_init, reward, done, metrics, info_init)

def step(self, state: env.State, action: jp.ndarray) -> env.State:
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""

# clip action taken
Expand Down
12 changes: 6 additions & 6 deletions qdax/environments/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict

import brax.envs
from brax.v1.envs import State, Wrapper
import flax.struct
import jax
from brax import jumpy as jp
from brax.v1 import jumpy as jp


class CompletedEvalMetrics(flax.struct.PyTreeNode):
Expand All @@ -13,12 +13,12 @@ class CompletedEvalMetrics(flax.struct.PyTreeNode):
completed_episodes_steps: jp.ndarray


class CompletedEvalWrapper(brax.envs.env.Wrapper):
class CompletedEvalWrapper(Wrapper):
"""Brax env with eval metrics for completed episodes."""

STATE_INFO_KEY = "completed_eval_metrics"

def reset(self, rng: jp.ndarray) -> brax.envs.env.State:
def reset(self, rng: jp.ndarray) -> State:
reset_state = self.env.reset(rng)
reset_state.metrics["reward"] = reset_state.reward
eval_metrics = CompletedEvalMetrics(
Expand All @@ -35,8 +35,8 @@ def reset(self, rng: jp.ndarray) -> brax.envs.env.State:
return reset_state

def step(
self, state: brax.envs.env.State, action: jp.ndarray
) -> brax.envs.env.State:
self, state: State, action: jp.ndarray
) -> State:
state_metrics = state.info[self.STATE_INFO_KEY]
if not isinstance(state_metrics, CompletedEvalMetrics):
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
Expand Down
20 changes: 10 additions & 10 deletions requirements.txt
Lookatator marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
absl-py==1.0.0
brax==0.0.15
chex==0.1.5
brax==0.9.2
chex==0.1.83
dm-haiku==0.0.9
flax==0.6.0
gym==0.23.1
flax==0.7.4
gym==0.26.2
ipython
jax==0.3.17
jaxlib==0.3.15
jumanji==0.1.3
jax==0.4.16
jaxlib==0.4.16
jumanji==0.3.1
jupyter
numpy==1.22.3
optax==0.1.4
numpy==1.24.1
optax==0.1.7
protobuf==3.19.4
scikit-learn==1.0.2
scipy==1.8.0
seaborn==0.11.2
tensorflow-probability==0.15.0
typing-extensions==4.3.0
typing-extensions==4.3.0
Loading
Loading