From bae0911e0a871378a040be1fc3627d37ca544694 Mon Sep 17 00:00:00 2001 From: youliang Date: Mon, 19 Feb 2024 05:19:51 -0800 Subject: [PATCH 1/5] examples in sim to save and load rlds data Signed-off-by: youliang --- docs/sim_quick_start.md | 44 +++++++++++ examples/async_drq_sim/async_drq_sim.py | 18 +++-- .../async_rlpd_drq_sim/async_rlpd_drq_sim.py | 15 +++- .../async_sac_state_sim.py | 15 ++-- serl_launcher/requirements.txt | 1 + .../serl_launcher/data/data_store.py | 78 +++++++++++++++++-- serl_launcher/serl_launcher/utils/launcher.py | 69 ++++++++++++++++ 7 files changed, 219 insertions(+), 21 deletions(-) diff --git a/docs/sim_quick_start.md b/docs/sim_quick_start.md index 37acc565..ad4e7e1e 100644 --- a/docs/sim_quick_start.md +++ b/docs/sim_quick_start.md @@ -117,3 +117,47 @@ Run actor node with rendering window: # add --ip x.x.x.x if running on a different machine bash run_actor.sh ``` + +## Use RLDS logger to save and load trajectories + +This provides a way to save and load trajectories for SERL training. [Tensorflow RLDS dataset](https://github.com/google-research/rlds) format is used to save and load trajectories. This standard is compliant with the [RTX datasets](https://robotics-transformer-x.github.io/), which can potentially can be used for other robot learning tasks. + +### Installation + +This requires additional installation of `oxe_envlogger`: +```bash +git clone git@github.com:rail-berkeley/oxe_envlogger.git +cd oxe_envlogger +pip install -e . +``` + +### Usage + +**Save the trajectories** + +With the example above, we can save the data from the replay buffer by providing the `rlds_logger_path` argument. This will save the data to the specified path. + +```bash +./run_learner.sh --log_rlds_path /path/to/save +``` + +This will save the data to the specified path in the following format: + +```bash + - /path/to/save + - dataset_info.json + - features.json + - serl_rlds_dataset-train.tfrecord-00000 + - serl_rlds_dataset-train.tfrecord-00001 + .... +``` + +**Load the trajectories** + +With the example above, we can load the data from the replay buffer by providing the `preload_rlds_path` argument. This will load the data from the specified path. + +```bash +./run_learner.sh --preload_rlds_path /path/to/load +``` + +This is equivalent to the `--demo_path` argument in `examples/async_rlpd_drq_sim/run_learner.sh` script. diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index 59a4de71..f60dede3 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -24,8 +24,8 @@ make_drq_agent, make_trainer_config, make_wandb_logger, + make_replay_buffer, ) -from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper import franka_sim @@ -66,6 +66,9 @@ "debug", False, "Debug mode." ) # debug mode will disable wandb logging +flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") + devices = jax.local_devices() num_devices = len(devices) sharding = jax.sharding.PositionalSharding(devices) @@ -138,7 +141,7 @@ def update_params(params): next_observations=next_obs, rewards=reward, masks=1.0 - done, - dones=done, + dones=done or truncated, ) data_store.insert(transition) @@ -291,17 +294,20 @@ def main(_): ) def create_replay_buffer_and_wandb_logger(): - replay_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, + replay_buffer = make_replay_buffer( + env, capacity=FLAGS.replay_buffer_capacity, + rlds_logger_path=FLAGS.log_rlds_path, + type="memory_efficient_replay_buffer", image_keys=image_keys, + preload_rlds_path=FLAGS.preload_rlds_path, ) + # set up wandb and logging wandb_logger = make_wandb_logger( project="serl_dev", description=FLAGS.exp_name or FLAGS.env, - # debug=FLAGS.debug, + debug=FLAGS.debug, ) return replay_buffer, wandb_logger diff --git a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py index 66d06383..5f22e0da 100644 --- a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py +++ b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py @@ -25,6 +25,7 @@ make_drq_agent, make_trainer_config, make_wandb_logger, + make_replay_buffer, ) from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper @@ -68,6 +69,9 @@ "debug", False, "Debug mode." ) # debug mode will disable wandb logging +flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") + devices = jax.local_devices() num_devices = len(devices) sharding = jax.sharding.PositionalSharding(devices) @@ -140,7 +144,7 @@ def update_params(params): next_observations=next_obs, rewards=reward, masks=1.0 - done, - dones=done, + dones=done or truncated, ) data_store.insert(transition) @@ -313,12 +317,15 @@ def main(_): ) def create_replay_buffer_and_wandb_logger(): - replay_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, + replay_buffer = make_replay_buffer( + env, capacity=FLAGS.replay_buffer_capacity, + rlds_logger_path=FLAGS.log_rlds_path, + type="memory_efficient_replay_buffer", image_keys=image_keys, + preload_rlds_path=FLAGS.preload_rlds_path, ) + # set up wandb and logging wandb_logger = make_wandb_logger( project="serl_dev", diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index f44af819..90b96d63 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -17,8 +17,8 @@ make_sac_agent, make_trainer_config, make_wandb_logger, + make_replay_buffer, ) -from serl_launcher.data.data_store import ReplayBufferDataStore from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.sac import SACAgent @@ -61,6 +61,9 @@ "debug", False, "Debug mode." ) # debug mode will disable wandb logging +flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") + def print_green(x): return print("\033[92m {}\033[00m".format(x)) @@ -130,7 +133,7 @@ def update_params(params): next_observations=next_obs, rewards=reward, masks=1.0 - done, - dones=done, + dones=done or truncated, ) ) @@ -264,10 +267,12 @@ def main(_): ) def create_replay_buffer_and_wandb_logger(): - replay_buffer = ReplayBufferDataStore( - env.observation_space, - env.action_space, + replay_buffer = make_replay_buffer( + env, capacity=FLAGS.replay_buffer_capacity, + rlds_logger_path=FLAGS.log_rlds_path, + type="replay_buffer", + preload_rlds_path=FLAGS.preload_rlds_path, ) # set up wandb and logging diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index 5c333d9f..6b29152c 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -16,3 +16,4 @@ einops >= 0.6.1 imageio >= 2.31.1 moviepy >= 1.0.3 pre-commit == 3.3.3 +tensorflow_datasets >= 4.4.0 diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index 334d6181..e9f99969 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -10,6 +10,18 @@ from agentlace.data.data_store import DataStoreBase +from typing import List, Optional, TypeVar + +# import oxe_envlogger if it is installed +try: + from oxe_envlogger.rlds_logger import RLDSLogger, RLDSStepType +except ImportError: + print( + "rlds logger is not installed, install it if required: " + "https://github.com/rail-berkeley/oxe_envlogger " + ) + RLDSLogger = TypeVar("RLDSLogger") + class ReplayBufferDataStore(ReplayBuffer, DataStoreBase): def __init__( @@ -17,15 +29,42 @@ def __init__( observation_space: gym.Space, action_space: gym.Space, capacity: int, + rlds_logger: Optional[RLDSLogger] = None, ): ReplayBuffer.__init__(self, observation_space, action_space, capacity) DataStoreBase.__init__(self, capacity) self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger # ensure thread safety - def insert(self, *args, **kwargs): + def insert(self, data): with self._lock: - super(ReplayBufferDataStore, self).insert(*args, **kwargs) + super(ReplayBufferDataStore, self).insert(data) + + # add data to the rlds logger + if self._logger: + if self.step_type in { + RLDSStepType.TERMINATION, + RLDSStepType.TRUNCATION, + }: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["dones"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: check if this is correct + reward=data["rewards"], + step_type=self.step_type, + ) # ensure thread safety def sample(self, *args, **kwargs): @@ -48,17 +87,46 @@ def __init__( action_space: gym.Space, capacity: int, image_keys: Iterable[str] = ("image",), + rlds_logger: Optional[RLDSLogger] = None, ): MemoryEfficientReplayBuffer.__init__( self, observation_space, action_space, capacity, pixel_keys=image_keys ) DataStoreBase.__init__(self, capacity) self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger # ensure thread safety - def insert(self, *args, **kwargs): + def insert(self, data): with self._lock: - super(MemoryEfficientReplayBufferDataStore, self).insert(*args, **kwargs) + super(MemoryEfficientReplayBufferDataStore, self).insert(data) + + if self._logger: + # handle restart when it was done before + if self.step_type in { + RLDSStepType.TERMINATION, + RLDSStepType.TRUNCATION, + }: + self.step_type = RLDSStepType.RESTART + elif self.step_type == RLDSStepType.TRUNCATION: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["dones"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: not obs, but next_obs + reward=data["rewards"], + step_type=self.step_type, + ) # ensure thread safety def sample(self, *args, **kwargs): @@ -85,8 +153,6 @@ def populate_data_store( :return data_store """ import pickle as pkl - import numpy as np - from copy import deepcopy for demo_path in demos_path: with open(demo_path, "rb") as f: diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 9961bbe1..79658714 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -3,7 +3,11 @@ import jax from jax import nn +from typing import Optional +import tensorflow_datasets as tfds + from agentlace.trainer import TrainerConfig +from agentlace.data.tfds import populate_datastore from serl_launcher.common.wandb import WandBLogger from serl_launcher.agents.continuous.bc import BCAgent @@ -11,6 +15,11 @@ from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.agents.continuous.vice import VICEAgent +from serl_launcher.data.data_store import ( + MemoryEfficientReplayBufferDataStore, + ReplayBufferDataStore, +) + ############################################################################## @@ -180,3 +189,63 @@ def make_wandb_logger( debug=debug, ) return wandb_logger + + +def make_replay_buffer( + env, + capacity: int = 1000000, + rlds_logger_path: Optional[str] = None, + type: str = "replay_buffer", + image_keys: list = [], # used when type is "memory_efficient_replay_buffer" + preload_rlds_path: Optional[str] = None, +): + """ + This is the high-level helper function to + create a replay buffer and wandb logger + + support only for "replay_buffer" and "memory_efficient_replay_buffer" + """ + print("shape of observation space and action space") + print(env.observation_space) + print(env.action_space) + + # init logger for RLDS + if rlds_logger_path: + # clean this to make this common + from oxe_envlogger.rlds_logger import RLDSLogger + + rlds_logger = RLDSLogger( + observation_space=env.observation_space, + action_space=env.action_space, + dataset_name="serl_rlds_dataset", + directory=rlds_logger_path, + max_episodes_per_file=5, # TODO: arbitrary number + ) + else: + rlds_logger = None + + if type == "replay_buffer": + replay_buffer = ReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + rlds_logger=rlds_logger, + ) + elif type == "memory_efficient_replay_buffer": + replay_buffer = MemoryEfficientReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + rlds_logger=rlds_logger, + image_keys=image_keys, + ) + else: + raise ValueError(f"Unsupported replay_buffer_type: {type}") + + if preload_rlds_path: + print(f" - Preloaded {preload_rlds_path} to replay buffer") + dataset = tfds.builder_from_directory(preload_rlds_path).as_dataset(split="all") + populate_datastore(replay_buffer, dataset, type="with_dones") + print(f" - done populated {len(replay_buffer)} samples to replay buffer") + + return replay_buffer From fd5bc59d5bcf5f72f2c9f6650ef9bb1f0d16ca9a Mon Sep 17 00:00:00 2001 From: youliang Date: Mon, 19 Feb 2024 11:58:10 -0800 Subject: [PATCH 2/5] nit Signed-off-by: youliang --- docs/sim_quick_start.md | 4 +- .../franka_sim/envs/panda_pick_gym_env.py | 51 ++++++++++--------- serl_launcher/setup.py | 2 +- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/docs/sim_quick_start.md b/docs/sim_quick_start.md index ad4e7e1e..cc028f57 100644 --- a/docs/sim_quick_start.md +++ b/docs/sim_quick_start.md @@ -1,5 +1,7 @@ # Quick Start with SERL in Sim +This is a minimal mujoco simulation environment for training with SERL. The environment consists of a panda robot arm and a cube. The goal is to lift the cube to a target position. The environment is implemented using `franka_sim` and `gym` interface. + ![](./images/franka_sim.png) ## Installation @@ -160,4 +162,4 @@ With the example above, we can load the data from the replay buffer by providing ./run_learner.sh --preload_rlds_path /path/to/load ``` -This is equivalent to the `--demo_path` argument in `examples/async_rlpd_drq_sim/run_learner.sh` script. +This is similar to the `examples/async_rlpd_drq_sim/run_learner.sh` script, which uses `--demo_path` argument which load .pkl offline demo trajectories. diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index deb6942b..57eb44d1 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, Tuple, Dict import gymnasium as gym import mujoco @@ -140,7 +140,8 @@ def __init__( def reset( self, seed=None, **kwargs - ) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" mujoco.mj_resetData(self._model, self._data) # Reset arm to home position. @@ -163,22 +164,21 @@ def reset( obs = self._compute_observation() return obs, {} - """ - take a step in the environment. - Params: - action: np.ndarray - - Returns: - observation: dict[str, np.ndarray], - reward: float, - done: bool, - truncated: bool, - info: dict[str, Any] - """ - def step( self, action: np.ndarray - ) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ x, y, z, grasp = action # Set the mocap position. @@ -227,6 +227,17 @@ def _compute_observation(self) -> dict: obs = {} obs["state"] = {} + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + # joint_pos = np.stack( # [self._data.sensor(f"panda/joint{i}_pos").data for i in range(1, 8)], # ).ravel() @@ -237,14 +248,6 @@ def _compute_observation(self) -> dict: # ).ravel() # obs["panda/joint_vel"] = joint_vel.astype(np.float32) - tcp_pos = self._data.sensor("2f85/pinch_pos").data - obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) - tcp_vel = self._data.sensor("2f85/pinch_vel").data - obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) - gripper_pos = np.array( - self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 - ) - obs["state"]["panda/gripper_pos"] = gripper_pos # joint_torque = np.stack( # [self._data.sensor(f"panda/joint{i}_torque").data for i in range(1, 8)], # ).ravel() diff --git a/serl_launcher/setup.py b/serl_launcher/setup.py index f33d0de7..5f7bdddf 100644 --- a/serl_launcher/setup.py +++ b/serl_launcher/setup.py @@ -13,7 +13,7 @@ "typing_extensions", "opencv-python", "lz4", - "agentlace@git+https://github.com/youliangtan/agentlace.git@e35c9c5ef440d3cc053a154c47b842f9c12b4356", + "agentlace@git+https://github.com/youliangtan/agentlace.git@2d5d6bff0778d65aa4a589cef2a2bd6f01c645c7", ], packages=find_packages(), zip_safe=False, From a885bccb122d956b59f45ac4c2303a84ec8b7e61 Mon Sep 17 00:00:00 2001 From: youliang Date: Fri, 23 Feb 2024 07:08:01 -0800 Subject: [PATCH 3/5] update commitid, docstr Signed-off-by: youliang --- serl_launcher/serl_launcher/utils/launcher.py | 14 ++++++++++---- serl_launcher/setup.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 79658714..44ed7dec 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -196,14 +196,20 @@ def make_replay_buffer( capacity: int = 1000000, rlds_logger_path: Optional[str] = None, type: str = "replay_buffer", - image_keys: list = [], # used when type is "memory_efficient_replay_buffer" + image_keys: list = [], # used only type=="memory_efficient_replay_buffer" preload_rlds_path: Optional[str] = None, ): """ This is the high-level helper function to - create a replay buffer and wandb logger + create a replay buffer for the given environment. - support only for "replay_buffer" and "memory_efficient_replay_buffer" + Args: + - env: gym or gymasium environment + - capacity: capacity of the replay buffer + - rlds_logger_path: path to save RLDS logs + - type: support only for "replay_buffer" and "memory_efficient_replay_buffer" + - image_keys: list of image keys, used only "memory_efficient_replay_buffer" + - preload_rlds_path: path to preloaded RLDS trajectories """ print("shape of observation space and action space") print(env.observation_space) @@ -211,7 +217,7 @@ def make_replay_buffer( # init logger for RLDS if rlds_logger_path: - # clean this to make this common + # from: https://github.com/rail-berkeley/oxe_envlogger from oxe_envlogger.rlds_logger import RLDSLogger rlds_logger = RLDSLogger( diff --git a/serl_launcher/setup.py b/serl_launcher/setup.py index 5f7bdddf..09557a59 100644 --- a/serl_launcher/setup.py +++ b/serl_launcher/setup.py @@ -13,7 +13,7 @@ "typing_extensions", "opencv-python", "lz4", - "agentlace@git+https://github.com/youliangtan/agentlace.git@2d5d6bff0778d65aa4a589cef2a2bd6f01c645c7", + "agentlace@git+https://github.com/youliangtan/agentlace.git@e61032fbce8a1e6d3dc2aeba21de082e4bf46fe3", ], packages=find_packages(), zip_safe=False, From ec4e4ef3dab1ef1710fa46ea5967448a37866828 Mon Sep 17 00:00:00 2001 From: youliang Date: Sun, 5 May 2024 16:58:05 -0700 Subject: [PATCH 4/5] fix gym and agentlace deps Signed-off-by: youliang --- .../async_drq_randomized.py | 4 ++-- .../record_bc_demos.py | 2 +- .../async_bin_relocation_fwbw_drq/record_demo.py | 2 +- .../record_transitions.py | 2 +- .../test_classifier.py | 2 +- .../train_reward_classifier.py | 2 +- .../async_cable_route_drq/async_drq_randomized.py | 4 ++-- examples/async_cable_route_drq/record_demo.py | 2 +- examples/async_cable_route_drq/test_classifier.py | 2 +- .../train_reward_classifier.py | 2 +- examples/async_drq_sim/async_drq_sim.py | 4 ++-- .../async_pcb_insert_drq/async_drq_randomized.py | 4 ++-- examples/async_pcb_insert_drq/record_demo.py | 2 +- .../async_peg_insert_drq/async_drq_randomized.py | 4 ++-- examples/async_peg_insert_drq/record_demo.py | 2 +- examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py | 4 ++-- .../async_sac_state_sim/async_sac_state_sim.py | 4 ++-- examples/bc_policy.py | 4 ++-- franka_sim/franka_sim/__init__.py | 2 +- franka_sim/franka_sim/envs/panda_pick_gym_env.py | 14 +++++++++++--- franka_sim/franka_sim/mujoco_gym_env.py | 2 +- franka_sim/franka_sim/test/test_gym_env_render.py | 2 +- franka_sim/requirements.txt | 1 + serl_launcher/requirements.txt | 4 ++-- serl_launcher/serl_launcher/common/evaluation.py | 2 +- serl_launcher/serl_launcher/data/data_store.py | 2 +- serl_launcher/serl_launcher/data/dataset.py | 2 +- .../data/memory_efficient_replay_buffer.py | 4 ++-- serl_launcher/serl_launcher/data/replay_buffer.py | 2 +- serl_launcher/serl_launcher/utils/sim_utils.py | 2 +- serl_launcher/serl_launcher/wrappers/chunking.py | 4 ++-- serl_launcher/serl_launcher/wrappers/dmcgym.py | 2 +- .../serl_launcher/wrappers/front_camera_wrapper.py | 4 ++-- serl_launcher/serl_launcher/wrappers/mujoco.py | 2 +- serl_launcher/serl_launcher/wrappers/norm.py | 2 +- serl_launcher/serl_launcher/wrappers/remap.py | 4 ++-- serl_launcher/serl_launcher/wrappers/roboverse.py | 2 +- .../serl_launcher/wrappers/serl_obs_wrappers.py | 4 ++-- .../serl_launcher/wrappers/video_recorder.py | 2 +- serl_launcher/setup.py | 2 +- serl_robot_infra/README.md | 2 +- serl_robot_infra/franka_env/__init__.py | 2 +- .../bin_relocation_env/franka_bin_relocation.py | 2 +- serl_robot_infra/franka_env/envs/franka_env.py | 2 +- .../franka_env/envs/pcb_env/franka_pcb_insert.py | 2 +- .../franka_env/envs/peg_env/franka_peg_insert.py | 2 +- serl_robot_infra/franka_env/envs/relative_env.py | 2 +- serl_robot_infra/franka_env/envs/wrappers.py | 6 +++--- serl_robot_infra/setup.py | 3 ++- 49 files changed, 76 insertions(+), 66 deletions(-) diff --git a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py index 9aa48b07..fe1039e5 100644 --- a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py +++ b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py @@ -11,8 +11,8 @@ from copy import deepcopy from collections import OrderedDict -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.common.evaluation import evaluate diff --git a/examples/async_bin_relocation_fwbw_drq/record_bc_demos.py b/examples/async_bin_relocation_fwbw_drq/record_bc_demos.py index 1466bd7a..0636cf89 100644 --- a/examples/async_bin_relocation_fwbw_drq/record_bc_demos.py +++ b/examples/async_bin_relocation_fwbw_drq/record_bc_demos.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_bin_relocation_fwbw_drq/record_demo.py b/examples/async_bin_relocation_fwbw_drq/record_demo.py index 8d33a3ca..73ccd817 100644 --- a/examples/async_bin_relocation_fwbw_drq/record_demo.py +++ b/examples/async_bin_relocation_fwbw_drq/record_demo.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_bin_relocation_fwbw_drq/record_transitions.py b/examples/async_bin_relocation_fwbw_drq/record_transitions.py index e61afd43..8b7f4dbe 100644 --- a/examples/async_bin_relocation_fwbw_drq/record_transitions.py +++ b/examples/async_bin_relocation_fwbw_drq/record_transitions.py @@ -7,7 +7,7 @@ add `--record_failed_only` to only record failed transitions """ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_bin_relocation_fwbw_drq/test_classifier.py b/examples/async_bin_relocation_fwbw_drq/test_classifier.py index f54d74dc..c1236c03 100644 --- a/examples/async_bin_relocation_fwbw_drq/test_classifier.py +++ b/examples/async_bin_relocation_fwbw_drq/test_classifier.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py b/examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py index 77ec14f4..511201c7 100644 --- a/examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py +++ b/examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py @@ -7,7 +7,7 @@ from flax.training import checkpoints import optax from tqdm import tqdm -import gymnasium as gym +import gym import os from absl import app, flags diff --git a/examples/async_cable_route_drq/async_drq_randomized.py b/examples/async_cable_route_drq/async_drq_randomized.py index f53ceb36..fe1c03d2 100644 --- a/examples/async_cable_route_drq/async_drq_randomized.py +++ b/examples/async_cable_route_drq/async_drq_randomized.py @@ -9,8 +9,8 @@ from absl import app, flags from flax.training import checkpoints -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.common.evaluation import evaluate diff --git a/examples/async_cable_route_drq/record_demo.py b/examples/async_cable_route_drq/record_demo.py index 4ca6d5fd..f359b28b 100644 --- a/examples/async_cable_route_drq/record_demo.py +++ b/examples/async_cable_route_drq/record_demo.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_cable_route_drq/test_classifier.py b/examples/async_cable_route_drq/test_classifier.py index 30ae47d6..8bec2573 100644 --- a/examples/async_cable_route_drq/test_classifier.py +++ b/examples/async_cable_route_drq/test_classifier.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_cable_route_drq/train_reward_classifier.py b/examples/async_cable_route_drq/train_reward_classifier.py index 54e5b95f..8c6ebf32 100644 --- a/examples/async_cable_route_drq/train_reward_classifier.py +++ b/examples/async_cable_route_drq/train_reward_classifier.py @@ -5,7 +5,7 @@ from flax.training import checkpoints import optax from tqdm import tqdm -import gymnasium as gym +import gym from serl_launcher.wrappers.chunking import ChunkingWrapper from serl_launcher.utils.train_utils import concat_batches diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index f60dede3..5de34e99 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -9,8 +9,8 @@ from absl import app, flags from flax.training import checkpoints -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.common.evaluation import evaluate diff --git a/examples/async_pcb_insert_drq/async_drq_randomized.py b/examples/async_pcb_insert_drq/async_drq_randomized.py index 2798b6ff..8f9607c2 100644 --- a/examples/async_pcb_insert_drq/async_drq_randomized.py +++ b/examples/async_pcb_insert_drq/async_drq_randomized.py @@ -9,8 +9,8 @@ from absl import app, flags from flax.training import checkpoints -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.common.evaluation import evaluate diff --git a/examples/async_pcb_insert_drq/record_demo.py b/examples/async_pcb_insert_drq/record_demo.py index 38dbda14..8413fcba 100644 --- a/examples/async_pcb_insert_drq/record_demo.py +++ b/examples/async_pcb_insert_drq/record_demo.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_peg_insert_drq/async_drq_randomized.py b/examples/async_peg_insert_drq/async_drq_randomized.py index 20ac5056..4c6c109e 100644 --- a/examples/async_peg_insert_drq/async_drq_randomized.py +++ b/examples/async_peg_insert_drq/async_drq_randomized.py @@ -9,8 +9,8 @@ from absl import app, flags from flax.training import checkpoints -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.common.evaluation import evaluate diff --git a/examples/async_peg_insert_drq/record_demo.py b/examples/async_peg_insert_drq/record_demo.py index ccdb4070..bf88e8b9 100644 --- a/examples/async_peg_insert_drq/record_demo.py +++ b/examples/async_peg_insert_drq/record_demo.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym from tqdm import tqdm import numpy as np import copy diff --git a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py index 5f22e0da..79a24cac 100644 --- a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py +++ b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py @@ -9,8 +9,8 @@ from absl import app, flags from flax.training import checkpoints -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.drq import DrQAgent from serl_launcher.common.evaluation import evaluate diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 90b96d63..9e93f4e8 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -3,7 +3,7 @@ import time from functools import partial -import gymnasium as gym +import gym import jax import jax.numpy as jnp import numpy as np @@ -20,7 +20,7 @@ make_replay_buffer, ) -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.agents.continuous.sac import SACAgent from serl_launcher.common.evaluation import evaluate from serl_launcher.utils.timer_utils import Timer diff --git a/examples/bc_policy.py b/examples/bc_policy.py index 007384c3..fe3bf747 100644 --- a/examples/bc_policy.py +++ b/examples/bc_policy.py @@ -8,8 +8,8 @@ from copy import deepcopy import time -import gymnasium as gym -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from serl_launcher.utils.timer_utils import Timer from serl_launcher.wrappers.chunking import ChunkingWrapper diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index afef4ca4..967e9e5d 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -5,7 +5,7 @@ "GymRenderingSpec", ] -from gymnasium.envs.registration import register +from gym.envs.registration import register register( id="PandaPickCube-v0", diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index 57eb44d1..c93c7c8f 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -1,10 +1,17 @@ from pathlib import Path from typing import Any, Literal, Tuple, Dict -import gymnasium as gym +import gym import mujoco import numpy as np -from gymnasium import spaces +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None from franka_sim.controllers import opspace from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv @@ -130,8 +137,9 @@ def __init__( dtype=np.float32, ) + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer - self._viewer = MujocoRenderer( self.model, self.data, diff --git a/franka_sim/franka_sim/mujoco_gym_env.py b/franka_sim/franka_sim/mujoco_gym_env.py index a0acec74..2c085937 100644 --- a/franka_sim/franka_sim/mujoco_gym_env.py +++ b/franka_sim/franka_sim/mujoco_gym_env.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Literal, Optional -import gymnasium as gym +import gym import mujoco import numpy as np diff --git a/franka_sim/franka_sim/test/test_gym_env_render.py b/franka_sim/franka_sim/test/test_gym_env_render.py index cc059343..bc706f43 100644 --- a/franka_sim/franka_sim/test/test_gym_env_render.py +++ b/franka_sim/franka_sim/test/test_gym_env_render.py @@ -1,6 +1,6 @@ import time -import gymnasium as gym +import gym import mujoco import mujoco.viewer import numpy as np diff --git a/franka_sim/requirements.txt b/franka_sim/requirements.txt index 06dec758..1c7bae24 100644 --- a/franka_sim/requirements.txt +++ b/franka_sim/requirements.txt @@ -1,5 +1,6 @@ dm_env mujoco==2.3.7 +gym >= 0.26 gymnasium dm-robotics-transformations imageio[ffmpeg] diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index 6b29152c..48e2d6fa 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -8,9 +8,9 @@ tqdm >= 4.60.0 chex==0.1.85 optax==0.1.5 absl-py >= 0.12.0 -scipy >= 1.6.0 +scipy <= 1.12.0 wandb >= 0.12.14 -tensorflow==2.15.0 +tensorflow>=2.16.0 tensorflow_probability>=0.23.0 einops >= 0.6.1 imageio >= 2.31.1 diff --git a/serl_launcher/serl_launcher/common/evaluation.py b/serl_launcher/serl_launcher/common/evaluation.py index 7bde12a0..5065a5b1 100644 --- a/serl_launcher/serl_launcher/common/evaluation.py +++ b/serl_launcher/serl_launcher/common/evaluation.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import Dict -import gymnasium as gym +import gym import jax import numpy as np diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index e9f99969..20e65813 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -1,7 +1,7 @@ from threading import Lock from typing import Union, Iterable -import gymnasium as gym +import gym import jax from serl_launcher.data.replay_buffer import ReplayBuffer from serl_launcher.data.memory_efficient_replay_buffer import ( diff --git a/serl_launcher/serl_launcher/data/dataset.py b/serl_launcher/serl_launcher/data/dataset.py index 5fff673c..0760fb02 100644 --- a/serl_launcher/serl_launcher/data/dataset.py +++ b/serl_launcher/serl_launcher/data/dataset.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import numpy as np from flax.core import frozen_dict -from gymnasium.utils import seeding +from gym.utils import seeding DataType = Union[np.ndarray, Dict[str, "DataType"]] DatasetDict = Dict[str, DataType] diff --git a/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py b/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py index 80f8d461..d94f1143 100644 --- a/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py @@ -1,12 +1,12 @@ import copy from typing import Iterable, Optional, Tuple -import gymnasium as gym +import gym import numpy as np from serl_launcher.data.dataset import DatasetDict, _sample from serl_launcher.data.replay_buffer import ReplayBuffer from flax.core import frozen_dict -from gymnasium.spaces import Box +from gym.spaces import Box class MemoryEfficientReplayBuffer(ReplayBuffer): diff --git a/serl_launcher/serl_launcher/data/replay_buffer.py b/serl_launcher/serl_launcher/data/replay_buffer.py index a217ef7b..62acaf09 100644 --- a/serl_launcher/serl_launcher/data/replay_buffer.py +++ b/serl_launcher/serl_launcher/data/replay_buffer.py @@ -1,7 +1,7 @@ import collections from typing import Any, Iterator, Optional, Sequence, Tuple, Union -import gymnasium as gym +import gym import jax import numpy as np from serl_launcher.data.dataset import Dataset, DatasetDict diff --git a/serl_launcher/serl_launcher/utils/sim_utils.py b/serl_launcher/serl_launcher/utils/sim_utils.py index 0ed2c1ea..b32ebb1b 100644 --- a/serl_launcher/serl_launcher/utils/sim_utils.py +++ b/serl_launcher/serl_launcher/utils/sim_utils.py @@ -1,6 +1,6 @@ from typing import Callable, Union -import gymnasium as gym +import gym import numpy as np try: diff --git a/serl_launcher/serl_launcher/wrappers/chunking.py b/serl_launcher/serl_launcher/wrappers/chunking.py index 4c2499fd..7007d744 100644 --- a/serl_launcher/serl_launcher/wrappers/chunking.py +++ b/serl_launcher/serl_launcher/wrappers/chunking.py @@ -1,8 +1,8 @@ from collections import deque from typing import Optional -import gymnasium as gym -import gymnasium.spaces +import gym +import gym.spaces import jax import numpy as np diff --git a/serl_launcher/serl_launcher/wrappers/dmcgym.py b/serl_launcher/serl_launcher/wrappers/dmcgym.py index b8c62a66..f42003a0 100644 --- a/serl_launcher/serl_launcher/wrappers/dmcgym.py +++ b/serl_launcher/serl_launcher/wrappers/dmcgym.py @@ -6,7 +6,7 @@ from typing import OrderedDict import dm_env -import gymnasium as gym +import gym import numpy as np from gym import spaces diff --git a/serl_launcher/serl_launcher/wrappers/front_camera_wrapper.py b/serl_launcher/serl_launcher/wrappers/front_camera_wrapper.py index bcacb438..32e2629e 100644 --- a/serl_launcher/serl_launcher/wrappers/front_camera_wrapper.py +++ b/serl_launcher/serl_launcher/wrappers/front_camera_wrapper.py @@ -1,5 +1,5 @@ -import gymnasium as gym -from gymnasium.core import Env +import gym +from gym.core import Env from copy import deepcopy diff --git a/serl_launcher/serl_launcher/wrappers/mujoco.py b/serl_launcher/serl_launcher/wrappers/mujoco.py index d9c22bc0..7d9bc89a 100644 --- a/serl_launcher/serl_launcher/wrappers/mujoco.py +++ b/serl_launcher/serl_launcher/wrappers/mujoco.py @@ -1,6 +1,6 @@ from typing import Callable, Union -import gymnasium as gym +import gym import numpy as np from scipy.spatial.transform import Rotation diff --git a/serl_launcher/serl_launcher/wrappers/norm.py b/serl_launcher/serl_launcher/wrappers/norm.py index 6b7d54f3..2021ebe4 100644 --- a/serl_launcher/serl_launcher/wrappers/norm.py +++ b/serl_launcher/serl_launcher/wrappers/norm.py @@ -1,4 +1,4 @@ -import gymnasium as gym +import gym class UnnormalizeActionProprio(gym.ActionWrapper, gym.ObservationWrapper): diff --git a/serl_launcher/serl_launcher/wrappers/remap.py b/serl_launcher/serl_launcher/wrappers/remap.py index 9a986d74..7acb2d93 100644 --- a/serl_launcher/serl_launcher/wrappers/remap.py +++ b/serl_launcher/serl_launcher/wrappers/remap.py @@ -1,7 +1,7 @@ from typing import Any -import gymnasium as gym -import gymnasium.spaces +import gym +import gym.spaces import jax diff --git a/serl_launcher/serl_launcher/wrappers/roboverse.py b/serl_launcher/serl_launcher/wrappers/roboverse.py index 0068ce1f..835749b8 100644 --- a/serl_launcher/serl_launcher/wrappers/roboverse.py +++ b/serl_launcher/serl_launcher/wrappers/roboverse.py @@ -1,6 +1,6 @@ from typing import Callable, Union -import gymnasium as gym +import gym import numpy as np diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index 6e9a1e46..41c169f9 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -1,5 +1,5 @@ -import gymnasium as gym -from gymnasium.spaces import flatten_space, flatten +import gym +from gym.spaces import flatten_space, flatten class SERLObsWrapper(gym.ObservationWrapper): diff --git a/serl_launcher/serl_launcher/wrappers/video_recorder.py b/serl_launcher/serl_launcher/wrappers/video_recorder.py index e4a7763b..658417db 100644 --- a/serl_launcher/serl_launcher/wrappers/video_recorder.py +++ b/serl_launcher/serl_launcher/wrappers/video_recorder.py @@ -1,7 +1,7 @@ import os from typing import List, Optional -import gymnasium as gym +import gym import imageio import numpy as np import tensorflow as tf diff --git a/serl_launcher/setup.py b/serl_launcher/setup.py index 09557a59..e80e8eba 100644 --- a/serl_launcher/setup.py +++ b/serl_launcher/setup.py @@ -13,7 +13,7 @@ "typing_extensions", "opencv-python", "lz4", - "agentlace@git+https://github.com/youliangtan/agentlace.git@e61032fbce8a1e6d3dc2aeba21de082e4bf46fe3", + "agentlace@git+https://github.com/youliangtan/agentlace.git@892d1557264d7bb1d5df04b37638c850c9d36f35", ], packages=find_packages(), zip_safe=False, diff --git a/serl_robot_infra/README.md b/serl_robot_infra/README.md index 044cafd8..fc26c3a3 100644 --- a/serl_robot_infra/README.md +++ b/serl_robot_infra/README.md @@ -83,7 +83,7 @@ Lastly, we use a gym env interface to interact with the robot server, defined in Example Usage ```py -import gymnasium as gym +import gym import franka_env env = gym.make("FrankaEnv-Vision-v0") ``` diff --git a/serl_robot_infra/franka_env/__init__.py b/serl_robot_infra/franka_env/__init__.py index 06d50009..fad5a239 100644 --- a/serl_robot_infra/franka_env/__init__.py +++ b/serl_robot_infra/franka_env/__init__.py @@ -1,4 +1,4 @@ -from gymnasium.envs.registration import register +from gym.envs.registration import register import numpy as np register( diff --git a/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py b/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py index 86bc6ad8..cc757058 100644 --- a/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py +++ b/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py @@ -4,7 +4,7 @@ import copy import cv2 import queue -import gymnasium as gym +import gym from franka_env.envs.franka_env import FrankaEnv from franka_env.utils.rotations import euler_2_quat diff --git a/serl_robot_infra/franka_env/envs/franka_env.py b/serl_robot_infra/franka_env/envs/franka_env.py index f21d4826..976534d7 100644 --- a/serl_robot_infra/franka_env/envs/franka_env.py +++ b/serl_robot_infra/franka_env/envs/franka_env.py @@ -1,6 +1,6 @@ """Gym Interface for Franka""" import numpy as np -import gymnasium as gym +import gym import cv2 import copy from scipy.spatial.transform import Rotation diff --git a/serl_robot_infra/franka_env/envs/pcb_env/franka_pcb_insert.py b/serl_robot_infra/franka_env/envs/pcb_env/franka_pcb_insert.py index 87dd569a..46aa11d1 100644 --- a/serl_robot_infra/franka_env/envs/pcb_env/franka_pcb_insert.py +++ b/serl_robot_infra/franka_env/envs/pcb_env/franka_pcb_insert.py @@ -1,5 +1,5 @@ import numpy as np -import gymnasium as gym +import gym import time import requests import copy diff --git a/serl_robot_infra/franka_env/envs/peg_env/franka_peg_insert.py b/serl_robot_infra/franka_env/envs/peg_env/franka_peg_insert.py index 2c1336b7..07f98e48 100644 --- a/serl_robot_infra/franka_env/envs/peg_env/franka_peg_insert.py +++ b/serl_robot_infra/franka_env/envs/peg_env/franka_peg_insert.py @@ -1,5 +1,5 @@ import numpy as np -import gymnasium as gym +import gym import time import requests import copy diff --git a/serl_robot_infra/franka_env/envs/relative_env.py b/serl_robot_infra/franka_env/envs/relative_env.py index 75b15bf6..8cad0dac 100644 --- a/serl_robot_infra/franka_env/envs/relative_env.py +++ b/serl_robot_infra/franka_env/envs/relative_env.py @@ -1,5 +1,5 @@ from scipy.spatial.transform import Rotation as R -import gymnasium as gym +import gym import numpy as np from gym import Env from franka_env.utils.transformations import ( diff --git a/serl_robot_infra/franka_env/envs/wrappers.py b/serl_robot_infra/franka_env/envs/wrappers.py index 38422974..7568348c 100644 --- a/serl_robot_infra/franka_env/envs/wrappers.py +++ b/serl_robot_infra/franka_env/envs/wrappers.py @@ -1,8 +1,8 @@ import time -from gymnasium import Env, spaces -import gymnasium as gym +from gym import Env, spaces +import gym import numpy as np -from gymnasium.spaces import Box +from gym.spaces import Box import copy from franka_env.spacemouse.spacemouse_expert import SpaceMouseExpert from franka_env.utils.rotations import quat_2_euler diff --git a/serl_robot_infra/setup.py b/serl_robot_infra/setup.py index b2e63bb4..a788a6e5 100644 --- a/serl_robot_infra/setup.py +++ b/serl_robot_infra/setup.py @@ -5,7 +5,7 @@ version="0.0.1", packages=find_packages(), install_requires=[ - "gymnasium", + "gym>=0.26", "pyrealsense2", "pymodbus==2.5.3", "opencv-python", @@ -18,5 +18,6 @@ "requests", "flask", "defusedxml", + "pynput", ], ) From e6c0edb97d834552f5aa3f590e04e1aefcbd3917 Mon Sep 17 00:00:00 2001 From: youliang Date: Sun, 5 May 2024 17:03:36 -0700 Subject: [PATCH 5/5] fix style Signed-off-by: youliang --- franka_sim/franka_sim/envs/panda_pick_gym_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index c93c7c8f..f0a8d25b 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -140,6 +140,7 @@ def __init__( # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It # is possible to add a similar viewer feature with gym, but that can be a future TODO from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + self._viewer = MujocoRenderer( self.model, self.data,