From f3dc2a767ddf1397b278d13cf84e925b9b84df98 Mon Sep 17 00:00:00 2001 From: Younggyo Seo Date: Tue, 11 Jun 2024 14:23:07 +0100 Subject: [PATCH 1/2] support setting max_velocity and max_acceleration for arms --- rlbench/environment.py | 11 ++++++++++- tools/dataset_generator.py | 6 ++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/rlbench/environment.py b/rlbench/environment.py index b1fb2f07f..4e9cb1a23 100644 --- a/rlbench/environment.py +++ b/rlbench/environment.py @@ -1,4 +1,5 @@ import importlib +from functools import partial from os.path import exists, dirname, abspath, join from typing import Type, List @@ -38,7 +39,9 @@ def __init__(self, visual_randomization_config: VisualRandomizationConfig = None, dynamics_randomization_config: DynamicsRandomizationConfig = None, attach_grasped_objects: bool = True, - shaped_rewards: bool = False + shaped_rewards: bool = False, + arm_max_velocity: float = 1.0, + arm_max_acceleration: float = 4.0, ): self._dataset_root = dataset_root @@ -54,6 +57,8 @@ def __init__(self, self._dynamics_randomization_config = dynamics_randomization_config self._attach_grasped_objects = attach_grasped_objects self._shaped_rewards = shaped_rewards + self._arm_max_velocity = arm_max_velocity + self._arm_max_acceleration = arm_max_acceleration if robot_setup not in SUPPORTED_ROBOTS.keys(): raise ValueError('robot_configuration must be one of %s' % @@ -97,6 +102,10 @@ def launch(self): arm_class, gripper_class, _ = SUPPORTED_ROBOTS[ self._robot_setup] + arm_class = partial( + arm_class, + max_velocity=self._arm_max_velocity, + max_acceleration=self._arm_max_acceleration) # We assume the panda is already loaded in the scene. if self._robot_setup != 'panda': diff --git a/tools/dataset_generator.py b/tools/dataset_generator.py index be9d02b7d..2c619896c 100644 --- a/tools/dataset_generator.py +++ b/tools/dataset_generator.py @@ -38,6 +38,10 @@ 'The number of episodes to collect per task.') flags.DEFINE_integer('variations', -1, 'Number of variations to collect per task. -1 for all.') +flags.DEFINE_float('arm_max_velocity', 1.0, + 'Max arm velocity used for motion planning.') +flags.DEFINE_float('arm_max_acceleration', 4.0, + 'Max arm acceleration used for motion planning.') def check_and_make(dir): @@ -214,6 +218,8 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks): rlbench_env = Environment( action_mode=MoveArmThenGripper(JointVelocity(), Discrete()), obs_config=obs_config, + arm_max_velocity=FLAGS.arm_max_velocity, + arm_max_acceleration=FLAGS.arm_max_acceleration, headless=True) rlbench_env.launch() From de486abdaf49f2ce9ee9d3faadff22aeda5d7648 Mon Sep 17 00:00:00 2001 From: Younggyo Seo Date: Tue, 2 Jul 2024 09:59:40 +0100 Subject: [PATCH 2/2] fix a conflict in dataset_generator --- rlbench/dataset_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rlbench/dataset_generator.py b/rlbench/dataset_generator.py index 85c2cffc4..f05bbe2ae 100644 --- a/rlbench/dataset_generator.py +++ b/rlbench/dataset_generator.py @@ -192,8 +192,8 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks, args): rlbench_env = Environment( action_mode=MoveArmThenGripper(JointVelocity(), Discrete()), obs_config=obs_config, - arm_max_velocity=FLAGS.arm_max_velocity, - arm_max_acceleration=FLAGS.arm_max_acceleration, + arm_max_velocity=args.arm_max_velocity, + arm_max_acceleration=args.arm_max_acceleration, headless=True) rlbench_env.launch()