From 36ea382dabbde2a2431951fc9048f7e5ffd8106b Mon Sep 17 00:00:00 2001 From: Balakumar Sundaralingam Date: Fri, 22 Nov 2024 14:15:18 -0800 Subject: [PATCH] Add planning to grasp API --- CHANGELOG.md | 25 +- benchmark/curobo_benchmark.py | 5 +- benchmark/curobo_voxel_benchmark.py | 25 +- docker/ros1_x86.dockerfile | 69 ----- docker/start_docker_arm64.sh | 22 -- docker/start_docker_x86_robot.sh | 24 -- examples/usd_example.py | 38 ++- .../cuda_robot_model/cuda_robot_generator.py | 6 +- src/curobo/curobolib/cpp/geom_cuda.cpp | 39 +-- .../curobolib/cpp/pose_distance_kernel.cu | 190 +++--------- .../curobolib/cpp/self_collision_kernel.cu | 75 +++-- src/curobo/curobolib/cpp/sphere_obb_kernel.cu | 135 ++++----- src/curobo/curobolib/geom.py | 14 +- src/curobo/curobolib/kinematics.py | 13 +- src/curobo/curobolib/ls.py | 2 - src/curobo/geom/sdf/world.py | 29 +- src/curobo/geom/sdf/world_voxel.py | 28 +- src/curobo/geom/types.py | 33 ++- src/curobo/graph/graph_nx.py | 15 +- src/curobo/opt/newton/newton_base.py | 19 +- src/curobo/rollout/arm_reacher.py | 45 ++- src/curobo/rollout/cost/pose_cost.py | 44 ++- .../rollout/dynamics_model/kinematic_model.py | 9 +- src/curobo/types/camera.py | 2 +- src/curobo/types/math.py | 50 ++-- src/curobo/util/helpers.py | 9 + src/curobo/util/sample_lib.py | 2 +- src/curobo/util/torch_utils.py | 6 + src/curobo/util/trajectory.py | 11 +- src/curobo/util/xrdf_utils.py | 24 +- src/curobo/util_file.py | 18 +- src/curobo/wrap/reacher/motion_gen.py | 280 +++++++++++++++++- src/curobo/wrap/reacher/trajopt.py | 13 +- tests/goal_test.py | 7 +- tests/motion_gen_goalset_test.py | 41 ++- tests/self_collision_test.py | 73 +++++ tests/voxelization_test.py | 6 +- tests/xrdf_test.py | 32 ++ 38 files changed, 941 insertions(+), 537 deletions(-) delete mode 100644 docker/ros1_x86.dockerfile delete mode 100644 docker/start_docker_arm64.sh delete mode 100644 docker/start_docker_x86_robot.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e3cf25c..7547a722 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,17 +10,40 @@ its affiliates is strictly prohibited. --> # Changelog -## Latest Commit +## Version 0.7.5 + +### Changes in Default Behavior +- Remove explicit global seed setting for numpy and random. To enforce deterministic behavior, +use `np.random.seed(2)` and `random.seed(2)` in your program. +- geom.types.VoxelGrid now uses a different algorithm to calculate number of voxels per dimension +and also to compute xyz locations in a grid. This new implementation matches implementation in +nvblox. ### New Features - Add pose cost metric to MPC to allow for partial pose reaching. - Update obstacle poses in cpu reference with an optional flag. +- Add planning to grasp API in ``MotionGen.plan_grasp`` that plans a sequence of motions to grasp +an object given grasp poses. This API also provides args to disable collisions during the grasping +phase. +- Constrained planning can now use either goal frame or base frame at runtime. ### BugFixes & Misc. - Fixed optimize_dt not being correctly set when motion gen is called in reactive mode. - Add documentation for geom module. - Add descriptive api for computing kinematics. - Fix cv2 import order in isaac sim realsense examples. +- Fix attach sphere api mismatch in ``TrajOptSolver``. +- Fix bug in ``get_spline_interpolated_trajectory`` where +numpy array was created instead of torch tensor. +- Fix gradient bug when sphere origin is exactly at face of a cuboid. +- Add support for parsing Yaml 1.2 format with an updated regex for scientific notations. +- Move to yaml `SafeLoader` from `Loader`. +- Graph search checks if a node exists before attempting to find a path. +- Fix `steps_max` becoming 0 when optimized dt has NaN values. +- Clone `MotionGenPlanConfig` instance for every plan api. +- Improve sphere position to voxel location calculation to match nvblox's implementation. +- Add self collision checking support for spheres > 1024 and number of checks > 512 * 1024. +- Fix gradient passthrough in warp batch transform kernels. ## Version 0.7.4 diff --git a/benchmark/curobo_benchmark.py b/benchmark/curobo_benchmark.py index 84999699..2b96546c 100644 --- a/benchmark/curobo_benchmark.py +++ b/benchmark/curobo_benchmark.py @@ -210,6 +210,7 @@ def load_curobo( trajopt_dt=0.25, finetune_dt_scale=finetune_dt_scale, high_precision=args.high_precision, + use_cuda_graph_trajopt_metrics=cuda_graph, ) mg = MotionGen(motion_gen_config) mg.warmup(enable_graph=True, warmup_js_trajopt=False, parallel_finetune=parallel_finetune) @@ -484,7 +485,7 @@ def benchmark_mb( start_state, q_traj, dt=result.interpolation_dt, - save_path=join_path("benchmark/log/usd/", problem_name) + ".usd", + save_path=join_path("benchmark/log/usd/", problem_name)[1:] + ".usd", interpolation_steps=1, write_robot_usd_path="benchmark/log/usd/assets/", robot_usd_local_reference="assets/", @@ -499,7 +500,7 @@ def benchmark_mb( result.optimized_plan, result.optimized_dt.item(), title=problem_name, - save_path=join_path("benchmark/log/plot/", problem_name + ".png"), + save_path=join_path("benchmark/log/plot/", problem_name + ".png")[1:], ) m_list.append(current_metrics) diff --git a/benchmark/curobo_voxel_benchmark.py b/benchmark/curobo_voxel_benchmark.py index 191d500f..93aba488 100644 --- a/benchmark/curobo_voxel_benchmark.py +++ b/benchmark/curobo_voxel_benchmark.py @@ -149,7 +149,7 @@ def load_curobo( "base": { "dims": [2.4, 2.4, 2.4], "pose": [0, 0, 0, 1, 0, 0, 0], - "voxel_size": 0.005, + "voxel_size": 0.02, "feature_dtype": torch.bfloat16, }, } @@ -294,7 +294,7 @@ def benchmark_mb( world = WorldConfig.from_dict(problem["obstacles"]) # mg.world_coll_checker.clear_cache() - world_coll = WorldConfig.from_dict(problem["obstacles"]) + world_coll = WorldConfig.from_dict(problem["obstacles"]).get_collision_check_world() if args.mesh: world_coll = world_coll.get_mesh_world(merge_meshes=False) robot_world.clear_world_cache() @@ -302,7 +302,7 @@ def benchmark_mb( esdf = robot_world.world_model.get_esdf_in_bounding_box( Cuboid(name="base", pose=[0, 0, 0, 1, 0, 0, 0], dims=[2.4, 2.4, 2.4]), - voxel_size=0.005, + voxel_size=0.02, dtype=torch.float32, ) # esdf.feature_tensor[esdf.feature_tensor < -1.0] = -1000.0 @@ -336,13 +336,16 @@ def benchmark_mb( world.randomize_color(r=[0.5, 0.9], g=[0.2, 0.5], b=[0.0, 0.2]) coll_mesh = mg.world_coll_checker.get_mesh_in_bounding_box( - curobo_Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[2, 2, 2]), - voxel_size=0.005, + curobo_Cuboid( + name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[2.4, 2.4, 2.4] + ), + voxel_size=0.02, ) coll_mesh.color = [0.0, 0.8, 0.8, 0.2] coll_mesh.name = "voxel_world" + # world = WorldConfig(mesh=[coll_mesh]) world.add_obstacle(coll_mesh) # get costs: @@ -360,7 +363,7 @@ def benchmark_mb( plot_cost_iteration( traj_cost, title=problem_name + "_" + str(success) + "_" + str(dt), - save_path=join_path("benchmark/log/plot/", problem_name + "_cost"), + save_path=join_path("benchmark/log/plot/", problem_name + "_cost")[1:], log_scale=False, ) if "finetune_trajopt_result" in result.debug_info: @@ -373,7 +376,7 @@ def benchmark_mb( title=problem_name + "_" + str(success) + "_" + str(dt), save_path=join_path( "benchmark/log/plot/", problem_name + "_ft_cost" - ), + )[1:], log_scale=False, ) if result.success.item(): @@ -481,7 +484,7 @@ def benchmark_mb( start_state, q_traj, dt=result.interpolation_dt, - save_path=join_path("benchmark/log/usd/", problem_name) + ".usd", + save_path=join_path("benchmark/log/usd/", problem_name)[1:] + ".usd", interpolation_steps=1, write_robot_usd_path="benchmark/log/usd/assets/", robot_usd_local_reference="assets/", @@ -499,7 +502,7 @@ def benchmark_mb( # result.get_interpolated_plan(), # result.interpolation_dt, title=problem_name, - save_path=join_path("benchmark/log/plot/", problem_name + ".pdf"), + save_path=join_path("benchmark/log/plot/", problem_name + ".pdf")[1:], ) plot_traj( # result.optimized_plan, @@ -507,7 +510,9 @@ def benchmark_mb( result.get_interpolated_plan(), result.interpolation_dt, title=problem_name, - save_path=join_path("benchmark/log/plot/", problem_name + "_int.pdf"), + save_path=join_path("benchmark/log/plot/", problem_name + "_int.pdf")[ + 1: + ], ) m_list.append(current_metrics) diff --git a/docker/ros1_x86.dockerfile b/docker/ros1_x86.dockerfile deleted file mode 100644 index c4e00311..00000000 --- a/docker/ros1_x86.dockerfile +++ /dev/null @@ -1,69 +0,0 @@ -## -## Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -## -## NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -## property and proprietary rights in and to this material, related -## documentation and any modifications thereto. Any use, reproduction, -## disclosure or distribution of this material and related documentation -## without an express license agreement from NVIDIA CORPORATION or -## its affiliates is strictly prohibited. -## -FROM nvcr.io/nvidia/pytorch:23.02-py3 - - -RUN echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections - -RUN apt-get update && apt-get install -y \ - tzdata \ - && rm -rf /var/lib/apt/lists/* \ - && ln -fs /usr/share/zoneinfo/America/Los_Angeles /etc/localtime \ - && echo "America/Los_Angeles" > /etc/timezone \ - && dpkg-reconfigure -f noninteractive tzdata - -RUN apt-get update &&\ - apt-get install -y sudo git bash software-properties-common graphviz &&\ - rm -rf /var/lib/apt/lists/* - - - -RUN python -m pip install --upgrade pip && python3 -m pip install graphviz - -# Install ROS noetic -RUN sh -c 'echo "deb http://packages.ros.org/ros/ubuntu focal main" > /etc/apt/sources.list.d/ros-latest.list' \ -&& apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 \ -&& apt-get update && apt-get install -y \ - ros-noetic-desktop-full git build-essential python3-rosdep \ - && rm -rf /var/lib/apt/lists/* - - -# install realsense and azure kinect -# Install the RealSense library (https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md#installing-the-packages) -#RUN sudo apt-key adv --keyserver keys.gnupg.net --recv-key F6E65AC044F831AC80A06380C8B3A55A6F3EFCDE || sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-key F6E65AC044F831AC80A06380C8B3A55A6F3EFCDE -#RUN sudo add-apt-repository "deb https://librealsense.intel.com/Debian/apt-repo $(lsb_release -cs) main" -u -#RUN apt-get update && apt-get install -y \ -# librealsense2-dkms \ -# software-properties-common \ -# librealsense2-utils \ -# && rm -rf /var/lib/apt/lists/* - - -# install moveit from source for all algos: -ARG ROS_DISTRO=noetic -RUN apt-get update && apt-get install -y \ - ros-$ROS_DISTRO-apriltag-ros \ - ros-$ROS_DISTRO-realsense2-camera \ - ros-$ROS_DISTRO-ros-numpy \ - ros-$ROS_DISTRO-vision-msgs \ - ros-$ROS_DISTRO-franka-ros \ - ros-$ROS_DISTRO-moveit-resources \ - ros-$ROS_DISTRO-rosparam-shortcuts \ - libglfw3-dev \ - ros-$ROS_DISTRO-collada-urdf \ - ros-$ROS_DISTRO-ur-msgs \ - swig \ - && rm -rf /var/lib/apt/lists/* - - -RUN apt-get update && rosdep init && rosdep update && apt-get install -y ros-noetic-moveit-ros-visualization && rm -rf /var/lib/apt/lists/* -RUN pip3 install netifaces - diff --git a/docker/start_docker_arm64.sh b/docker/start_docker_arm64.sh deleted file mode 100644 index 96461aee..00000000 --- a/docker/start_docker_arm64.sh +++ /dev/null @@ -1,22 +0,0 @@ -## -## Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -## -## NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -## property and proprietary rights in and to this material, related -## documentation and any modifications thereto. Any use, reproduction, -## disclosure or distribution of this material and related documentation -## without an express license agreement from NVIDIA CORPORATION or -## its affiliates is strictly prohibited. -## -docker run --rm -it \ ---runtime nvidia \ ---mount type=bind,src=/home/$USER/code,target=/home/$USER/code \ ---hostname ros1-docker \ ---add-host ros1-docker:127.0.0.1 \ ---network host \ ---gpus all \ ---env ROS_HOSTNAME=localhost \ ---env DISPLAY=$DISPLAY \ ---volume /tmp/.X11-unix:/tmp/.X11-unix \ ---volume /dev/input:/dev/input \ -curobo_user_docker:latest diff --git a/docker/start_docker_x86_robot.sh b/docker/start_docker_x86_robot.sh deleted file mode 100644 index a4d7cbfc..00000000 --- a/docker/start_docker_x86_robot.sh +++ /dev/null @@ -1,24 +0,0 @@ -## -## Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -## -## NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -## property and proprietary rights in and to this material, related -## documentation and any modifications thereto. Any use, reproduction, -## disclosure or distribution of this material and related documentation -## without an express license agreement from NVIDIA CORPORATION or -## its affiliates is strictly prohibited. -## -docker run --rm -it \ ---privileged --mount type=bind,src=/home/$USER/code,target=/home/$USER/code \ --e NVIDIA_DISABLE_REQUIRE=1 \ --e NVIDIA_DRIVER_CAPABILITIES=all --device /dev/dri \ ---hostname ros1-docker \ ---add-host ros1-docker:127.0.0.1 \ ---gpus all \ ---network host \ ---env ROS_MASTER_URI=http://127.0.0.1:11311 \ ---env ROS_IP=127.0.0.1 \ ---env DISPLAY=unix$DISPLAY \ ---volume /tmp/.X11-unix:/tmp/.X11-unix \ ---volume /dev/input:/dev/input \ -curobo_user_docker:latest \ No newline at end of file diff --git a/examples/usd_example.py b/examples/usd_example.py index d4e1dcc9..93dba3e2 100644 --- a/examples/usd_example.py +++ b/examples/usd_example.py @@ -19,7 +19,7 @@ from curobo.types.base import TensorDeviceType from curobo.types.math import Pose from curobo.types.robot import JointState, RobotConfig -from curobo.util.logger import setup_curobo_logger +from curobo.util.logger import setup_curobo_logger, log_error from curobo.util.usd_helper import UsdHelper from curobo.util_file import ( get_assets_path, @@ -44,7 +44,7 @@ def save_curobo_world_to_usd(): usd_helper.write_stage_to_file("test.usd") -def get_trajectory(robot_file="franka.yml", dt=1.0 / 60.0): +def get_trajectory(robot_file="franka.yml", dt=1.0 / 60.0, plan_grasp: bool = False): tensor_args = TensorDeviceType() world_file = "collision_test.yml" motion_gen_config = MotionGenConfig.load_from_robot_config( @@ -60,24 +60,40 @@ def get_trajectory(robot_file="franka.yml", dt=1.0 / 60.0): interpolation_dt=dt, ) motion_gen = MotionGen(motion_gen_config) - motion_gen.warmup() + motion_gen.warmup(n_goalset=2) robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_file))["robot_cfg"] robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args) retract_cfg = motion_gen.get_retract_config() state = motion_gen.rollout_fn.compute_kinematics( JointState.from_position(retract_cfg.view(1, -1)) ) + if plan_grasp: + retract_pose = Pose( + state.ee_pos_seq.view(1, -1, 3), quaternion=state.ee_quat_seq.view(1, -1, 4) + ) + start_state = JointState.from_position(retract_cfg.view(1, -1).clone()) + start_state.position[..., :-2] += 0.5 + m_config = MotionGenPlanConfig(False, True) + + result = motion_gen.plan_grasp(start_state, retract_pose, m_config.clone()) + if not result.success: + log_error("Failed to plan grasp: " + result.status) + traj = result.grasp_interpolated_trajectory + traj2 = result.retract_interpolated_trajectory + traj = traj.stack(traj2).clone() + # result = motion_gen.plan_single(start_state, retract_pose) + # traj = result.get_interpolated_plan() # optimized plan - retract_pose = Pose(state.ee_pos_seq.squeeze(), quaternion=state.ee_quat_seq.squeeze()) - start_state = JointState.from_position(retract_cfg.view(1, -1).clone()) - start_state.position[..., :-2] += 0.5 - result = motion_gen.plan_single(start_state, retract_pose) - traj = result.get_interpolated_plan() # optimized plan + else: + retract_pose = Pose(state.ee_pos_seq.squeeze(), quaternion=state.ee_quat_seq.squeeze()) + start_state = JointState.from_position(retract_cfg.view(1, -1).clone()) + start_state.position[..., :-2] += 0.5 + result = motion_gen.plan_single(start_state, retract_pose) + traj = result.get_interpolated_plan() # optimized plan return traj -def save_curobo_robot_world_to_usd(robot_file="franka.yml"): - print(robot_file) +def save_curobo_robot_world_to_usd(robot_file="franka.yml", plan_grasp: bool = False): tensor_args = TensorDeviceType() world_file = "collision_test.yml" world_model = WorldConfig.from_dict( @@ -85,7 +101,7 @@ def save_curobo_robot_world_to_usd(robot_file="franka.yml"): ).get_obb_world() dt = 1 / 60.0 - q_traj = get_trajectory(robot_file, dt) + q_traj = get_trajectory(robot_file, dt, plan_grasp) if q_traj is not None: q_start = q_traj[0] UsdHelper.write_trajectory_animation_with_robot_usd( diff --git a/src/curobo/cuda_robot_model/cuda_robot_generator.py b/src/curobo/cuda_robot_model/cuda_robot_generator.py index 9e43db79..2d461351 100644 --- a/src/curobo/cuda_robot_model/cuda_robot_generator.py +++ b/src/curobo/cuda_robot_model/cuda_robot_generator.py @@ -852,7 +852,11 @@ def _build_collision_model( if not valid_data: use_experimental_kernel = False - log_warn("Self Collision checks are greater than 32 * 512, using slower kernel") + log_warn( + "Self Collision checks are greater than 32 * 512, using slower kernel." + + " Number of spheres: " + + str(self_collision_distance.shape[0]) + ) if use_experimental_kernel: self_coll_matrix = torch.zeros((2), device=self.tensor_args.device, dtype=torch.uint8) else: diff --git a/src/curobo/curobolib/cpp/geom_cuda.cpp b/src/curobo/curobolib/cpp/geom_cuda.cpp index a83337fd..fe566fa5 100644 --- a/src/curobo/curobolib/cpp/geom_cuda.cpp +++ b/src/curobo/curobolib/cpp/geom_cuda.cpp @@ -108,21 +108,21 @@ swept_sphere_voxel_clpt(const torch::Tensor sphere_position, // batch_size, 3 torch::Tensor sparsity_idx, const torch::Tensor weight, const torch::Tensor activation_distance, const torch::Tensor max_distance, - const torch::Tensor speed_dt, + const torch::Tensor speed_dt, const torch::Tensor grid_features, // n_boxes, 4, 4 const torch::Tensor grid_params, // n_boxes, 3 const torch::Tensor grid_pose, // n_boxes, 4, 4 const torch::Tensor grid_enable, // n_boxes, 4, 4 const torch::Tensor n_env_grid, const torch::Tensor env_query_idx, // n_boxes, 4, 4 - const int max_nobs, - const int batch_size, + const int max_nobs, + const int batch_size, const int horizon, - const int n_spheres, + const int n_spheres, const int sweep_steps, const bool enable_speed_metric, const bool transform_back, - const bool compute_distance, + const bool compute_distance, const bool use_batch_env, const bool sum_collisions); @@ -145,14 +145,15 @@ std::vectorpose_distance( const torch::Tensor offset_waypoint, const torch::Tensor offset_tstep_fraction, const torch::Tensor batch_pose_idx, + const torch::Tensor project_distance, const int batch_size, const int horizon, const int mode, const int num_goals = 1, const bool compute_grad = false, const bool write_distance = true, - const bool use_metric = false, - const bool project_distance = true); + const bool use_metric = false + ); std::vector backward_pose_distance(torch::Tensor out_grad_p, @@ -202,7 +203,7 @@ std::vectorsphere_obb_clpt_wrapper( torch::Tensor closest_point, // batch size, 3 torch::Tensor sparsity_idx, const torch::Tensor weight, const torch::Tensor activation_distance, - const torch::Tensor max_distance, + const torch::Tensor max_distance, const torch::Tensor obb_accel, // n_boxes, 4, 4 const torch::Tensor obb_bounds, // n_boxes, 3 const torch::Tensor obb_pose, // n_boxes, 4, 4 @@ -210,9 +211,9 @@ std::vectorsphere_obb_clpt_wrapper( const torch::Tensor n_env_obb, // n_boxes, 4, 4 const torch::Tensor env_query_idx, // n_boxes, 4, 4 const int max_nobs, const int batch_size, const int horizon, - const int n_spheres, + const int n_spheres, const bool transform_back, const bool compute_distance, - const bool use_batch_env, const bool sum_collisions = true, + const bool use_batch_env, const bool sum_collisions = true, const bool compute_esdf = false) { const at::cuda::OptionalCUDAGuard guard(sphere_position.device()); @@ -305,10 +306,11 @@ std::vectorpose_distance_wrapper( const torch::Tensor weight, const torch::Tensor vec_convergence, const torch::Tensor run_weight, const torch::Tensor run_vec_weight, const torch::Tensor offset_waypoint, const torch::Tensor offset_tstep_fraction, - const torch::Tensor batch_pose_idx, const int batch_size, const int horizon, + const torch::Tensor batch_pose_idx, + const torch::Tensor project_distance, + const int batch_size, const int horizon, const int mode, const int num_goals = 1, const bool compute_grad = false, - const bool write_distance = false, const bool use_metric = false, - const bool project_distance = true) + const bool write_distance = false, const bool use_metric = false) { // at::cuda::DeviceGuard guard(angle.device()); CHECK_INPUT(out_distance); @@ -323,6 +325,7 @@ std::vectorpose_distance_wrapper( CHECK_INPUT(batch_pose_idx); CHECK_INPUT(offset_waypoint); CHECK_INPUT(offset_tstep_fraction); + CHECK_INPUT(project_distance); const at::cuda::OptionalCUDAGuard guard(current_position.device()); return pose_distance( @@ -332,8 +335,10 @@ std::vectorpose_distance_wrapper( vec_convergence, run_weight, run_vec_weight, offset_waypoint, offset_tstep_fraction, - batch_pose_idx, batch_size, - horizon, mode, num_goals, compute_grad, write_distance, use_metric, project_distance); + batch_pose_idx, + project_distance, + batch_size, + horizon, mode, num_goals, compute_grad, write_distance, use_metric); } std::vectorbackward_pose_distance_wrapper( @@ -372,8 +377,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "Closest Point Voxel(curobolib)"); m.def("swept_closest_point_voxel", &swept_sphere_voxel_clpt, "Swpet Closest Point Voxel(curobolib)"); - - + + m.def("self_collision_distance", &self_collision_distance_wrapper, "Self Collision Distance (curobolib)"); diff --git a/src/curobo/curobolib/cpp/pose_distance_kernel.cu b/src/curobo/curobolib/cpp/pose_distance_kernel.cu index 506504b1..16169931 100644 --- a/src/curobo/curobolib/cpp/pose_distance_kernel.cu +++ b/src/curobo/curobolib/cpp/pose_distance_kernel.cu @@ -146,7 +146,6 @@ namespace Curobo } } - template __device__ __forceinline__ void compute_pose_distance_vector(float *result_vec, const float3 goal_position, @@ -156,7 +155,8 @@ namespace Curobo const float *vec_weight, const float3 offset_position, const float3 offset_rotation, - const bool reach_offset) + const bool reach_offset, + const bool project_distance) { // project current position to goal frame: float3 error_position = make_float3(0, 0, 0); @@ -253,7 +253,7 @@ namespace Curobo } } - template + template __device__ __forceinline__ void compute_pose_distance(float *distance_vec, float& distance, float& position_distance, float& rotation_distance, const float3 current_position, @@ -265,17 +265,19 @@ namespace Curobo const float r_alpha, const float3 offset_position, const float3 offset_rotation, - const bool reach_offset) + const bool reach_offset, + const bool project_distance) { - compute_pose_distance_vector(&distance_vec[0], - goal_position, - goal_quat, - current_position, - current_quat, - &vec_weight[0], - offset_position, - offset_rotation, - reach_offset); + compute_pose_distance_vector(&distance_vec[0], + goal_position, + goal_quat, + current_position, + current_quat, + &vec_weight[0], + offset_position, + offset_rotation, + reach_offset, + project_distance); position_distance = 0; rotation_distance = 0; @@ -394,7 +396,7 @@ namespace Curobo *(float3 *)&out_grad_q[batch_idx * 4 + 1] = g_q; } - template + template __global__ void goalset_pose_distance_kernel( scalar_t *out_distance, scalar_t *out_position_distance, scalar_t *out_rotation_distance, scalar_t *out_p_vec, scalar_t *out_q_vec, @@ -405,7 +407,9 @@ namespace Curobo const scalar_t *run_weight, const scalar_t *run_vec_weight, const scalar_t *offset_waypoint, const scalar_t *offset_tstep_fraction, - const int32_t *batch_pose_idx, const int mode, const int num_goals, + const int32_t *batch_pose_idx, + const uint8_t *project_distance_tensor, + const int mode, const int num_goals, const int batch_size, const int horizon, const bool write_grad = false) { const int t_idx = (blockDim.x * blockIdx.x + threadIdx.x); @@ -416,7 +420,7 @@ namespace Curobo { return; } - + const bool project_distance = project_distance_tensor[0]; // read current pose: float3 position = *(float3 *)¤t_position[batch_idx * horizon * 3 + h_idx * 3]; @@ -511,7 +515,7 @@ namespace Curobo float4 gq4 = *(float4 *)&goal_quat[(offset + k) * 4]; l_goal_quat = make_float4(gq4.y, gq4.z, gq4.w, gq4.x); - compute_pose_distance(&distance_vec[0], + compute_pose_distance(&distance_vec[0], pose_distance, position_distance, rotation_distance, @@ -531,7 +535,8 @@ namespace Curobo r_w_alpha, offset_position, offset_rotation, - reach_offset); + reach_offset, + project_distance); if (pose_distance <= best_distance) { @@ -657,10 +662,10 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, const torch::Tensor offset_waypoint, const torch::Tensor offset_tstep_fraction, const torch::Tensor batch_pose_idx, // batch_size, 1 + const torch::Tensor project_distance, const int batch_size, const int horizon, const int mode, const int num_goals = 1, const bool compute_grad = false, - const bool write_distance = true, const bool use_metric = false, - const bool project_distance = true) + const bool write_distance = true, const bool use_metric = false) { using namespace Curobo::Pose; @@ -684,123 +689,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (project_distance) - { - if (use_metric) - { - if (write_distance) - { - AT_DISPATCH_FLOATING_TYPES( - current_position.scalar_type(), "batch_pose_distance", ([&] { - goalset_pose_distance_kernel - << < blocksPerGrid, threadsPerBlock, 0, - stream >> > ( - out_distance.data_ptr(), - out_position_distance.data_ptr(), - out_rotation_distance.data_ptr(), - distance_p_vector.data_ptr(), - distance_q_vector.data_ptr(), - out_gidx.data_ptr(), - current_position.data_ptr(), - goal_position.data_ptr(), - current_quat.data_ptr(), - goal_quat.data_ptr(), - vec_weight.data_ptr(), weight.data_ptr(), - vec_convergence.data_ptr(), - run_weight.data_ptr(), - run_vec_weight.data_ptr(), - offset_waypoint.data_ptr(), - offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, - batch_size, horizon, compute_grad); - })); - } - else - { - AT_DISPATCH_FLOATING_TYPES( - current_position.scalar_type(), "batch_pose_distance", ([&] { - goalset_pose_distance_kernel - << < blocksPerGrid, threadsPerBlock, 0, - stream >> > ( - out_distance.data_ptr(), - out_position_distance.data_ptr(), - out_rotation_distance.data_ptr(), - distance_p_vector.data_ptr(), - distance_q_vector.data_ptr(), - out_gidx.data_ptr(), - current_position.data_ptr(), - goal_position.data_ptr(), - current_quat.data_ptr(), - goal_quat.data_ptr(), - vec_weight.data_ptr(), weight.data_ptr(), - vec_convergence.data_ptr(), - run_weight.data_ptr(), - run_vec_weight.data_ptr(), - offset_waypoint.data_ptr(), - offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, - batch_size, horizon, compute_grad); - })); - } - } - else - { - if (write_distance) - { - AT_DISPATCH_FLOATING_TYPES( - current_position.scalar_type(), "batch_pose_distance", ([&] { - goalset_pose_distance_kernel - << < blocksPerGrid, threadsPerBlock, 0, stream >> > ( - out_distance.data_ptr(), - out_position_distance.data_ptr(), - out_rotation_distance.data_ptr(), - distance_p_vector.data_ptr(), - distance_q_vector.data_ptr(), - out_gidx.data_ptr(), - current_position.data_ptr(), - goal_position.data_ptr(), - current_quat.data_ptr(), - goal_quat.data_ptr(), - vec_weight.data_ptr(), weight.data_ptr(), - vec_convergence.data_ptr(), - run_weight.data_ptr(), - run_vec_weight.data_ptr(), - offset_waypoint.data_ptr(), - offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, - batch_size, horizon, compute_grad); - })); - } - else - { - AT_DISPATCH_FLOATING_TYPES( - current_position.scalar_type(), "batch_pose_distance", ([&] { - goalset_pose_distance_kernel - << < blocksPerGrid, threadsPerBlock, 0, stream >> > ( - out_distance.data_ptr(), - out_position_distance.data_ptr(), - out_rotation_distance.data_ptr(), - distance_p_vector.data_ptr(), - distance_q_vector.data_ptr(), - out_gidx.data_ptr(), - current_position.data_ptr(), - goal_position.data_ptr(), - current_quat.data_ptr(), - goal_quat.data_ptr(), - vec_weight.data_ptr(), weight.data_ptr(), - vec_convergence.data_ptr(), - run_weight.data_ptr(), - run_vec_weight.data_ptr(), - offset_waypoint.data_ptr(), - offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, - batch_size, horizon, compute_grad); - })); - } - } - } - else - { + if (use_metric) { if (write_distance) @@ -808,7 +697,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, AT_DISPATCH_FLOATING_TYPES( current_position.scalar_type(), "batch_pose_distance", ([&] { goalset_pose_distance_kernel - << < blocksPerGrid, threadsPerBlock, 0, + << < blocksPerGrid, threadsPerBlock, 0, stream >> > ( out_distance.data_ptr(), out_position_distance.data_ptr(), @@ -826,7 +715,9 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, run_vec_weight.data_ptr(), offset_waypoint.data_ptr(), offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, + batch_pose_idx.data_ptr(), + project_distance.data_ptr(), + mode, num_goals, batch_size, horizon, compute_grad); })); } @@ -835,7 +726,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, AT_DISPATCH_FLOATING_TYPES( current_position.scalar_type(), "batch_pose_distance", ([&] { goalset_pose_distance_kernel - << < blocksPerGrid, threadsPerBlock, 0, + << < blocksPerGrid, threadsPerBlock, 0, stream >> > ( out_distance.data_ptr(), out_position_distance.data_ptr(), @@ -853,7 +744,9 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, run_vec_weight.data_ptr(), offset_waypoint.data_ptr(), offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, + batch_pose_idx.data_ptr(), + project_distance.data_ptr(), + mode, num_goals, batch_size, horizon, compute_grad); })); } @@ -864,7 +757,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, { AT_DISPATCH_FLOATING_TYPES( current_position.scalar_type(), "batch_pose_distance", ([&] { - goalset_pose_distance_kernel + goalset_pose_distance_kernel << < blocksPerGrid, threadsPerBlock, 0, stream >> > ( out_distance.data_ptr(), out_position_distance.data_ptr(), @@ -882,7 +775,9 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, run_vec_weight.data_ptr(), offset_waypoint.data_ptr(), offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, + batch_pose_idx.data_ptr(), + project_distance.data_ptr(), + mode, num_goals, batch_size, horizon, compute_grad); })); } @@ -890,7 +785,7 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, { AT_DISPATCH_FLOATING_TYPES( current_position.scalar_type(), "batch_pose_distance", ([&] { - goalset_pose_distance_kernel + goalset_pose_distance_kernel << < blocksPerGrid, threadsPerBlock, 0, stream >> > ( out_distance.data_ptr(), out_position_distance.data_ptr(), @@ -908,12 +803,15 @@ pose_distance(torch::Tensor out_distance, torch::Tensor out_position_distance, run_vec_weight.data_ptr(), offset_waypoint.data_ptr(), offset_tstep_fraction.data_ptr(), - batch_pose_idx.data_ptr(), mode, num_goals, + batch_pose_idx.data_ptr(), + project_distance.data_ptr(), + mode, num_goals, batch_size, horizon, compute_grad); })); } } - } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); return { out_distance, out_position_distance, out_rotation_distance, diff --git a/src/curobo/curobolib/cpp/self_collision_kernel.cu b/src/curobo/curobolib/cpp/self_collision_kernel.cu index af6d9e88..4ba6cc23 100644 --- a/src/curobo/curobolib/cpp/self_collision_kernel.cu +++ b/src/curobo/curobolib/cpp/self_collision_kernel.cu @@ -41,7 +41,9 @@ namespace Curobo scalar_t *out_distance, // batch x 1 scalar_t *out_vec, // batch x nspheres x 4 const scalar_t *robot_spheres, // batch x nspheres x 4 - const scalar_t *collision_threshold, const int batch_size, + const scalar_t *offsets, + const uint8_t *coll_matrix, + const int batch_size, const int nspheres, const scalar_t *weight, const bool write_grad = false) { const int batch_idx = blockDim.x * blockIdx.x + threadIdx.x; @@ -52,37 +54,35 @@ namespace Curobo } float r_diff, distance; float max_penetration = 0; - float3 sph1, sph2, dist_vec; + float4 sph1, sph2; int sph1_idx = -1; int sph2_idx = -1; // iterate over spheres: for (int i = 0; i < nspheres; i++) { - sph1 = *(float3 *)&robot_spheres[batch_idx * nspheres * 4 + i * 4]; + sph1 = *(float4 *)&robot_spheres[batch_idx * nspheres * 4 + i * 4]; + sph1.w += offsets[i]; for (int j = i + 1; j < nspheres; j++) { - r_diff = collision_threshold[i * nspheres + j]; - - if (isinf(r_diff)) + if(coll_matrix[i * nspheres + j] == 1) { - continue; - } - sph2 = *(float3 *)&robot_spheres[batch_idx * nspheres * 4 + j * 4]; - - // compute sphere distance: - distance = relu(r_diff - length(sph1 - sph2)); + sph2 = *(float4 *)&robot_spheres[batch_idx * nspheres * 4 + j * 4]; + sph2.w += offsets[j]; - if (distance > max_penetration) - { - max_penetration = distance; - sph1_idx = i; - sph2_idx = j; + // compute sphere distance: + r_diff = sph1.w + sph2.w; + float d = sqrt((sph1.x - sph2.x) * (sph1.x - sph2.x) + + (sph1.y - sph2.y) * (sph1.y - sph2.y) + + (sph1.z - sph2.z) * (sph1.z - sph2.z)); + distance = (r_diff - d); - if (write_grad) + if (distance > max_penetration) { - dist_vec = normalize(sph1 - sph2);// / distance; + max_penetration = distance; + sph1_idx = i; + sph2_idx = j; } } } @@ -95,6 +95,11 @@ namespace Curobo if (write_grad) { + float3 sph1_g = + *(float3 *)&robot_spheres[4 * (batch_idx * nspheres + sph1_idx)]; + float3 sph2_g = + *(float3 *)&robot_spheres[4 * (batch_idx * nspheres + sph2_idx)]; + float3 dist_vec = normalize(sph1_g - sph2_g); *(float3 *)&out_vec[batch_idx * nspheres * 4 + sph1_idx * 4] = weight[0] * -1 * dist_vec; *(float3 *)&out_vec[batch_idx * nspheres * 4 + sph2_idx * 4] = @@ -131,7 +136,7 @@ namespace Curobo int i = ndpt * (warp_idx / nwpr); // starting row number for this warp int j = (warp_idx % nwpr) * 32; // starting column number for this warp - dist_t max_d = {0.0, 0.0, 0.0 };// .d, .i, .j + dist_t max_d = { 0.0, 0, 0 };// .d, .i, .j __shared__ dist_t max_darr[32]; // Optimization: About 1/3 of the warps will have no work. @@ -354,7 +359,7 @@ namespace Curobo // in registers (max_d). // Each thread computes upto ndpt distances. ////////////////////////////////////////////////////// - dist_t max_d[NBPB] = {{ 0.0, 0.0, 0.0}}; + dist_t max_d[NBPB] = {{ 0.0, 0, 0}}; int16_t indices[ndpt * 2]; for (uint8_t i = 0; i < ndpt * 2; i++) @@ -698,7 +703,7 @@ std::vectorself_collision_distance( } else { - assert(false); // only ndpt of 32 or 64 is currently supported. + assert(false); } } @@ -713,6 +718,8 @@ std::vectorself_collision_distance( assert(collision_matrix.size(0) == nspheres * nspheres); int smemSize = nspheres * sizeof(float4); + if (nspheres < 1024 && threadsPerBlock < 1024) + { AT_DISPATCH_FLOATING_TYPES( robot_spheres.scalar_type(), "self_collision_distance", ([&] { @@ -726,6 +733,30 @@ std::vectorself_collision_distance( ndpt_n, nwpr, weight.data_ptr(), sparse_index.data_ptr(), compute_grad); })); + } + else + { + threadsPerBlock = batch_size; + if (threadsPerBlock > 128) + { + threadsPerBlock = 128; + } + blocksPerGrid = (batch_size + threadsPerBlock - 1) / threadsPerBlock; + + AT_DISPATCH_FLOATING_TYPES( + robot_spheres.scalar_type(), "self_collision_distance", ([&] { + self_collision_distance_kernel + << < blocksPerGrid, threadsPerBlock, smemSize, stream >> > ( + out_distance.data_ptr(), + out_vec.data_ptr(), + robot_spheres.data_ptr(), + collision_offset.data_ptr(), + collision_matrix.data_ptr(), + batch_size, nspheres, + weight.data_ptr(), + compute_grad); + })); + } } C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/src/curobo/curobolib/cpp/sphere_obb_kernel.cu b/src/curobo/curobolib/cpp/sphere_obb_kernel.cu index 608e0677..908829cc 100644 --- a/src/curobo/curobolib/cpp/sphere_obb_kernel.cu +++ b/src/curobo/curobolib/cpp/sphere_obb_kernel.cu @@ -52,6 +52,28 @@ namespace Curobo return max(0.0f, sphere_length(v1, v2) - v1.w - v2.w); } + __device__ __forceinline__ int3 robust_floor(const float3 f_grid, const float threshold=1e-04) + { + float3 nearest_grid = make_float3(round(f_grid.x), round(f_grid.y), round(f_grid.z)); + + float3 abs_diff = (f_grid - nearest_grid); + + if (abs_diff.x >= threshold) + { + nearest_grid.x = floorf(f_grid.x); + } + if (abs_diff.y >= threshold) + { + nearest_grid.y = floorf(f_grid.y); + } + + if (abs_diff.z >= threshold) + { + nearest_grid.z = floorf(f_grid.z); + } + return make_int3(nearest_grid); + + } #if CHECK_FP8 __device__ __forceinline__ float @@ -487,21 +509,20 @@ namespace Curobo delta = make_float3(pt.x - sphere.x, pt.y - sphere.y, pt.z - sphere.z); distance = length(delta); - - if (!inside) + if (distance == 0.0) + { + delta = -1.0 * make_float3(pt.x, pt.y, pt.z); + } + if (!inside) // outside { distance *= -1.0; } - else + else // inside { delta = -1 * delta; } - if (distance != 0.0) - { - delta = normalize(delta); - - } + delta = normalize(delta); sph_distance = distance + sphere.w; // @@ -685,29 +706,48 @@ float4 &sum_pt) template __device__ __forceinline__ void - compute_voxel_location_params( + compute_voxel_index( const grid_scalar_t *grid_features, const float4& loc_grid_params, const float4& loc_sphere, + int &voxel_idx, int3 &xyz_loc, int3 &xyz_grid, - float &interpolated_distance, - int &voxel_idx) + float &interpolated_distance) { - // convert location to index: can use floor to cast to int. - // to account for negative values, add 0.5 * bounds. - const float3 loc_grid = make_float3(loc_grid_params.x, loc_grid_params.y, loc_grid_params.z); + + + const float3 loc_grid = make_float3(loc_grid_params.x, loc_grid_params.y, loc_grid_params.z);// - loc_grid_params.w; const float3 sphere = make_float3(loc_sphere.x, loc_sphere.y, loc_sphere.z); - //xyz_loc = make_int3(floorf((sphere + 0.5 * loc_grid) / loc_grid_params.w)); - const float inv_voxel_size = 1.0/loc_grid_params.w; - //xyz_loc = make_int3(sphere * inv_voxel_size) + make_int3(0.5 * loc_grid * inv_voxel_size); + const float inv_voxel_size = 1.0f / loc_grid_params.w; + + float3 f_grid = (loc_grid) * inv_voxel_size; + + + xyz_grid = robust_floor(f_grid) + 1; + + + xyz_loc = make_int3(((sphere.x + 0.5f * loc_grid.x) * inv_voxel_size), + ((sphere.y + 0.5f * loc_grid.y)* inv_voxel_size), + ((sphere.z + 0.5f * loc_grid.z) * inv_voxel_size)); + + + // check grid bounds: + // 2 to catch numerical precision errors. 1 can be used when exact. + // We need at least 1 as we + // look at neighbouring voxels for finite difference + const int offset = 2; + if (xyz_loc.x >= xyz_grid.x - offset || xyz_loc.y >= xyz_grid.y - offset || xyz_loc.z >= xyz_grid.z - offset + || xyz_loc.x <= offset || xyz_loc.y <= offset || xyz_loc.z <= offset + ) + { + voxel_idx = -1; + return; + } - xyz_loc = make_int3((sphere + 0.5 * loc_grid) * inv_voxel_size); - //xyz_loc = make_int3(sphere / loc_grid_params.w) + make_int3(floorf(0.5 * loc_grid/loc_grid_params.w)); - xyz_grid = make_int3((loc_grid * inv_voxel_size)) + 1; // find next nearest voxel to current point and then do weighted interpolation: voxel_idx = xyz_loc.x * xyz_grid.y * xyz_grid.z + xyz_loc.y * xyz_grid.z + xyz_loc.z; @@ -715,10 +755,6 @@ float4 &sum_pt) // compute interpolation distance between voxel origin and sphere location: get_array_value(grid_features, voxel_idx, interpolated_distance); - if(!INTERPOLATION) - { - interpolated_distance += 0.5 * loc_grid_params.w;//max(0.0, (0.3 * loc_grid_params.w) - loc_sphere.w); - } if(INTERPOLATION) { // @@ -739,41 +775,6 @@ float4 &sum_pt) - } - - template - __device__ __forceinline__ void - compute_voxel_index( - const grid_scalar_t *grid_features, - const float4& loc_grid_params, - const float4& loc_sphere, - int &voxel_idx, - int3 &xyz_loc, - int3 &xyz_grid, - float &interpolated_distance) - { - // check if sphere is out of bounds - // loc_grid_params.x contains bounds - float4 local_bounds = 0.5*loc_grid_params - 2*loc_grid_params.w; - - if (loc_sphere.x <= -1 * (local_bounds.x) || - loc_sphere.x >= (local_bounds.x) || - loc_sphere.y <= -1 * (local_bounds.y) || - loc_sphere.y >= (local_bounds.y) || - loc_sphere.z <= -1 * (local_bounds.z) || - loc_sphere.z >= (local_bounds.z)) - { - voxel_idx = -1; - return; - } - - compute_voxel_location_params(grid_features, loc_grid_params, loc_sphere, xyz_loc, xyz_grid, interpolated_distance, voxel_idx); - // convert location to index: can use floor to cast to int. - // to account for negative values, add 0.5 * bounds. - - - - } @@ -979,7 +980,7 @@ float4 &sum_pt) // Load sphere_position input float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4]; - if (sphere_cache.w <= 0.0) + if (sphere_cache.w < 0.0) { // write zeros for cost: out_distance[bn_sph_idx] = 0; @@ -1044,7 +1045,7 @@ float4 &sum_pt) // Load sphere_position input float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4]; - if (sphere_cache.w <= 0.0) + if (sphere_cache.w < 0.0) { // write zeros for cost: out_distance[bn_sph_idx] = 0; @@ -1173,7 +1174,7 @@ float4 &sum_pt) // Load sphere_position input float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4]; - if (sphere_cache.w <= 0.0) + if (sphere_cache.w < 0.0) { // write zeros for cost: out_distance[bn_sph_idx] = 0; @@ -1275,7 +1276,7 @@ float4 &sum_pt) // Load sphere_position input float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4]; - if (sphere_cache.w <= 0.0) + if (sphere_cache.w < 0.0) { // write zeros for cost: out_distance[bn_sph_idx] = 0; @@ -1452,7 +1453,7 @@ float4 &sum_pt) // Load sphere_position input float4 sphere_1_cache = *(float4 *)&sphere_position[bhs_idx * 4]; - if (sphere_1_cache.w <= 0.0) + if (sphere_1_cache.w < 0.0) { // write zeros for cost: out_distance[bhs_idx] = 0; @@ -1888,7 +1889,7 @@ float4 &sum_pt) // Load sphere_position input float4 sphere_cache = *(float4 *)&sphere_position[bn_sph_idx * 4]; - if (sphere_cache.w <= 0.0) + if (sphere_cache.w < 0.0) { // write zeros for cost: out_distance[bn_sph_idx] = 0; @@ -2001,7 +2002,7 @@ float4 &sum_pt) float4 sphere_1_cache = *(float4 *)&sphere_position[bhs_idx * 4]; - if (sphere_1_cache.w <= 0.0) + if (sphere_1_cache.w < 0.0) { // write zeros for cost: out_distance[bhs_idx] = 0; @@ -2303,7 +2304,7 @@ float4 &sum_pt) // if h_idx == horizon -1, we just read the same index float4 sphere_1_cache = *(float4 *)&sphere_position[bhs_idx * 4]; - if (sphere_1_cache.w <= 0.0) + if (sphere_1_cache.w < 0.0) { out_distance[b_addrs + h_idx * nspheres + sph_idx] = 0.0; return; diff --git a/src/curobo/curobolib/geom.py b/src/curobo/curobolib/geom.py index a675069a..b237cd0b 100644 --- a/src/curobo/curobolib/geom.py +++ b/src/curobo/curobolib/geom.py @@ -157,6 +157,7 @@ def get_pose_distance( offset_waypoint, offset_tstep_fraction, batch_pose_idx, + project_distance, batch_size, horizon, mode=1, @@ -164,7 +165,6 @@ def get_pose_distance( write_grad=False, write_distance=False, use_metric=False, - project_distance=True, ): if batch_pose_idx.shape[0] != batch_size: raise ValueError("Index buffer size is different from batch size") @@ -188,6 +188,7 @@ def get_pose_distance( offset_waypoint, offset_tstep_fraction, batch_pose_idx, + project_distance, batch_size, horizon, mode, @@ -195,7 +196,6 @@ def get_pose_distance( write_grad, write_distance, use_metric, - project_distance, ) out_distance = r[0] @@ -272,6 +272,7 @@ def forward( offset_waypoint, offset_tstep_fraction, batch_pose_idx, + project_distance, out_distance, out_position_distance, out_rotation_distance, @@ -284,8 +285,7 @@ def forward( horizon, mode, # =PoseErrorType.BATCH_GOAL.value, num_goals, - use_metric, # =False, - project_distance, # =True, + use_metric, ): # out_distance = current_position[..., 0].detach().clone() * 0.0 # out_position_distance = out_distance.detach().clone() @@ -322,6 +322,7 @@ def forward( offset_waypoint, offset_tstep_fraction, batch_pose_idx, + project_distance, batch_size, horizon, mode, @@ -329,7 +330,6 @@ def forward( current_position.requires_grad, True, use_metric, - project_distance, ) ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad) return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1) @@ -406,6 +406,7 @@ def forward( offset_waypoint, offset_tstep_fraction, batch_pose_idx, + project_distance, out_distance, out_position_distance, out_rotation_distance, @@ -419,7 +420,6 @@ def forward( mode, num_goals, use_metric, - project_distance, return_loss, ): """Compute error in pose @@ -494,6 +494,7 @@ def forward( offset_waypoint, offset_tstep_fraction, batch_pose_idx, + project_distance, batch_size, horizon, mode, @@ -501,7 +502,6 @@ def forward( current_position.requires_grad, False, use_metric, - project_distance, ) ctx.save_for_backward(out_p_vec, out_r_vec) return out_distance diff --git a/src/curobo/curobolib/kinematics.py b/src/curobo/curobolib/kinematics.py index 1c7927a2..e385b539 100644 --- a/src/curobo/curobolib/kinematics.py +++ b/src/curobo/curobolib/kinematics.py @@ -113,7 +113,6 @@ def forward( @staticmethod def backward(ctx, grad_out_link_pos, grad_out_link_quat, grad_out_spheres): grad_joint = None - if ctx.needs_input_grad[4]: ( joint_seq, @@ -193,10 +192,14 @@ def _call_backward_cuda( b_size = b_shape[0] n_spheres = robot_sphere_out.shape[1] n_joints = angle.shape[-1] - if grad_out.is_contiguous(): - grad_out = grad_out.view(-1) - else: - grad_out = grad_out.reshape(-1) + grad_out = grad_out.contiguous() + link_pos_out = link_pos_out.contiguous() + link_quat_out = link_quat_out.contiguous() + # if grad_out.is_contiguous(): + # grad_out = grad_out.view(-1) + # else: + # grad_out = grad_out.reshape(-1) + r = kinematics_fused_cu.backward( grad_out, link_pos_out, diff --git a/src/curobo/curobolib/ls.py b/src/curobo/curobolib/ls.py index 4743ef31..af8dc6fc 100644 --- a/src/curobo/curobolib/ls.py +++ b/src/curobo/curobolib/ls.py @@ -58,7 +58,6 @@ def wolfe_line_search( l1 = g_x.shape[1] l2 = g_x.shape[2] r = line_search_cu.line_search( - # m_idx, best_x, best_c, best_grad, @@ -76,7 +75,6 @@ def wolfe_line_search( l2, batchsize, ) - # print("batchsize:" + str(batchsize)) return (r[0], r[1], r[2]) diff --git a/src/curobo/geom/sdf/world.py b/src/curobo/geom/sdf/world.py index 5a43d85d..a7065520 100644 --- a/src/curobo/geom/sdf/world.py +++ b/src/curobo/geom/sdf/world.py @@ -633,6 +633,7 @@ def get_mesh_in_bounding_box( self, cuboid: Cuboid = Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[1, 1, 1]), voxel_size: float = 0.02, + run_marching_cubes: bool = True, ) -> Mesh: """Get a mesh representation of the world obstacles based on occupancy in a bounding box. @@ -642,19 +643,31 @@ def get_mesh_in_bounding_box( Args: cuboid: Bounding box to get the mesh representation. voxel_size: Size of the voxels in meters. + run_marching_cubes: Runs marching cubes over occupied voxels to generate a mesh. If + set to False, then all occupied voxels are merged into a mesh and returned. Returns: Mesh representation of the world obstacles in the bounding box. """ voxels = self.get_voxels_in_bounding_box(cuboid, voxel_size) - # voxels = voxels.cpu().numpy() - # cuboids = [Cuboid(name="c_"+str(x), pose=[voxels[x,0],voxels[x,1],voxels[x,2], 1,0,0,0], - # dims=[voxel_size, voxel_size, voxel_size]) for x in range(voxels.shape[0])] - # mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0] - mesh = Mesh.from_pointcloud( - voxels[:, :3].detach().cpu().numpy(), - pitch=voxel_size * 1.1, - ) + voxels = voxels.cpu().numpy() + + if run_marching_cubes: + mesh = Mesh.from_pointcloud( + voxels[:, :3].detach().cpu().numpy(), + pitch=voxel_size * 1.1, + ) + else: + cuboids = [ + Cuboid( + name="c_" + str(x), + pose=[voxels[x, 0], voxels[x, 1], voxels[x, 2], 1, 0, 0, 0], + dims=[voxel_size, voxel_size, voxel_size], + ) + for x in range(voxels.shape[0]) + ] + mesh = WorldConfig(cuboid=cuboids).get_mesh_world(True).mesh[0] + return mesh def get_obstacle_names(self, env_idx: int = 0) -> List[str]: diff --git a/src/curobo/geom/sdf/world_voxel.py b/src/curobo/geom/sdf/world_voxel.py index 11d55545..e6b3411c 100644 --- a/src/curobo/geom/sdf/world_voxel.py +++ b/src/curobo/geom/sdf/world_voxel.py @@ -358,12 +358,21 @@ def update_voxel_data(self, new_voxel: VoxelGrid, env_idx: int = 0): env_idx: Environment index to update voxel grid in. """ obs_idx = self.get_voxel_idx(new_voxel.name, env_idx) - self._voxel_tensor_list[3][env_idx, obs_idx, :, :] = new_voxel.feature_tensor.view( - new_voxel.feature_tensor.shape[0], -1 - ).to(dtype=self._voxel_tensor_list[3].dtype) - self._voxel_tensor_list[0][env_idx, obs_idx, :3] = self.tensor_args.to_device( - new_voxel.dims - ) + + feature_tensor = new_voxel.feature_tensor.view(new_voxel.feature_tensor.shape[0], -1) + if ( + feature_tensor.shape[0] != self._voxel_tensor_list[3][env_idx, obs_idx, :, :].shape[0] + or feature_tensor.shape[1] + != self._voxel_tensor_list[3][env_idx, obs_idx, :, :].shape[1] + ): + log_error( + "Feature tensor shape mismatch, existing shape: " + + str(self._voxel_tensor_list[3][env_idx, obs_idx, :, :].shape) + + " New shape: " + + str(feature_tensor.shape) + ) + self._voxel_tensor_list[3][env_idx, obs_idx, :, :].copy_(feature_tensor) + self._voxel_tensor_list[0][env_idx, obs_idx, :3].copy_(torch.as_tensor(new_voxel.dims)) self._voxel_tensor_list[0][env_idx, obs_idx, 3] = new_voxel.voxel_size self._voxel_tensor_list[1][env_idx, obs_idx, :7] = ( Pose.from_list(new_voxel.pose, self.tensor_args).inverse().get_pose_vector() @@ -876,14 +885,19 @@ def clear_cache(self): self._env_n_voxels[:] = 0 super().clear_cache() - def get_voxel_grid_shape(self, env_idx: int = 0, obs_idx: int = 0) -> torch.Size: + def get_voxel_grid_shape( + self, env_idx: int = 0, obs_idx: int = 0, name: Optional[str] = None + ) -> torch.Size: """Get dimensions of the voxel grid. Args: env_idx: Environment index. obs_idx: Obstacle index. + name: Name of obstacle. When provided, obs_idx is ignored. Returns: Shape of the voxel grid. """ + if name is not None: + obs_idx = self.get_voxel_idx(name, env_idx) return self._voxel_tensor_list[3][env_idx, obs_idx].shape diff --git a/src/curobo/geom/types.py b/src/curobo/geom/types.py index 0679fffc..e9ad9fa3 100644 --- a/src/curobo/geom/types.py +++ b/src/curobo/geom/types.py @@ -13,7 +13,6 @@ from __future__ import annotations # Standard Library -import math from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -28,6 +27,7 @@ from curobo.types.base import TensorDeviceType from curobo.types.camera import CameraObservation from curobo.types.math import Pose +from curobo.util.helpers import robust_floor from curobo.util.logger import log_error, log_warn from curobo.util_file import get_assets_path, join_path @@ -723,12 +723,16 @@ def get_grid_shape(self) -> Tuple[List[int], List[float], List[float]]: """Get shape of voxel grid.""" bounds = self.dims + + grid_shape = [bounds[0], bounds[1], bounds[2]] + + inv_voxel_size = 1.0 / self.voxel_size + + grid_shape = [1 + robust_floor(x * inv_voxel_size) for x in grid_shape] + low = [-bounds[0] / 2, -bounds[1] / 2, -bounds[2] / 2] high = [bounds[0] / 2, bounds[1] / 2, bounds[2] / 2] - grid_shape = [ - 1 + int(high[i] / self.voxel_size) - (int(low[i] / self.voxel_size)) - for i in range(len(low)) - ] + return grid_shape, low, high def create_xyzr_tensor( @@ -745,10 +749,19 @@ def create_xyzr_tensor( """ trange, low, high = self.get_grid_shape() - - x = torch.linspace(low[0], high[0], trange[0], device=tensor_args.device) - y = torch.linspace(low[1], high[1], trange[1], device=tensor_args.device) - z = torch.linspace(low[2], high[2], trange[2], device=tensor_args.device) + inv_voxel_size = 1.0 / self.voxel_size + x = torch.linspace(1, trange[0], trange[0], device=tensor_args.device) - round( + (0.5 * self.dims[0]) * inv_voxel_size + ) + y = torch.linspace(1, trange[1], trange[1], device=tensor_args.device) - round( + (0.5 * self.dims[1]) * inv_voxel_size + ) + z = torch.linspace(1, trange[2], trange[2], device=tensor_args.device) - round( + (0.5 * self.dims[2]) * inv_voxel_size + ) + x = x * self.voxel_size - 0.5 * self.voxel_size + y = y * self.voxel_size - 0.5 * self.voxel_size + z = z * self.voxel_size - 0.5 * self.voxel_size w, l, h = x.shape[0], y.shape[0], z.shape[0] xyz = ( torch.stack(torch.meshgrid(x, y, z, indexing="ij")).permute((1, 2, 3, 0)).reshape(-1, 3) @@ -757,7 +770,7 @@ def create_xyzr_tensor( if transform_to_origin: pose = Pose.from_list(self.pose, tensor_args=tensor_args) xyz = pose.transform_points(xyz.contiguous()) - r = torch.zeros_like(xyz[:, 0:1]) + (self.voxel_size * 0.5) + r = torch.zeros_like(xyz[:, 0:1]) xyzr = torch.cat([xyz, r], dim=1) return xyzr diff --git a/src/curobo/graph/graph_nx.py b/src/curobo/graph/graph_nx.py index 28cb3d11..84ddb70a 100644 --- a/src/curobo/graph/graph_nx.py +++ b/src/curobo/graph/graph_nx.py @@ -10,19 +10,10 @@ # -# Standard Library -import random - # Third Party import networkx as nx -import numpy as np import torch -# This is needed to get deterministic results from networkx. -# Note: it has to be set in global space. -np.random.seed(2) -random.seed(2) - class NetworkxGraph(object): def __init__(self): @@ -63,7 +54,11 @@ def get_edges(self, attribue="weight"): def path_exists(self, start_node_idx, goal_node_idx): self.update_graph() - return nx.has_path(self.graph, start_node_idx, goal_node_idx) + # check if nodes exist in the graph + if self.graph.has_node(start_node_idx) and self.graph.has_node(goal_node_idx): + return nx.has_path(self.graph, start_node_idx, goal_node_idx) + else: + return False def get_shortest_path(self, start_node_idx, goal_node_idx, return_length=False): self.update_graph() diff --git a/src/curobo/opt/newton/newton_base.py b/src/curobo/opt/newton/newton_base.py index 36e7a310..9498bfea 100644 --- a/src/curobo/opt/newton/newton_base.py +++ b/src/curobo/opt/newton/newton_base.py @@ -603,13 +603,30 @@ def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int): @get_torch_jit_decorator() -def scale_action(dx, action_step_max): +def scale_action_old(dx, action_step_max): scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0] scale_value = torch.clamp(scale_value, 1.0) dx_scaled = dx / scale_value return dx_scaled +@get_torch_jit_decorator() +def scale_action(dx, action_step_max): + + # get largest dx scaled by bounds across optimization variables + scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0] + + # scale dx to bring all dx within bounds: + # only perfom for dx that are greater than 1: + + new_scale = torch.where(scale_value <= 1.0, 1.0, scale_value) + dx_scaled = dx / new_scale + + # scale_value = torch.clamp(scale_value, 1.0) + # dx_scaled = dx / scale_value + return dx_scaled + + @get_torch_jit_decorator() def check_convergence( best_iteration: torch.Tensor, current_iteration: torch.Tensor, last_best: int diff --git a/src/curobo/rollout/arm_reacher.py b/src/curobo/rollout/arm_reacher.py index ca06f795..379b91ae 100644 --- a/src/curobo/rollout/arm_reacher.py +++ b/src/curobo/rollout/arm_reacher.py @@ -435,7 +435,14 @@ def enable_cspace_cost(self, enable: bool = True): self.dist_cost.disable_cost() self.cspace_convergence.disable_cost() - def get_pose_costs(self, include_link_pose: bool = False, include_convergence: bool = True): + def get_pose_costs( + self, + include_link_pose: bool = False, + include_convergence: bool = True, + only_convergence: bool = False, + ): + if only_convergence: + return [self.pose_convergence] pose_costs = [self.goal_cost] if include_convergence: pose_costs += [self.pose_convergence] @@ -447,33 +454,15 @@ def update_pose_cost_metric( self, metric: PoseCostMetric, ): - pose_costs = self.get_pose_costs(include_link_pose=metric.include_link_pose) - if metric.hold_partial_pose: - if metric.hold_vec_weight is None: - log_error("hold_vec_weight is required") - [x.hold_partial_pose(metric.hold_vec_weight) for x in pose_costs] - if metric.release_partial_pose: - [x.release_partial_pose() for x in pose_costs] - if metric.reach_partial_pose: - if metric.reach_vec_weight is None: - log_error("reach_vec_weight is required") - [x.reach_partial_pose(metric.reach_vec_weight) for x in pose_costs] - if metric.reach_full_pose: - [x.reach_full_pose() for x in pose_costs] - - pose_costs = self.get_pose_costs(include_convergence=False) - if metric.remove_offset_waypoint: - [x.remove_offset_waypoint() for x in pose_costs] - - if metric.offset_position is not None or metric.offset_rotation is not None: - [ - x.update_offset_waypoint( - offset_position=metric.offset_position, - offset_rotation=metric.offset_rotation, - offset_tstep_fraction=metric.offset_tstep_fraction, - ) - for x in pose_costs - ] + pose_costs = self.get_pose_costs( + include_link_pose=metric.include_link_pose, include_convergence=False + ) + for p in pose_costs: + p.update_metric(metric, update_offset_waypoint=True) + + pose_costs = self.get_pose_costs(only_convergence=True) + for p in pose_costs: + p.update_metric(metric, update_offset_waypoint=False) @get_torch_jit_decorator() diff --git a/src/curobo/rollout/cost/pose_cost.py b/src/curobo/rollout/cost/pose_cost.py index 46390bdb..b2e32007 100644 --- a/src/curobo/rollout/cost/pose_cost.py +++ b/src/curobo/rollout/cost/pose_cost.py @@ -86,6 +86,7 @@ class PoseCostMetric: offset_tstep_fraction: float = -1.0 remove_offset_waypoint: bool = False include_link_pose: bool = False + project_to_goal_frame: Optional[bool] = None def clone(self): @@ -102,6 +103,8 @@ def clone(self): offset_rotation=None if self.offset_rotation is None else self.offset_rotation.clone(), offset_tstep_fraction=self.offset_tstep_fraction, remove_offset_waypoint=self.remove_offset_waypoint, + include_link_pose=self.include_link_pose, + project_to_goal_frame=self.project_to_goal_frame, ) @classmethod @@ -110,6 +113,7 @@ def create_grasp_approach_metric( offset_position: float = 0.1, linear_axis: int = 2, tstep_fraction: float = 0.8, + project_to_goal_frame: Optional[bool] = None, tensor_args: TensorDeviceType = TensorDeviceType(), ) -> PoseCostMetric: """Enables moving to a pregrasp and then locked orientation movement to final grasp. @@ -121,6 +125,8 @@ def create_grasp_approach_metric( offset_position: offset in meters. linear_axis: specifies the x y or z axis. tstep_fraction: specifies the timestep fraction to start activating this constraint. + project_to_goal_frame: compute distance w.r.t. to goal frame instead of robot base + frame. If None, it will use value set in PoseCostConfig. tensor_args: cuda device. Returns: @@ -150,12 +156,17 @@ class PoseCost(CostBase, PoseCostConfig): def __init__(self, config: PoseCostConfig): PoseCostConfig.__init__(self, **vars(config)) CostBase.__init__(self) + self.project_distance_tensor = torch.tensor( + [self.project_distance], + device=self.tensor_args.device, + dtype=torch.uint8, + ) self.rot_weight = self.vec_weight[0:3] self.pos_weight = self.vec_weight[3:6] self._vec_convergence = self.tensor_args.to_device(self.vec_convergence) self._batch_size = 0 - def update_metric(self, metric: PoseCostMetric): + def update_metric(self, metric: PoseCostMetric, update_offset_waypoint: bool = True): if metric.hold_partial_pose: if metric.hold_vec_weight is None: log_error("hold_vec_weight is required") @@ -168,19 +179,22 @@ def update_metric(self, metric: PoseCostMetric): self.reach_partial_pose(metric.reach_vec_weight) if metric.reach_full_pose: self.reach_full_pose() - - if metric.remove_offset_waypoint: - self.remove_offset_waypoint() - - if metric.offset_position is not None or metric.offset_rotation is not None: - self.update_offset_waypoint( - offset_position=self.offset_position, - offset_rotation=self.offset_rotation, - offset_tstep_fraction=self.offset_tstep_fraction, - ) + if metric.project_to_goal_frame is not None: + self.project_distance_tensor[:] = metric.project_to_goal_frame + else: + self.project_distance_tensor[:] = self.project_distance + if update_offset_waypoint: + if metric.remove_offset_waypoint: + self.remove_offset_waypoint() + + if metric.offset_position is not None or metric.offset_rotation is not None: + self.update_offset_waypoint( + offset_position=metric.offset_position, + offset_rotation=metric.offset_rotation, + offset_tstep_fraction=metric.offset_tstep_fraction, + ) def hold_partial_pose(self, run_vec_weight: torch.Tensor): - self.run_vec_weight.copy_(run_vec_weight) def release_partial_pose(self): @@ -391,6 +405,7 @@ def forward_out_distance( self.offset_waypoint, self.offset_tstep_fraction, goal.batch_pose_idx, + self.project_distance_tensor, self.out_distance, self.out_position_distance, self.out_rotation_distance, @@ -404,7 +419,6 @@ def forward_out_distance( self.cost_type.value, num_goals, self.use_metric, - self.project_distance, ) # print(self.out_idx.shape, self.out_idx[:,-1]) # print(goal.batch_pose_idx.shape) @@ -444,6 +458,7 @@ def forward(self, ee_pos_batch, ee_rot_batch, goal: Goal, link_name: Optional[st self.offset_waypoint, self.offset_tstep_fraction, goal.batch_pose_idx, + self.project_distance_tensor, self.out_distance, self.out_position_distance, self.out_rotation_distance, @@ -457,7 +472,6 @@ def forward(self, ee_pos_batch, ee_rot_batch, goal: Goal, link_name: Optional[st self.cost_type.value, num_goals, self.use_metric, - self.project_distance, self.return_loss, ) @@ -498,6 +512,7 @@ def forward_pose( self.offset_waypoint, self.offset_tstep_fraction, batch_pose_idx, + self.project_distance_tensor, self.out_distance, self.out_position_distance, self.out_rotation_distance, @@ -511,7 +526,6 @@ def forward_pose( self.cost_type.value, num_goals, self.use_metric, - self.project_distance, self.return_loss, ) return distance diff --git a/src/curobo/rollout/dynamics_model/kinematic_model.py b/src/curobo/rollout/dynamics_model/kinematic_model.py index 9c9b6c33..f8a1bacf 100644 --- a/src/curobo/rollout/dynamics_model/kinematic_model.py +++ b/src/curobo/rollout/dynamics_model/kinematic_model.py @@ -449,7 +449,14 @@ def forward( state_seq = self.state_seq curr_batch_size = self.batch_size num_traj_points = self.horizon - + if not state_seq.position.is_contiguous(): + state_seq.position = state_seq.position.contiguous() + if not state_seq.velocity.is_contiguous(): + state_seq.velocity = state_seq.velocity.contiguous() + if not state_seq.acceleration.is_contiguous(): + state_seq.acceleration = state_seq.acceleration.contiguous() + if not state_seq.jerk.is_contiguous(): + state_seq.jerk = state_seq.jerk.contiguous() with profiler.record_function("tensor_step"): # forward step with step matrix: state_seq = self.tensor_step(start_state_shaped, act_seq, state_seq, start_state_idx) diff --git a/src/curobo/types/camera.py b/src/curobo/types/camera.py index c2387229..007daac9 100644 --- a/src/curobo/types/camera.py +++ b/src/curobo/types/camera.py @@ -121,7 +121,7 @@ def get_pointcloud(self, project_to_pose: bool = False): point_cloud = project_depth_using_rays(depth_image, self.projection_rays) if project_to_pose and self.pose is not None: - point_cloud = self.pose.batch_transform(point_cloud) + point_cloud = self.pose.batch_transform_points(point_cloud) return point_cloud diff --git a/src/curobo/types/math.py b/src/curobo/types/math.py index 97f11dd4..a2c84599 100644 --- a/src/curobo/types/math.py +++ b/src/curobo/types/math.py @@ -507,37 +507,25 @@ def angular_distance_phi3(goal_quat, current_quat): class OrientationError(Function): @staticmethod def geodesic_distance(goal_quat, current_quat, quat_res): - conjugate_quat = current_quat.clone() - conjugate_quat[..., 1:] *= -1.0 - quat_res = quat_multiply(goal_quat, conjugate_quat, quat_res) - - quat_res = -1.0 * quat_res * torch.sign(quat_res[..., 0]).unsqueeze(-1) - quat_res[..., 0] = 0.0 - # quat_res = conjugate_quat * 0.0 - return quat_res + quat_grad, rot_error = geodesic_distance(goal_quat, current_quat, quat_res) + return quat_grad, rot_error @staticmethod def forward(ctx, goal_quat, current_quat, quat_res): - quat_res = OrientationError.geodesic_distance(goal_quat, current_quat, quat_res) - rot_error = torch.norm(quat_res, dim=-1, keepdim=True) - ctx.save_for_backward(quat_res, rot_error) + quat_grad, rot_error = OrientationError.geodesic_distance(goal_quat, current_quat, quat_res) + ctx.save_for_backward(quat_grad) return rot_error @staticmethod def backward(ctx, grad_out): - grad_mul = None - if ctx.needs_input_grad[1]: - (quat_error, r_err) = ctx.saved_tensors - scale = 1 / r_err - scale = torch.nan_to_num(scale, 0, 0, 0) + grad_mul = grad_mul1 = None + (quat_grad,) = ctx.saved_tensors - grad_mul = grad_out * scale * quat_error - # print(grad_out.shape) - # if grad_out.shape[0] == 6: - # #print(grad_out.view(-1)) - # #print(grad_mul.view(-1)[-6:]) - # #exit() - return None, grad_mul, None + if ctx.needs_input_grad[1]: + grad_mul = grad_out * quat_grad + if ctx.needs_input_grad[0]: + grad_mul1 = -1.0 * grad_out * quat_grad + return grad_mul1, grad_mul, None @get_torch_jit_decorator() @@ -549,3 +537,19 @@ def normalize_quaternion(in_quaternion: torch.Tensor) -> torch.Tensor: # normalize quaternion in_q = k2 * in_quaternion return in_q + + +@get_torch_jit_decorator() +def geodesic_distance(goal_quat, current_quat, quat_res): + conjugate_quat = current_quat.detach().clone() + conjugate_quat[..., 1:] *= -1.0 + quat_res = quat_multiply(goal_quat, conjugate_quat, quat_res) + sign = torch.sign(quat_res[..., 0]) + sign = torch.where(sign == 0, 1.0, sign) + quat_res = -1.0 * quat_res * sign.unsqueeze(-1) + quat_res[..., 0] = 0.0 + rot_error = torch.norm(quat_res, dim=-1, keepdim=True) + scale = 1.0 / rot_error + scale = torch.nan_to_num(scale, 0.0, 0.0, 0.0) + quat_res = quat_res * scale + return quat_res, rot_error diff --git a/src/curobo/util/helpers.py b/src/curobo/util/helpers.py index 6ff31b93..4eb1a12a 100644 --- a/src/curobo/util/helpers.py +++ b/src/curobo/util/helpers.py @@ -9,6 +9,7 @@ # its affiliates is strictly prohibited. # # Standard Library +import math from collections import defaultdict from typing import List @@ -27,3 +28,11 @@ def list_idx_if_not_none(d_list: List, idx: int): else: idx_list.append(None) return idx_list + + +def robust_floor(x: float, threshold: float = 1e-04) -> int: + nearest_int = round(x) + if abs(x - nearest_int) < threshold: + return nearest_int + else: + return int(math.floor(x)) diff --git a/src/curobo/util/sample_lib.py b/src/curobo/util/sample_lib.py index 9c2140f0..4397e4ba 100644 --- a/src/curobo/util/sample_lib.py +++ b/src/curobo/util/sample_lib.py @@ -137,7 +137,7 @@ def get_samples(self, sample_shape, base_seed=None, filter_smooth=False, **kwarg return self.samples -def bspline(c_arr, t_arr=None, n=100, degree=3): +def bspline(c_arr: torch.Tensor, t_arr=None, n=100, degree=3): sample_device = c_arr.device sample_dtype = c_arr.dtype cv = c_arr.cpu().numpy() diff --git a/src/curobo/util/torch_utils.py b/src/curobo/util/torch_utils.py index 60414120..ae8790c6 100644 --- a/src/curobo/util/torch_utils.py +++ b/src/curobo/util/torch_utils.py @@ -178,3 +178,9 @@ def get_cache_fn_decorator(maxsize: Optional[int] = None): def empty_decorator(function): return function + + +@get_torch_jit_decorator() +def round_away_from_zero(x: torch.Tensor) -> torch.Tensor: + y = torch.trunc(x + 0.5 * torch.sign(x)) + return y diff --git a/src/curobo/util/trajectory.py b/src/curobo/util/trajectory.py index 2fe61aa7..c4423eac 100644 --- a/src/curobo/util/trajectory.py +++ b/src/curobo/util/trajectory.py @@ -99,7 +99,7 @@ def get_linear_traj( return trajectory -def get_smooth_trajectory(raw_traj, degree=5): +def get_smooth_trajectory(raw_traj: torch.Tensor, degree: int = 5): cpu_traj = raw_traj.cpu() smooth_traj = torch.zeros_like(cpu_traj) @@ -108,11 +108,10 @@ def get_smooth_trajectory(raw_traj, degree=5): return smooth_traj.to(dtype=raw_traj.dtype, device=raw_traj.device) -def get_spline_interpolated_trajectory(raw_traj, des_horizon, degree=5): +def get_spline_interpolated_trajectory(raw_traj: torch.Tensor, des_horizon: int, degree: int = 5): retimed_traj = torch.zeros((des_horizon, raw_traj.shape[-1])) tensor_args = TensorDeviceType(device=raw_traj.device, dtype=raw_traj.dtype) - cpu_traj = raw_traj.cpu().numpy() - + cpu_traj = raw_traj.cpu() for i in range(cpu_traj.shape[-1]): retimed_traj[:, i] = bspline(cpu_traj[:, i], n=des_horizon, degree=degree) retimed_traj = retimed_traj.to(**(tensor_args.as_torch_dict())) @@ -179,7 +178,7 @@ def get_batch_interpolated_trajectory( opt_dt[:] = raw_dt # traj_steps contains the tsteps for each trajectory if steps_max <= 0: - log_error("Steps max is less than 0") + log_error("Steps max is less than 1, with a value: " + str(steps_max)) if out_traj_state is not None and out_traj_state.position.shape[1] < steps_max: log_warn( @@ -610,5 +609,7 @@ def calculate_tsteps( ) if not optimize_dt: opt_dt[:] = raw_dt + # check for nan: + opt_dt = torch.nan_to_num(opt_dt, nan=min_dt) traj_steps, steps_max = calculate_traj_steps(opt_dt, interpolation_dt, horizon) return traj_steps, steps_max, opt_dt diff --git a/src/curobo/util/xrdf_utils.py b/src/curobo/util/xrdf_utils.py index fc5d0c2a..f001be7f 100644 --- a/src/curobo/util/xrdf_utils.py +++ b/src/curobo/util/xrdf_utils.py @@ -8,6 +8,7 @@ # its affiliates is strictly prohibited. # Standard Library +from copy import deepcopy from typing import Any, Dict, Optional # CuRobo @@ -17,9 +18,13 @@ from curobo.util_file import load_yaml -def return_value_if_exists(input_dict: Dict, key: str, suffix: str = "xrdf") -> Any: +def return_value_if_exists( + input_dict: Dict, key: str, suffix: str = "xrdf", raise_error: bool = True +) -> Any: if key not in input_dict: - log_error(key + " key not found in " + suffix) + if raise_error: + log_error(key + " key not found in " + suffix) + return None return input_dict[key] @@ -42,7 +47,6 @@ def convert_xrdf_to_curobo( if return_value_if_exists(input_xrdf_dict, "format") != "xrdf": log_error("format is not xrdf") - raise ValueError("format is not xrdf") if return_value_if_exists(input_xrdf_dict, "format_version") > 1.0: log_warn("format_version is greater than 1.0") @@ -63,7 +67,11 @@ def convert_xrdf_to_curobo( coll_spheres = return_value_if_exists(input_xrdf_dict["geometry"][coll_name], "spheres") output_dict["collision_spheres"] = coll_spheres - buffer_distance = return_value_if_exists(input_xrdf_dict["collision"], "buffer_distance") + buffer_distance = return_value_if_exists( + input_xrdf_dict["collision"], "buffer_distance", raise_error=False + ) + if buffer_distance is None: + buffer_distance = 0.0 output_dict["collision_sphere_buffer"] = buffer_distance output_dict["collision_link_names"] = list(coll_spheres.keys()) @@ -82,8 +90,10 @@ def convert_xrdf_to_curobo( self_collision_buffer = return_value_if_exists( input_xrdf_dict["self_collision"], "buffer_distance", + raise_error=False, ) - + if self_collision_buffer is None: + self_collision_buffer = {} output_dict["self_collision_ignore"] = self_collision_ignore output_dict["self_collision_buffer"] = self_collision_buffer else: @@ -92,10 +102,10 @@ def convert_xrdf_to_curobo( log_warn("collision key not found in xrdf, collision avoidance is disabled") tool_frames = return_value_if_exists(input_xrdf_dict, "tool_frames") + output_dict["ee_link"] = tool_frames[0] - output_dict["link_names"] = None if len(tool_frames) > 1: - output_dict["link_names"] = input_xrdf_dict["tool_frames"] + output_dict["link_names"] = deepcopy(tool_frames) # cspace: cspace_dict = return_value_if_exists(input_xrdf_dict, "cspace") diff --git a/src/curobo/util_file.py b/src/curobo/util_file.py index d366f31b..eb62e5ef 100644 --- a/src/curobo/util_file.py +++ b/src/curobo/util_file.py @@ -11,17 +11,33 @@ """Contains helper functions for interacting with file systems.""" # Standard Library import os +import re import shutil import sys from typing import Any, Dict, List, Union # Third Party import yaml -from yaml import Loader +from yaml import SafeLoader as Loader # CuRobo from curobo.util.logger import log_warn +Loader.add_implicit_resolver( + "tag:yaml.org,2002:float", + re.compile( + """^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), +) + # get paths def get_module_path() -> str: diff --git a/src/curobo/wrap/reacher/motion_gen.py b/src/curobo/wrap/reacher/motion_gen.py index be15672f..07a7cfdf 100644 --- a/src/curobo/wrap/reacher/motion_gen.py +++ b/src/curobo/wrap/reacher/motion_gen.py @@ -242,6 +242,7 @@ def load_from_robot_config( ik_seed: int = 1531, graph_seed: int = 1531, high_precision: bool = False, + use_cuda_graph_trajopt_metrics: bool = False, ): """Create a motion generation configuration from robot and world configuration. @@ -473,6 +474,10 @@ def load_from_robot_config( the number of iterations for optimization solvers and reduce the thresholds for position to 1mm and rotation to 0.025. Default of False is recommended for most cases as standard motion generation settings reach within 0.5mm on most problems. + use_cuda_graph_trajopt_metrics: Flag to enable cuda_graph when evaluating interpolated + trajectories after trajectory optimization. If interpolation_buffer is smaller + than interpolated trajectory, then the buffers will be re-created. This can cause + existing cuda graph to be invalid. Returns: MotionGenConfig: Instance of motion generation configuration. @@ -722,6 +727,7 @@ def load_from_robot_config( minimize_jerk=minimize_jerk, optimize_dt=optimize_dt, project_pose_to_goal_frame=project_pose_to_goal_frame, + use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics, ) trajopt_solver = TrajOptSolver(trajopt_cfg) @@ -763,6 +769,7 @@ def load_from_robot_config( filter_robot_command=filter_robot_command, optimize_dt=optimize_dt, num_seeds=num_trajopt_noisy_seeds, + use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics, ) js_trajopt_solver = TrajOptSolver(js_trajopt_cfg) @@ -805,6 +812,7 @@ def load_from_robot_config( filter_robot_command=filter_robot_command, optimize_dt=optimize_dt, project_pose_to_goal_frame=project_pose_to_goal_frame, + use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics, ) finetune_trajopt_solver = TrajOptSolver(finetune_trajopt_cfg) @@ -847,6 +855,7 @@ def load_from_robot_config( filter_robot_command=filter_robot_command, optimize_dt=optimize_dt, num_seeds=num_trajopt_noisy_seeds, + use_cuda_graph_metrics=use_cuda_graph_trajopt_metrics, ) finetune_js_trajopt_solver = TrajOptSolver(finetune_js_trajopt_cfg) @@ -1379,6 +1388,26 @@ def _check_none_and_copy_idx( return current_tensor +@dataclass +class GraspPlanResult: + success: Optional[torch.Tensor] = None + grasp_trajectory: Optional[JointState] = None + grasp_trajectory_dt: Optional[torch.Tensor] = None + grasp_interpolated_trajectory: Optional[JointState] = None + grasp_interpolation_dt: Optional[torch.Tensor] = None + retract_trajectory: Optional[JointState] = None + retract_trajectory_dt: Optional[torch.Tensor] = None + retract_interpolated_trajectory: Optional[JointState] = None + retract_interpolation_dt: Optional[torch.Tensor] = None + approach_result: Optional[MotionGenResult] = None + grasp_result: Optional[MotionGenResult] = None + retract_result: Optional[MotionGenResult] = None + status: Optional[str] = None + goalset_result: Optional[MotionGenResult] = None + planning_time: float = 0.0 + goalset_index: Optional[torch.Tensor] = None + + class MotionGen(MotionGenConfig): """Motion generation wrapper for generating collision-free trajectories. @@ -2167,10 +2196,16 @@ def update_pose_cost_metric( Returns: bool: True if the constraint can be added, False otherwise. """ + + rollouts = self.get_all_pose_rollout_instances() + # check if constraint is valid: if metric.hold_partial_pose and metric.offset_tstep_fraction < 0.0: start_pose = self.compute_kinematics(start_state).ee_pose.clone() - if self.project_pose_to_goal_frame: + project_distance = metric.project_to_goal_frame + if project_distance is None: + project_distance = rollouts[0].goal_cost.project_distance + if project_distance: # project start pose to goal frame: projected_pose = goal_pose.compute_local_pose(start_pose) if torch.count_nonzero(metric.hold_vec_weight[:3] > 0.0) > 0: @@ -2208,7 +2243,6 @@ def update_pose_cost_metric( log_warn("Partial position between start and goal is not equal.") return False - rollouts = self.get_all_pose_rollout_instances() [ rollout.update_pose_cost_metric(metric) for rollout in rollouts @@ -2955,6 +2989,7 @@ def _plan_attempts( """ start_time = time.time() valid_query = True + plan_config = plan_config.clone() if plan_config.check_start_validity: valid_query, status = self.check_start_state(start_state) if not valid_query: @@ -3094,6 +3129,7 @@ def _plan_batch_attempts( MotionGenResult: Result of batched planning. """ start_time = time.time() + plan_config = plan_config.clone() goal_pose = goal_pose.clone() if plan_config.pose_cost_metric is not None: valid_query = self.update_pose_cost_metric( @@ -4135,3 +4171,243 @@ def batch_plan( ) result = self.plan_batch(start_state, goal_pose, plan_config) return result + + def toggle_link_collision(self, collision_link_names: List[str], enable_flag: bool): + if len(collision_link_names) > 0: + if enable_flag: + for k in collision_link_names: + self.kinematics.kinematics_config.enable_link_spheres(k) + else: + for k in collision_link_names: + self.kinematics.kinematics_config.disable_link_spheres(k) + + def plan_grasp( + self, + start_state: JointState, + grasp_poses: Pose, + plan_config: MotionGenPlanConfig, + grasp_approach_offset: Pose = Pose.from_list([0, 0, -0.15, 1, 0, 0, 0]), + grasp_approach_path_constraint: Union[None, List[float]] = [0.1, 0.1, 0.1, 0.1, 0.1, 0.0], + retract_offset: Pose = Pose.from_list([0, 0, -0.15, 1, 0, 0, 0]), + retract_path_constraint: Union[None, List[float]] = [0.1, 0.1, 0.1, 0.1, 0.1, 0.0], + disable_collision_links: List[str] = [], + plan_approach_to_grasp: bool = True, + plan_grasp_to_retract: bool = True, + grasp_approach_constraint_in_goal_frame: bool = True, + retract_constraint_in_goal_frame: bool = True, + ) -> GraspPlanResult: + """Plan a sequence of motions to grasp an object, given a set of grasp poses. + + This function plans three motions, first approaches the object with an offset, then + moves with linear constraints to the grasp pose, and finally retracts the arm base to + offset with linear constraints. During the linear constrained motions, collision between + disable_collision_links and the world is disabled. This disabling is useful to enable + contact between a robot's gripper links and the object. + + This method takes a set of grasp poses and finds the best grasp pose to reach based on a + goal set trajectory optimization problem. In this problem, the robot needs to reach one + of the poses in the grasp_poses set at the terminal state. To allow for in-contact grasps, + collision between disable_collision_links and world is disabled during the optimization. + The best grasp pose is then used to plan the three motions. + + Args: + start_state: Start joint state for planning. + grasp_poses: Set of grasp poses, represented with :class:~curobo.math.types.Pose, of + shape (1, num_grasps, 7). + plan_config: Planning parameters for motion generation. + grasp_approach_offset: Offset pose from the grasp pose. Reference frame is the grasp + pose frame if grasp_approach_constraint_in_goal_frame is True, otherwise the + reference frame is the robot base frame. + grasp_approach_path_constraint: Path constraint for the approach to grasp pose and + grasp to retract path. This is a list of 6 values, where each value is a weight + for each Cartesian dimension. The first three are for orientation and the last + three are for position. If None, no path constraint is applied. + retract_offset: Retract offset pose from grasp pose. Reference frame is the grasp pose + frame if retract_constraint_in_goal_frame is True, otherwise the reference frame is + the robot base frame. + retract_path_constraint: Path constraint for the retract path. This is a list of 6 + values, where each value is a weight for each Cartesian dimension. The first three + are for orientation and the last three are for position. If None, no path + constraint is applied. + disable_collision_links: Name of links to disable collision with the world during + the approach to grasp and grasp to retract path. + plan_approach_to_grasp: If True, planning also includes moving from approach to + grasp. If False, a plan to reach offset of the best grasp pose is returned. + plan_grasp_to_retract: If True, planning also includes moving from grasp to retract. + If False, only a plan to reach the best grasp pose is returned. + grasp_approach_constraint_in_goal_frame: If True, the grasp approach offset is in the + grasp pose frame. If False, the grasp approach offset is in the robot base frame. + Also applies to grasp_approach_path_constraint. + retract_constraint_in_goal_frame: If True, the retract offset is in the grasp pose + frame. If False, the retract offset is in the robot base frame. Also applies to + retract_path_constraint. + + Returns: + GraspPlanResult: Result of planning. Use :meth:`GraspPlanResult.grasp_trajectory` to + get the trajectory to reach the grasp pose and + :meth:`GraspPlanResult.retract_trajectory` to get the trajectory to retract from + the grasp pose. + """ + if plan_config.pose_cost_metric is not None: + log_error("plan_config.pose_cost_metric should be None") + self.toggle_link_collision(disable_collision_links, False) + result = GraspPlanResult() + goalset_motion_gen_result = self.plan_goalset( + start_state, + grasp_poses, + plan_config, + ) + self.toggle_link_collision(disable_collision_links, True) + result.success = goalset_motion_gen_result.success.clone() + result.success[:] = False + result.goalset_result = goalset_motion_gen_result + if not goalset_motion_gen_result.success.item(): + result.status = "No grasp in goal set was reachable." + return result + result.goalset_index = goalset_motion_gen_result.goalset_index.clone() + + # plan to offset: + goal_index = goalset_motion_gen_result.goalset_index.item() + goal_pose = grasp_poses.get_index(0, goal_index).clone() + if grasp_approach_constraint_in_goal_frame: + offset_goal_pose = goal_pose.clone().multiply(grasp_approach_offset) + else: + offset_goal_pose = grasp_approach_offset.clone().multiply(goal_pose.clone()) + + reach_offset_mg_result = self.plan_single( + start_state, + offset_goal_pose, + plan_config.clone(), + ) + result.approach_result = reach_offset_mg_result + if not reach_offset_mg_result.success.item(): + result.status = f"Planning to Approach pose failed: {reach_offset_mg_result.status}" + return result + + if not plan_approach_to_grasp: + result.grasp_trajectory = reach_offset_mg_result.optimized_plan + result.grasp_trajectory_dt = reach_offset_mg_result.optimized_dt + result.grasp_interpolated_trajectory = reach_offset_mg_result.get_interpolated_plan() + result.grasp_interpolation_dt = reach_offset_mg_result.interpolation_dt + return result + # plan to final grasp + if grasp_approach_path_constraint is not None: + hold_pose_cost_metric = PoseCostMetric( + hold_partial_pose=True, + hold_vec_weight=self.tensor_args.to_device(grasp_approach_path_constraint), + project_to_goal_frame=grasp_approach_constraint_in_goal_frame, + ) + plan_config.pose_cost_metric = hold_pose_cost_metric + + offset_start_state = reach_offset_mg_result.optimized_plan[-1].unsqueeze(0) + + self.toggle_link_collision(disable_collision_links, False) + + reach_grasp_mg_result = self.plan_single( + offset_start_state, + goal_pose, + plan_config, + ) + self.toggle_link_collision(disable_collision_links, True) + result.grasp_result = reach_grasp_mg_result + if not reach_grasp_mg_result.success.item(): + result.status = ( + f"Planning from Approach to Grasp Failed: {reach_grasp_mg_result.status}" + ) + return result + + # Get stitched trajectory: + + offset_dt = reach_offset_mg_result.optimized_dt + grasp_dt = reach_grasp_mg_result.optimized_dt + if offset_dt > grasp_dt: + # retime grasp trajectory to match offset trajectory: + grasp_time_dilation = grasp_dt / offset_dt + + reach_grasp_mg_result.retime_trajectory( + grasp_time_dilation, + interpolate_trajectory=True, + ) + else: + offset_time_dilation = offset_dt / grasp_dt + + reach_offset_mg_result.retime_trajectory( + offset_time_dilation, + interpolate_trajectory=True, + ) + + if (reach_offset_mg_result.optimized_dt - reach_grasp_mg_result.optimized_dt).abs() > 0.01: + reach_offset_mg_result.success[:] = False + if reach_offset_mg_result.debug_info is None: + reach_offset_mg_result.debug_info = {} + reach_offset_mg_result.debug_info["plan_single_grasp_status"] = ( + "Stitching Trajectories Failed" + ) + return reach_offset_mg_result, None + + result.grasp_trajectory = reach_offset_mg_result.optimized_plan.stack( + reach_grasp_mg_result.optimized_plan + ).clone() + + result.grasp_trajectory_dt = reach_offset_mg_result.optimized_dt + + result.grasp_interpolated_trajectory = ( + reach_offset_mg_result.get_interpolated_plan() + .stack(reach_grasp_mg_result.get_interpolated_plan()) + .clone() + ) + result.grasp_interpolation_dt = reach_offset_mg_result.interpolation_dt + + # update trajectories in results: + result.planning_time = ( + reach_offset_mg_result.total_time + + reach_grasp_mg_result.total_time + + goalset_motion_gen_result.total_time + ) + + # check if retract path is required: + result.success[:] = True + if not plan_grasp_to_retract: + return result + + result.success[:] = False + self.toggle_link_collision(disable_collision_links, False) + grasp_start_state = result.grasp_trajectory[-1].unsqueeze(0) + + # compute retract goal pose: + if retract_constraint_in_goal_frame: + retract_goal_pose = goal_pose.clone().multiply(retract_offset) + else: + retract_goal_pose = retract_offset.clone().multiply(goal_pose.clone()) + + # add path constraint for retract: + plan_config.pose_cost_metric = None + + if retract_path_constraint is not None: + hold_pose_cost_metric = PoseCostMetric( + hold_partial_pose=True, + hold_vec_weight=self.tensor_args.to_device(retract_path_constraint), + project_to_goal_frame=retract_constraint_in_goal_frame, + ) + plan_config.pose_cost_metric = hold_pose_cost_metric + + # plan from grasp pose to retract: + retract_grasp_mg_result = self.plan_single( + grasp_start_state, + retract_goal_pose, + plan_config, + ) + self.toggle_link_collision(disable_collision_links, True) + result.planning_time += retract_grasp_mg_result.total_time + if not retract_grasp_mg_result.success.item(): + result.status = f"Retract from Grasp failed: {retract_grasp_mg_result.status}" + result.retract_result = retract_grasp_mg_result + return result + result.success[:] = True + + result.retract_trajectory = retract_grasp_mg_result.optimized_plan + result.retract_trajectory_dt = retract_grasp_mg_result.optimized_dt + result.retract_interpolated_trajectory = retract_grasp_mg_result.get_interpolated_plan() + result.retract_interpolation_dt = retract_grasp_mg_result.interpolation_dt + + return result diff --git a/src/curobo/wrap/reacher/trajopt.py b/src/curobo/wrap/reacher/trajopt.py index d64b5fb8..7a7f1a10 100644 --- a/src/curobo/wrap/reacher/trajopt.py +++ b/src/curobo/wrap/reacher/trajopt.py @@ -148,6 +148,7 @@ def load_from_robot_config( filter_robot_command: bool = False, optimize_dt: bool = True, project_pose_to_goal_frame: bool = True, + use_cuda_graph_metrics: bool = False, ): """Load TrajOptSolver configuration from robot configuration. @@ -290,6 +291,10 @@ def load_from_robot_config( project_pose_to_goal_frame: Project pose to goal frame when calculating distance between reached and goal pose. Use this to constrain motion to specific axes either in the global frame or the goal frame. + use_cuda_graph_metrics: Flag to enable cuda_graph when evaluating interpolated + trajectories after trajectory optimization. If interpolation_buffer is smaller + than interpolated trajectory, then the buffers will be re-created. This can cause + existing cuda graph to be invalid. Returns: TrajOptSolverConfig: Trajectory optimization configuration. @@ -508,7 +513,7 @@ def load_from_robot_config( safety_rollout=arm_rollout_safety, optimizers=opt_list, compute_metrics=True, - use_cuda_graph_metrics=use_cuda_graph, + use_cuda_graph_metrics=use_cuda_graph_metrics, sync_cuda_time=sync_cuda_time, ) trajopt = WrapBase(cfg) @@ -539,7 +544,7 @@ def load_from_robot_config( tensor_args=tensor_args, sync_cuda_time=sync_cuda_time, interpolate_rollout=interpolate_rollout, - use_cuda_graph_metrics=use_cuda_graph, + use_cuda_graph_metrics=use_cuda_graph_metrics, trim_steps=trim_steps, store_debug_in_result=store_debug_in_result, optimize_dt=optimize_dt, @@ -720,7 +725,7 @@ def attach_spheres_to_robot( link_name: Name of the link to attach the spheres to. Note that this link should already have pre-allocated spheres. """ - self.kinematics.attach_object( + self.kinematics.kinematics_config.attach_object( sphere_radius=sphere_radius, sphere_tensor=sphere_tensor, link_name=link_name ) @@ -730,7 +735,7 @@ def detach_spheres_from_robot(self, link_name: str = "attached_object") -> None: Args: link_name: Name of the link to detach the spheres from. """ - self.kinematics.detach_object(link_name) + self.kinematics.kinematics_config.detach_object(link_name) def _update_solve_state_and_goal_buffer( self, diff --git a/tests/goal_test.py b/tests/goal_test.py index b95ac1b1..51538f42 100644 --- a/tests/goal_test.py +++ b/tests/goal_test.py @@ -53,6 +53,7 @@ def test_repeat_seeds(): weight = tensor_args.to_device([1, 1, 1, 1]) vec_convergence = tensor_args.to_device([0, 0]) run_weight = tensor_args.to_device([1]) + project_distance = torch.tensor([True], device=tensor_args.device, dtype=torch.uint8) r = get_pose_distance( out_d, out_d.clone(), @@ -72,6 +73,7 @@ def test_repeat_seeds(): offset_waypoint, offset_tstep_fraction, g.batch_pose_idx, + project_distance, start_pose.position.shape[0], 1, 1, @@ -79,7 +81,6 @@ def test_repeat_seeds(): False, False, True, - True, ) assert torch.sum(r[0]).item() <= 1e-5 @@ -105,6 +106,8 @@ def test_horizon_repeat_seeds(): batch_pose_idx = torch.arange(0, b, 1, device=tensor_args.device, dtype=torch.int32).unsqueeze( -1 ) + project_distance = torch.tensor([True], device=tensor_args.device, dtype=torch.uint8) + goal = Goal(goal_pose=goal_pose, batch_pose_idx=batch_pose_idx, current_state=current_state) g = goal # .repeat_seeds(4) @@ -144,6 +147,7 @@ def test_horizon_repeat_seeds(): offset_waypoint, offset_tstep_fraction, g.batch_pose_idx, + project_distance, start_pose.position.shape[0], h, 1, @@ -151,6 +155,5 @@ def test_horizon_repeat_seeds(): True, False, False, - True, ) assert torch.sum(r[0]).item() < 1e-5 diff --git a/tests/motion_gen_goalset_test.py b/tests/motion_gen_goalset_test.py index a2bddbf7..9c9a9dfe 100644 --- a/tests/motion_gen_goalset_test.py +++ b/tests/motion_gen_goalset_test.py @@ -20,11 +20,11 @@ from curobo.wrap.reacher.motion_gen import MotionGen, MotionGenConfig, MotionGenPlanConfig -@pytest.fixture(scope="function") +@pytest.fixture(scope="module") def motion_gen(): tensor_args = TensorDeviceType() world_file = "collision_table.yml" - robot_file = "franka.yml" + robot_file = "ur5e_robotiq_2f_140.yml" motion_gen_config = MotionGenConfig.load_from_robot_config( robot_file, world_file, @@ -202,3 +202,40 @@ def test_batch_goalset_padded(motion_gen_batch): q = result.optimized_plan.trim_trajectory(-1).squeeze(1) reached_state = motion_gen.compute_kinematics(q) assert torch.norm(goal_pose.position - reached_state.ee_pos_seq) < 0.005 + + +def test_grasp_goalset(motion_gen): + motion_gen.reset() + m_config = MotionGenPlanConfig(False, True) + + retract_cfg = motion_gen.get_retract_config() + + state = motion_gen.compute_kinematics(JointState.from_position(retract_cfg.view(1, -1))) + + goal_pose = Pose( + state.ee_pos_seq.repeat(10, 1).view(1, -1, 3), + quaternion=state.ee_quat_seq.repeat(10, 1).view(1, -1, 4), + ) + goal_pose.position[0, 0, 0] += 0.2 + + start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3) + + result = motion_gen.plan_grasp( + start_state, + goal_pose, + m_config.clone(), + disable_collision_links=[ + "left_outer_knuckle", + "left_inner_knuckle", + "left_outer_finger", + "left_inner_finger", + "left_inner_finger_pad", + "right_outer_knuckle", + "right_inner_knuckle", + "right_outer_finger", + "right_inner_finger", + "right_inner_finger_pad", + ], + ) + + assert torch.count_nonzero(result.success) == 1 diff --git a/tests/self_collision_test.py b/tests/self_collision_test.py index 42bd8492..600553cd 100644 --- a/tests/self_collision_test.py +++ b/tests/self_collision_test.py @@ -9,6 +9,9 @@ # its affiliates is strictly prohibited. # +# Standard Library +import copy + # Third Party import pytest import torch @@ -108,3 +111,73 @@ def test_self_collision_franka(): cost_fn._out_distance[:] = 0.0 out = cost_fn.forward(in_spheres) assert out.sum().item() > 0.0 + + +def test_self_collision_10k_spheres_franka(): + tensor_args = TensorDeviceType() + + robot_cfg = load_yaml(join_path(get_robot_configs_path(), "franka.yml"))["robot_cfg"] + robot_cfg["kinematics"]["debug"] = {"self_collision_experimental": False} + + robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args) + robot_cfg.kinematics.self_collision_config.experimental_kernel = True + kinematics = CudaRobotModel(robot_cfg.kinematics) + self_collision_data = kinematics.get_self_collision_config() + self_collision_config = SelfCollisionCostConfig( + **{"weight": 1.0, "classify": False, "self_collision_kin_config": self_collision_data}, + tensor_args=tensor_args + ) + cost_fn = SelfCollisionCost(self_collision_config) + cost_fn.self_collision_kin_config.experimental_kernel = True + + b = 10 + h = 1 + + q = torch.rand( + (b * h, kinematics.get_dof()), device=tensor_args.device, dtype=tensor_args.dtype + ) + + test_q = tensor_args.to_device([2.7735, -1.6737, 0.4998, -2.9865, 0.3386, 0.8413, 0.4371]) + q[0, :] = test_q + kin_state = kinematics.get_state(q) + + in_spheres = kin_state.link_spheres_tensor + in_spheres = in_spheres.view(b, h, -1, 4).contiguous() + + out = cost_fn.forward(in_spheres) + assert out.sum().item() > 0.0 + + # create a franka robot with 10k spheres: + tensor_args = TensorDeviceType() + + robot_cfg = load_yaml(join_path(get_robot_configs_path(), "franka.yml"))["robot_cfg"] + robot_cfg["kinematics"]["debug"] = {"self_collision_experimental": False} + + sphere_cfg = load_yaml( + join_path(get_robot_configs_path(), robot_cfg["kinematics"]["collision_spheres"]) + )["collision_spheres"] + n_times = 10 + for k in sphere_cfg.keys(): + sphere_cfg[k] = [copy.deepcopy(x) for x in sphere_cfg[k] for _ in range(n_times)] + + robot_cfg["kinematics"]["collision_spheres"] = sphere_cfg + robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args) + robot_cfg.kinematics.self_collision_config.experimental_kernel = False + + kinematics = CudaRobotModel(robot_cfg.kinematics) + self_collision_data = kinematics.get_self_collision_config() + self_collision_config = SelfCollisionCostConfig( + **{"weight": 1.0, "classify": False, "self_collision_kin_config": self_collision_data}, + tensor_args=tensor_args + ) + cost_fn = SelfCollisionCost(self_collision_config) + cost_fn.self_collision_kin_config.experimental_kernel = False + + kin_state = kinematics.get_state(q) + + in_spheres = kin_state.link_spheres_tensor + in_spheres = in_spheres.view(b, h, -1, 4).contiguous() + + out_10k = cost_fn.forward(in_spheres) + assert out_10k.sum().item() > 0.0 + assert torch.linalg.norm(out - out_10k) < 1e-3 diff --git a/tests/voxelization_test.py b/tests/voxelization_test.py index c79de055..e4ac786d 100644 --- a/tests/voxelization_test.py +++ b/tests/voxelization_test.py @@ -122,7 +122,7 @@ def test_esdf_from_world(world_collision): world_collision.clear_voxelization_cache() esdf = world_collision.get_esdf_in_bounding_box(voxel_size=voxel_size).clone() - occupied = esdf.get_occupied_voxels() + occupied = esdf.get_occupied_voxels(feature_threshold=0.0) assert voxels.shape == occupied.shape @@ -136,7 +136,7 @@ def test_esdf_from_world(world_collision): indirect=True, ) def test_voxels_prim_mesh(world_collision, world_collision_primitive): - voxel_size = 0.1 + voxel_size = 0.05 voxels = world_collision.get_voxels_in_bounding_box(voxel_size=voxel_size).clone() voxels_prim = world_collision_primitive.get_voxels_in_bounding_box( voxel_size=voxel_size @@ -172,7 +172,7 @@ def test_esdf_prim_mesh(world_collision, world_collision_primitive): indirect=True, ) def test_marching_cubes_from_world(world_collision): - voxel_size = 0.1 + voxel_size = 0.05 voxels = world_collision.get_voxels_in_bounding_box(voxel_size=voxel_size) mesh = Mesh.from_pointcloud(voxels[:, :3].detach().cpu().numpy(), pitch=voxel_size * 0.1) diff --git a/tests/xrdf_test.py b/tests/xrdf_test.py index 5ca97b0c..2fa8f306 100644 --- a/tests/xrdf_test.py +++ b/tests/xrdf_test.py @@ -12,6 +12,9 @@ from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel, CudaRobotModelConfig from curobo.cuda_robot_model.util import load_robot_yaml from curobo.types.file_path import ContentPath +from curobo.types.math import Pose +from curobo.types.robot import JointState +from curobo.wrap.reacher.motion_gen import MotionGen, MotionGenConfig def test_xrdf_kinematics(): @@ -31,3 +34,32 @@ def test_xrdf_kinematics(): error = kin_pose.ee_pose.position - expected_position assert error.norm() < 0.01 + + assert "link_names" not in robot_data["robot_cfg"]["kinematics"] + + +def test_xrdf_motion_gen(): + robot_file = "ur10e.xrdf" + urdf_file = "robot/ur_description/ur10e.urdf" + content_path = ContentPath(robot_xrdf_file=robot_file, robot_urdf_file=urdf_file) + robot_data = load_robot_yaml(content_path) + robot_data["robot_cfg"]["kinematics"]["ee_link"] = "wrist_3_link" + + motion_gen_config = MotionGenConfig.load_from_robot_config( + robot_data, + "collision_table.yml", + use_cuda_graph=True, + ee_link_name="tool0", + ) + motion_gen = MotionGen(motion_gen_config) + motion_gen.warmup(warmup_js_trajopt=False) + retract_cfg = motion_gen.get_retract_config() + state = motion_gen.rollout_fn.compute_kinematics( + JointState.from_position(retract_cfg.view(1, -1)) + ) + + retract_pose = Pose(state.ee_pos_seq.squeeze(), quaternion=state.ee_quat_seq.squeeze()) + start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3) + result = motion_gen.plan_single(start_state, retract_pose) + + assert result.success.item()