Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing joint_position_action and add gripper action #221

Merged
merged 9 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/task_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ jobs:
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT

pip install ".[dev]"
python3 -m unittest discover tests/demos
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
3 changes: 2 additions & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ jobs:
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT

pip install ".[dev]"
python3 -m unittest discover tests/unit
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
6 changes: 0 additions & 6 deletions requirements.txt

This file was deleted.

36 changes: 20 additions & 16 deletions rlbench/backend/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self,

self._robot_shapes = self.robot.arm.get_objects_in_tree(
object_type=ObjectType.SHAPE)
self._execute_demo_joint_position_action = None
self._joint_position_action = None

def load(self, task: Task) -> None:
"""Loads the task and positions at the centre of the workspace.
Expand Down Expand Up @@ -337,6 +337,8 @@ def get_demo(self, record: bool = True,
demo = []
if record:
self.pyrep.step() # Need this here or get_force doesn't work...
self._joint_position_action = None
gripper_open = 1.0 if self.robot.gripper.get_open_amount()[0] > 0.9 else 0.0
demo.append(self.get_observation())
while True:
success = False
Expand Down Expand Up @@ -366,7 +368,7 @@ def get_demo(self, record: bool = True,
while not done:
done = path.step()
self.step()
self._execute_demo_joint_position_action = path.get_executed_joint_position_action()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
self._demo_record_step(demo, record, callable_each_step)
success, term = self.task.success()

Expand All @@ -385,9 +387,10 @@ def get_demo(self, record: bool = True,
if not contains_param:
done = False
while not done:
done = gripper.actuate(1.0, 0.04)
self.pyrep.step()
self.task.step()
gripper_open = 1.0
done = gripper.actuate(gripper_open, 0.04)
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
if self._obs_config.record_gripper_closing:
self._demo_record_step(
demo, record, callable_each_step)
Expand All @@ -397,9 +400,10 @@ def get_demo(self, record: bool = True,
if not contains_param:
done = False
while not done:
done = gripper.actuate(0.0, 0.04)
self.pyrep.step()
self.task.step()
gripper_open = 0.0
done = gripper.actuate(gripper_open, 0.04)
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
if self._obs_config.record_gripper_closing:
self._demo_record_step(
demo, record, callable_each_step)
Expand All @@ -409,9 +413,10 @@ def get_demo(self, record: bool = True,
num = float(rest[:rest.index(')')])
done = False
while not done:
done = gripper.actuate(num, 0.04)
self.pyrep.step()
self.task.step()
gripper_open = num
done = gripper.actuate(gripper_open, 0.04)
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
if self._obs_config.record_gripper_closing:
self._demo_record_step(
demo, record, callable_each_step)
Expand All @@ -429,8 +434,8 @@ def get_demo(self, record: bool = True,
# (e.g. ball rowling to goal)
if not success:
for _ in range(10):
self.pyrep.step()
self.task.step()
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
self._demo_record_step(demo, record, callable_each_step)
success, term = self.task.success()
if success:
Expand Down Expand Up @@ -545,8 +550,7 @@ def _get_cam_data(cam: VisionSensor, name: str):
misc.update(_get_cam_data(self._cam_front, 'front_camera'))
misc.update(_get_cam_data(self._cam_wrist, 'wrist_camera'))
misc.update({"variation_index": self._variation_index})
if self._execute_demo_joint_position_action is not None:
if self._joint_position_action is not None:
# Store the actual requested joint positions during demo collection
misc.update({"executed_demo_joint_position_action": self._execute_demo_joint_position_action})
self._execute_demo_joint_position_action = None
misc.update({"joint_position_action": self._joint_position_action})
return misc
4 changes: 2 additions & 2 deletions setup.py
stepjam marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_version(rel_path):
raise RuntimeError("Unable to find version string.")

core_requirements = [
"pyrep @ git+https://github.com/stepjam/PyRep.git@076ca15c57f2495a4194da03565891ab1aaa317e",
"pyrep @ git+https://github.com/stepjam/PyRep.git@cd9830b58ef09538562b785fc0c257f528f1762b",
"numpy",
"Pillow",
"pyquaternion",
Expand All @@ -60,7 +60,7 @@ def get_version(rel_path):
'rlbench.gym'
],
extras_require={
"dev": ["html-testRunner", "gym"]
"dev": ["pytest", "html-testRunner", "gym"]
},
package_data={'': ['*.ttm', '*.obj', '**/**/*.ttm', '**/**/*.obj'],
'rlbench': ['task_design.ttt']},
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,35 @@ def test_swap_arm(self):
robot_setup=robot_config)
self.env.launch()
self.env.shutdown()

def test_executed_jp_action(self):
for task_cls in [ReachTarget, TakeLidOffSaucepan]:
with self.subTest(task_cls=task_cls):
task = self.get_task(
task_cls, JointPosition(True))
num_episodes = 20
demos = task.get_demos(num_episodes, live_demos=True)
total_reward = 0.0
# Check if executed joint position action is stored
for demo in demos:
jp_action = []
self.assertTrue("joint_position_action" not in demo[0].misc)
for t, obs in enumerate(demo):
if t == 0:
# First timestep should not have an action
self.assertTrue('joint_position_action' not in obs.misc)
else:
self.assertTrue("joint_position_action" in obs.misc)
jp_action.append(obs.misc["joint_position_action"])

task.reset_to_demo(demo)
for t, action in enumerate(jp_action):
obs, reward, term = task.step(action)
if term:
break
total_reward += reward

success_rate = total_reward / num_episodes
self.assertTrue(success_rate >= 0.9)
self.env.shutdown()

Loading