Skip to content

Commit

Permalink
Add planning to grasp API
Browse files Browse the repository at this point in the history
  • Loading branch information
balakumar-s committed Nov 22, 2024
1 parent 18e9ebd commit 36ea382
Show file tree
Hide file tree
Showing 38 changed files with 941 additions and 537 deletions.
25 changes: 24 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,40 @@ its affiliates is strictly prohibited.
-->
# Changelog

## Latest Commit
## Version 0.7.5

### Changes in Default Behavior
- Remove explicit global seed setting for numpy and random. To enforce deterministic behavior,
use `np.random.seed(2)` and `random.seed(2)` in your program.
- geom.types.VoxelGrid now uses a different algorithm to calculate number of voxels per dimension
and also to compute xyz locations in a grid. This new implementation matches implementation in
nvblox.

### New Features
- Add pose cost metric to MPC to allow for partial pose reaching.
- Update obstacle poses in cpu reference with an optional flag.
- Add planning to grasp API in ``MotionGen.plan_grasp`` that plans a sequence of motions to grasp
an object given grasp poses. This API also provides args to disable collisions during the grasping
phase.
- Constrained planning can now use either goal frame or base frame at runtime.

### BugFixes & Misc.
- Fixed optimize_dt not being correctly set when motion gen is called in reactive mode.
- Add documentation for geom module.
- Add descriptive api for computing kinematics.
- Fix cv2 import order in isaac sim realsense examples.
- Fix attach sphere api mismatch in ``TrajOptSolver``.
- Fix bug in ``get_spline_interpolated_trajectory`` where
numpy array was created instead of torch tensor.
- Fix gradient bug when sphere origin is exactly at face of a cuboid.
- Add support for parsing Yaml 1.2 format with an updated regex for scientific notations.
- Move to yaml `SafeLoader` from `Loader`.
- Graph search checks if a node exists before attempting to find a path.
- Fix `steps_max` becoming 0 when optimized dt has NaN values.
- Clone `MotionGenPlanConfig` instance for every plan api.
- Improve sphere position to voxel location calculation to match nvblox's implementation.
- Add self collision checking support for spheres > 1024 and number of checks > 512 * 1024.
- Fix gradient passthrough in warp batch transform kernels.

## Version 0.7.4

