Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/gym agentlace deps #45

Merged
merged 6 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from copy import deepcopy
from collections import OrderedDict

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
Expand Down
2 changes: 1 addition & 1 deletion examples/async_bin_relocation_fwbw_drq/record_bc_demos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
2 changes: 1 addition & 1 deletion examples/async_bin_relocation_fwbw_drq/record_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
add `--record_failed_only` to only record failed transitions
"""

import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
2 changes: 1 addition & 1 deletion examples/async_bin_relocation_fwbw_drq/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flax.training import checkpoints
import optax
from tqdm import tqdm
import gymnasium as gym
import gym
import os
from absl import app, flags

Expand Down
4 changes: 2 additions & 2 deletions examples/async_cable_route_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from absl import app, flags
from flax.training import checkpoints

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
Expand Down
2 changes: 1 addition & 1 deletion examples/async_cable_route_drq/record_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
2 changes: 1 addition & 1 deletion examples/async_cable_route_drq/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
2 changes: 1 addition & 1 deletion examples/async_cable_route_drq/train_reward_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flax.training import checkpoints
import optax
from tqdm import tqdm
import gymnasium as gym
import gym

from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.utils.train_utils import concat_batches
Expand Down
4 changes: 2 additions & 2 deletions examples/async_drq_sim/async_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from absl import app, flags
from flax.training import checkpoints

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
Expand Down
4 changes: 2 additions & 2 deletions examples/async_pcb_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from flax.training import checkpoints
import threading

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
Expand Down
2 changes: 1 addition & 1 deletion examples/async_pcb_insert_drq/record_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
4 changes: 2 additions & 2 deletions examples/async_peg_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from absl import app, flags
from flax.training import checkpoints

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
Expand Down
2 changes: 1 addition & 1 deletion examples/async_peg_insert_drq/record_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym
from tqdm import tqdm
import numpy as np
import copy
Expand Down
4 changes: 2 additions & 2 deletions examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from absl import app, flags
from flax.training import checkpoints

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.agents.continuous.drq import DrQAgent
from serl_launcher.common.evaluation import evaluate
Expand Down
4 changes: 2 additions & 2 deletions examples/async_sac_state_sim/async_sac_state_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from functools import partial

import gymnasium as gym
import gym
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -20,7 +20,7 @@
make_replay_buffer,
)

from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from serl_launcher.agents.continuous.sac import SACAgent
from serl_launcher.common.evaluation import evaluate
from serl_launcher.utils.timer_utils import Timer
Expand Down
4 changes: 2 additions & 2 deletions examples/bc_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from copy import deepcopy
import time

import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
import gym
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics

from serl_launcher.utils.timer_utils import Timer
from serl_launcher.wrappers.chunking import ChunkingWrapper
Expand Down
2 changes: 1 addition & 1 deletion franka_sim/franka_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"GymRenderingSpec",
]

from gymnasium.envs.registration import register
from gym.envs.registration import register

register(
id="PandaPickCube-v0",
Expand Down
13 changes: 11 additions & 2 deletions franka_sim/franka_sim/envs/panda_pick_gym_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from pathlib import Path
from typing import Any, Literal, Tuple, Dict

import gymnasium as gym
import gym
import mujoco
import numpy as np
from gymnasium import spaces
from gym import spaces

try:
import mujoco_py
except ImportError as e:
MUJOCO_PY_IMPORT_ERROR = e
else:
MUJOCO_PY_IMPORT_ERROR = None

from franka_sim.controllers import opspace
from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv
Expand Down Expand Up @@ -130,6 +137,8 @@ def __init__(
dtype=np.float32,
)

# NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It
# is possible to add a similar viewer feature with gym, but that can be a future TODO
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer

self._viewer = MujocoRenderer(
Expand Down
2 changes: 1 addition & 1 deletion franka_sim/franka_sim/mujoco_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Literal, Optional

import gymnasium as gym
import gym
import mujoco
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion franka_sim/franka_sim/test/test_gym_env_render.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

import gymnasium as gym
import gym
import mujoco
import mujoco.viewer
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions franka_sim/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
dm_env
mujoco==2.3.7
gym >= 0.26
gymnasium
dm-robotics-transformations
imageio[ffmpeg]
Expand Down
4 changes: 2 additions & 2 deletions serl_launcher/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ tqdm >= 4.60.0
chex==0.1.85
optax==0.1.5
absl-py >= 0.12.0
scipy >= 1.6.0
scipy <= 1.12.0
wandb >= 0.12.14
tensorflow==2.15.0
tensorflow>=2.16.0
tensorflow_probability>=0.23.0
einops >= 0.6.1
imageio >= 2.31.1
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from typing import Dict

import gymnasium as gym
import gym
import jax
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/data/data_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from threading import Lock
from typing import Union, Iterable

import gymnasium as gym
import gym
import jax
from serl_launcher.data.replay_buffer import ReplayBuffer
from serl_launcher.data.memory_efficient_replay_buffer import (
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import numpy as np
from flax.core import frozen_dict
from gymnasium.utils import seeding
from gym.utils import seeding

DataType = Union[np.ndarray, Dict[str, "DataType"]]
DatasetDict = Dict[str, DataType]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import copy
from typing import Iterable, Optional, Tuple

import gymnasium as gym
import gym
import numpy as np
from serl_launcher.data.dataset import DatasetDict, _sample
from serl_launcher.data.replay_buffer import ReplayBuffer
from flax.core import frozen_dict
from gymnasium.spaces import Box
from gym.spaces import Box


class MemoryEfficientReplayBuffer(ReplayBuffer):
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/data/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
from typing import Any, Iterator, Optional, Sequence, Tuple, Union

import gymnasium as gym
import gym
import jax
import numpy as np
from serl_launcher.data.dataset import Dataset, DatasetDict
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/utils/sim_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Union

import gymnasium as gym
import gym
import numpy as np

try:
Expand Down
4 changes: 2 additions & 2 deletions serl_launcher/serl_launcher/wrappers/chunking.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import deque
from typing import Optional

import gymnasium as gym
import gymnasium.spaces
import gym
import gym.spaces
import jax
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/wrappers/dmcgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import OrderedDict

import dm_env
import gymnasium as gym
import gym
import numpy as np
from gym import spaces

Expand Down
4 changes: 2 additions & 2 deletions serl_launcher/serl_launcher/wrappers/front_camera_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gymnasium as gym
from gymnasium.core import Env
import gym
from gym.core import Env
from copy import deepcopy


Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/wrappers/mujoco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Union

import gymnasium as gym
import gym
import numpy as np
from scipy.spatial.transform import Rotation

Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/wrappers/norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gymnasium as gym
import gym


class UnnormalizeActionProprio(gym.ActionWrapper, gym.ObservationWrapper):
Expand Down
4 changes: 2 additions & 2 deletions serl_launcher/serl_launcher/wrappers/remap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import gymnasium as gym
import gymnasium.spaces
import gym
import gym.spaces
import jax


Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/wrappers/roboverse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Union

import gymnasium as gym
import gym
import numpy as np


Expand Down
4 changes: 2 additions & 2 deletions serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gymnasium as gym
from gymnasium.spaces import flatten_space, flatten
import gym
from gym.spaces import flatten_space, flatten


class SERLObsWrapper(gym.ObservationWrapper):
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/serl_launcher/wrappers/video_recorder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import List, Optional

import gymnasium as gym
import gym
import imageio
import numpy as np
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"typing_extensions",
"opencv-python",
"lz4",
"agentlace@git+https://github.com/youliangtan/agentlace.git@e61032fbce8a1e6d3dc2aeba21de082e4bf46fe3",
"agentlace@git+https://github.com/youliangtan/agentlace.git@892d1557264d7bb1d5df04b37638c850c9d36f35",
],
packages=find_packages(),
zip_safe=False,
Expand Down
2 changes: 1 addition & 1 deletion serl_robot_infra/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Lastly, we use a gym env interface to interact with the robot server, defined in

Example Usage
```py
import gymnasium as gym
import gym
import franka_env
env = gym.make("FrankaEnv-Vision-v0")
```
Expand Down
2 changes: 1 addition & 1 deletion serl_robot_infra/franka_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gymnasium.envs.registration import register
from gym.envs.registration import register
import numpy as np

register(
Expand Down
Loading
Loading