Skip to content

Commit

Permalink
Merge pull request #393 from YuhengZhi/fix-quaternion
Browse files Browse the repository at this point in the history
fix quaternion's gradients in PoseInverse, and a few other warp kernels
  • Loading branch information
balakumar-s authored Nov 17, 2024
2 parents 56fabe5 + a22a2fd commit a1c1106
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions src/curobo/geom/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,7 @@ def compute_pose_inverse(
# write pt:
out_q = wp.transform_get_rotation(t_3)

out_v = wp.vec4()
out_v[0] = out_q[3] # out_q[3]
out_v[1] = out_q[0] # [0]
out_v[2] = out_q[1] # wp.extract(out_q, 1)
out_v[3] = out_q[2] # wp.extract(out_q, 2)
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])

out_position[b_idx] = wp.transform_get_translation(t_3)
out_quat[b_idx] = out_v
Expand All @@ -453,11 +449,7 @@ def compute_matrix_to_quat(
# create a transform from a vector/quaternion:
out_q = wp.quat_from_matrix(in_m)

out_v = wp.vec4()
out_v[0] = out_q[3] # wp.extract(out_q, 3)
out_v[1] = out_q[0] # wp.extract(out_q, 0)
out_v[2] = out_q[1] # wp.extract(out_q, 1)
out_v[3] = out_q[2] # wp.extract(out_q, 2)
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])
# write pt:
out_quat[b_idx] = out_v

Expand Down Expand Up @@ -562,11 +554,7 @@ def compute_batch_pose_multipy(
# write pt:
out_q = wp.transform_get_rotation(t_3)

out_v = wp.vec4()
out_v[0] = out_q[3]
out_v[1] = out_q[0]
out_v[2] = out_q[1]
out_v[3] = out_q[2]
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])

out_position[b_idx] = wp.transform_get_translation(t_3)
out_quat[b_idx] = out_v
Expand Down Expand Up @@ -626,11 +614,7 @@ def compute_pose_multipy(
# write pt:
out_q = wp.transform_get_rotation(t_3)

out_v = wp.vec4()
out_v[0] = out_q[3]
out_v[1] = out_q[0]
out_v[2] = out_q[1]
out_v[3] = out_q[2]
out_v = wp.vec4(out_q[3], out_q[0], out_q[1], out_q[2])

out_position[b_idx] = wp.transform_get_translation(t_3)
out_quat[b_idx] = out_v
Expand Down Expand Up @@ -850,7 +834,7 @@ def forward(
adj_position2: torch.Tensor,
adj_quaternion2: torch.Tensor,
):
b, _ = position.shape
b, _ = position.view(-1, 3).shape

if out_position is None:
out_position = torch.zeros_like(position2)
Expand Down Expand Up @@ -977,7 +961,7 @@ def backward(ctx, grad_out_position, grad_out_quaternion):
g_p2 = adj_position2
if ctx.needs_input_grad[3]:
g_q2 = adj_quaternion2
return g_p1, g_q1, g_p2, g_q2, None, None, None, None
return g_p1, g_q1, g_p2, g_q2, None, None, None, None, None, None


class TransformPose(torch.autograd.Function):
Expand All @@ -997,7 +981,7 @@ def forward(
adj_position2: torch.Tensor,
adj_quaternion2: torch.Tensor,
):
b, _ = position2.shape
b, _ = position2.view(-1, 3).shape
init_warp()
if out_position is None:
out_position = torch.zeros_like(position2)
Expand Down Expand Up @@ -1123,7 +1107,7 @@ def backward(ctx, grad_out_position, grad_out_quaternion):
g_p2 = adj_position2
if ctx.needs_input_grad[3]:
g_q2 = adj_quaternion2
return g_p1, g_q1, g_p2, g_q2, None, None, None, None
return g_p1, g_q1, g_p2, g_q2, None, None, None, None, None, None


class PoseInverse(torch.autograd.Function):
Expand Down Expand Up @@ -1223,8 +1207,6 @@ def backward(ctx, grad_out_position, grad_out_quaternion):
adj_inputs=[
None,
None,
None,
None,
],
adj_outputs=[
None,
Expand All @@ -1239,7 +1221,7 @@ def backward(ctx, grad_out_position, grad_out_quaternion):
if ctx.needs_input_grad[1]:
g_q1 = adj_quaternion

return g_p1, g_q1, None, None
return g_p1, g_q1, None, None, None, None


class QuatToMatrix(torch.autograd.Function):
Expand Down

0 comments on commit a1c1106

Please sign in to comment.