Expand Down
5 changes: 3 additions & 2 deletions benchmark/curobo_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def load_curobo(
trajopt_dt=0.25,
finetune_dt_scale=finetune_dt_scale,
high_precision=args.high_precision,
use_cuda_graph_trajopt_metrics=cuda_graph,
)
mg = MotionGen(motion_gen_config)
mg.warmup(enable_graph=True, warmup_js_trajopt=False, parallel_finetune=parallel_finetune)
Expand Down Expand Up @@ -484,7 +485,7 @@ def benchmark_mb(
start_state,
q_traj,
dt=result.interpolation_dt,
save_path=join_path("benchmark/log/usd/", problem_name) + ".usd",
save_path=join_path("benchmark/log/usd/", problem_name)[1:] + ".usd",
interpolation_steps=1,
write_robot_usd_path="benchmark/log/usd/assets/",
robot_usd_local_reference="assets/",
Expand All @@ -499,7 +500,7 @@ def benchmark_mb(
result.optimized_plan,
result.optimized_dt.item(),
title=problem_name,
save_path=join_path("benchmark/log/plot/", problem_name + ".png"),
save_path=join_path("benchmark/log/plot/", problem_name + ".png")[1:],
)

m_list.append(current_metrics)
Expand Down
25 changes: 15 additions & 10 deletions benchmark/curobo_voxel_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def load_curobo(
"base": {
"dims": [2.4, 2.4, 2.4],
"pose": [0, 0, 0, 1, 0, 0, 0],
"voxel_size": 0.005,
"voxel_size": 0.02,
"feature_dtype": torch.bfloat16,
},
}
Expand Down Expand Up @@ -294,15 +294,15 @@ def benchmark_mb(
world = WorldConfig.from_dict(problem["obstacles"])

# mg.world_coll_checker.clear_cache()
world_coll = WorldConfig.from_dict(problem["obstacles"])
world_coll = WorldConfig.from_dict(problem["obstacles"]).get_collision_check_world()
if args.mesh:
world_coll = world_coll.get_mesh_world(merge_meshes=False)
robot_world.clear_world_cache()
robot_world.update_world(world_coll)

esdf = robot_world.world_model.get_esdf_in_bounding_box(
Cuboid(name="base", pose=[0, 0, 0, 1, 0, 0, 0], dims=[2.4, 2.4, 2.4]),
voxel_size=0.005,
voxel_size=0.02,
dtype=torch.float32,
)
# esdf.feature_tensor[esdf.feature_tensor < -1.0] = -1000.0
Expand Down Expand Up @@ -336,13 +336,16 @@ def benchmark_mb(
world.randomize_color(r=[0.5, 0.9], g=[0.2, 0.5], b=[0.0, 0.2])

coll_mesh = mg.world_coll_checker.get_mesh_in_bounding_box(
curobo_Cuboid(name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[2, 2, 2]),
voxel_size=0.005,
curobo_Cuboid(
name="test", pose=[0, 0, 0, 1, 0, 0, 0], dims=[2.4, 2.4, 2.4]
),
voxel_size=0.02,
)

coll_mesh.color = [0.0, 0.8, 0.8, 0.2]

coll_mesh.name = "voxel_world"

# world = WorldConfig(mesh=[coll_mesh])
world.add_obstacle(coll_mesh)
# get costs:
Expand All @@ -360,7 +363,7 @@ def benchmark_mb(
plot_cost_iteration(
traj_cost,
title=problem_name + "_" + str(success) + "_" + str(dt),
save_path=join_path("benchmark/log/plot/", problem_name + "_cost"),
save_path=join_path("benchmark/log/plot/", problem_name + "_cost")[1:],
log_scale=False,
)
if "finetune_trajopt_result" in result.debug_info:
Expand All @@ -373,7 +376,7 @@ def benchmark_mb(
title=problem_name + "_" + str(success) + "_" + str(dt),
save_path=join_path(
"benchmark/log/plot/", problem_name + "_ft_cost"
),
)[1:],
log_scale=False,
)
if result.success.item():
Expand Down Expand Up @@ -481,7 +484,7 @@ def benchmark_mb(
start_state,
q_traj,
dt=result.interpolation_dt,
save_path=join_path("benchmark/log/usd/", problem_name) + ".usd",
save_path=join_path("benchmark/log/usd/", problem_name)[1:] + ".usd",
interpolation_steps=1,
write_robot_usd_path="benchmark/log/usd/assets/",
robot_usd_local_reference="assets/",
Expand All @@ -499,15 +502,17 @@ def benchmark_mb(
# result.get_interpolated_plan(),
# result.interpolation_dt,
title=problem_name,
save_path=join_path("benchmark/log/plot/", problem_name + ".pdf"),
save_path=join_path("benchmark/log/plot/", problem_name + ".pdf")[1:],
)
plot_traj(
# result.optimized_plan,
# result.optimized_dt.item(),
result.get_interpolated_plan(),
result.interpolation_dt,
title=problem_name,
save_path=join_path("benchmark/log/plot/", problem_name + "_int.pdf"),
save_path=join_path("benchmark/log/plot/", problem_name + "_int.pdf")[
1:
],
)

m_list.append(current_metrics)
Expand Down
69 changes: 0 additions & 69 deletions docker/ros1_x86.dockerfile

This file was deleted.

22 changes: 0 additions & 22 deletions docker/start_docker_arm64.sh

This file was deleted.

24 changes: 0 additions & 24 deletions docker/start_docker_x86_robot.sh

This file was deleted.

38 changes: 27 additions & 11 deletions examples/usd_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose
from curobo.types.robot import JointState, RobotConfig
from curobo.util.logger import setup_curobo_logger
from curobo.util.logger import setup_curobo_logger, log_error
from curobo.util.usd_helper import UsdHelper
from curobo.util_file import (
get_assets_path,
Expand All @@ -44,7 +44,7 @@ def save_curobo_world_to_usd():
usd_helper.write_stage_to_file("test.usd")


def get_trajectory(robot_file="franka.yml", dt=1.0 / 60.0):
def get_trajectory(robot_file="franka.yml", dt=1.0 / 60.0, plan_grasp: bool = False):
tensor_args = TensorDeviceType()
world_file = "collision_test.yml"
motion_gen_config = MotionGenConfig.load_from_robot_config(
Expand All @@ -60,32 +60,48 @@ def get_trajectory(robot_file="franka.yml", dt=1.0 / 60.0):
interpolation_dt=dt,
)
motion_gen = MotionGen(motion_gen_config)
motion_gen.warmup()
motion_gen.warmup(n_goalset=2)
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_file))["robot_cfg"]
robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args)
retract_cfg = motion_gen.get_retract_config()
state = motion_gen.rollout_fn.compute_kinematics(
JointState.from_position(retract_cfg.view(1, -1))
)
if plan_grasp:
retract_pose = Pose(
state.ee_pos_seq.view(1, -1, 3), quaternion=state.ee_quat_seq.view(1, -1, 4)
)
start_state = JointState.from_position(retract_cfg.view(1, -1).clone())
start_state.position[..., :-2] += 0.5
m_config = MotionGenPlanConfig(False, True)

result = motion_gen.plan_grasp(start_state, retract_pose, m_config.clone())
if not result.success:
log_error("Failed to plan grasp: " + result.status)
traj = result.grasp_interpolated_trajectory
traj2 = result.retract_interpolated_trajectory
traj = traj.stack(traj2).clone()
# result = motion_gen.plan_single(start_state, retract_pose)
# traj = result.get_interpolated_plan() # optimized plan

retract_pose = Pose(state.ee_pos_seq.squeeze(), quaternion=state.ee_quat_seq.squeeze())
start_state = JointState.from_position(retract_cfg.view(1, -1).clone())
start_state.position[..., :-2] += 0.5
result = motion_gen.plan_single(start_state, retract_pose)
traj = result.get_interpolated_plan() # optimized plan
else:
retract_pose = Pose(state.ee_pos_seq.squeeze(), quaternion=state.ee_quat_seq.squeeze())
start_state = JointState.from_position(retract_cfg.view(1, -1).clone())
start_state.position[..., :-2] += 0.5
result = motion_gen.plan_single(start_state, retract_pose)
traj = result.get_interpolated_plan() # optimized plan
return traj


def save_curobo_robot_world_to_usd(robot_file="franka.yml"):
print(robot_file)
def save_curobo_robot_world_to_usd(robot_file="franka.yml", plan_grasp: bool = False):
tensor_args = TensorDeviceType()
world_file = "collision_test.yml"
world_model = WorldConfig.from_dict(
load_yaml(join_path(get_world_configs_path(), world_file))
).get_obb_world()
dt = 1 / 60.0

q_traj = get_trajectory(robot_file, dt)
q_traj = get_trajectory(robot_file, dt, plan_grasp)
if q_traj is not None:
q_start = q_traj[0]
UsdHelper.write_trajectory_animation_with_robot_usd(
Expand Down
6 changes: 5 additions & 1 deletion src/curobo/cuda_robot_model/cuda_robot_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,11 @@ def _build_collision_model(

if not valid_data:
use_experimental_kernel = False
log_warn("Self Collision checks are greater than 32 * 512, using slower kernel")
log_warn(
"Self Collision checks are greater than 32 * 512, using slower kernel."
+ " Number of spheres: "
+ str(self_collision_distance.shape[0])
)
if use_experimental_kernel:
self_coll_matrix = torch.zeros((2), device=self.tensor_args.device, dtype=torch.uint8)
else:
Expand Down
Loading

0 comments on commit 36ea382

Please sign in to comment.