Skip to content

Commit

Permalink
format c++
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Sep 11, 2024
1 parent 1ed34f0 commit 4434f09
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions gsplat/cuda/csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,7 @@ inline __device__ void ortho_proj_vjp(
// df/dx = fx * df/dpixx
// df/dy = fy * df/dpixy
// df/dz = 0
v_mean3d += vec3<T>(
fx * v_mean2d[0],
fy * v_mean2d[1],
0.f
);
v_mean3d += vec3<T>(fx * v_mean2d[0], fy * v_mean2d[1], 0.f);
}

template <typename T>
Expand Down Expand Up @@ -393,26 +389,24 @@ inline __device__ void fisheye_proj(
float eps = 0.0000001f;
float xy_len = glm::length(glm::vec2({x, y})) + eps;
float theta = glm::atan(xy_len, z + eps);
mean2d = vec2<T>({
x * fx * theta / xy_len + cx,
y * fy * theta / xy_len + cy
});
mean2d =
vec2<T>({x * fx * theta / xy_len + cx, y * fy * theta / xy_len + cy});

float x2 = x * x + eps;
float y2 = y * y;
float xy = x * y;
float x2y2 = x2 + y2 ;
float x2y2 = x2 + y2;
float x2y2z2_inv = 1.f / (x2y2 + z * z);

float b = glm::atan(xy_len, z) / xy_len / x2y2;
float a = z * x2y2z2_inv / (x2y2);
mat3x2<T> J = mat3x2<T>(
fx * (x2 * a + y2 * b),
fy * xy * (a - b),
fx * xy * (a - b),
fy * (y2 * a + x2 * b),
- fx * x * x2y2z2_inv,
- fy * y * x2y2z2_inv
fy * xy * (a - b),
fx * xy * (a - b),
fy * (y2 * a + x2 * b),
-fx * x * x2y2z2_inv,
-fy * y * x2y2z2_inv
);
cov2d = J * cov3d * glm::transpose(J);
}
Expand Down Expand Up @@ -450,20 +444,20 @@ inline __device__ void fisheye_proj_vjp(
v_mean3d += vec3<T>(
fx * (x2 * a + y2 * b) * v_mean2d[0] + fy * xy * (a - b) * v_mean2d[1],
fx * xy * (a - b) * v_mean2d[0] + fy * (y2 * a + x2 * b) * v_mean2d[1],
- fx * x * x2y2z2_inv * v_mean2d[0] - fy * y * x2y2z2_inv * v_mean2d[1]
-fx * x * x2y2z2_inv * v_mean2d[0] - fy * y * x2y2z2_inv * v_mean2d[1]
);

const float theta = glm::atan(len_xy, z);
const float J_b = theta / len_xy / x2y2;
const float J_a = z * x2y2z2_inv / (x2y2);
const float J_a = z * x2y2z2_inv / (x2y2);
// mat3x2 is 3 columns x 2 rows.
mat3x2<T> J = mat3x2<T>(
fx * (x2 * J_a + y2 * J_b),
fy * xy * (J_a - J_b), // 1st column
fy * xy * (J_a - J_b), // 1st column
fx * xy * (J_a - J_b),
fy * (y2 * J_a + x2 * J_b), // 2nd column
- fx * x * x2y2z2_inv,
- fy * y * x2y2z2_inv // 3rd column
-fx * x * x2y2z2_inv,
-fy * y * x2y2z2_inv // 3rd column
);
v_cov3d += glm::transpose(J) * v_cov2d * J;

Expand All @@ -474,45 +468,51 @@ inline __device__ void fisheye_proj_vjp(
mat3x2<T> v_J = v_cov2d * J * glm::transpose(cov3d) +
glm::transpose(v_cov2d) * J * cov3d;

float l4 = x2y2z2 * x2y2z2;

float E = - l4 * x2y2 * theta + x2y2z2 * x2y2 * len_xy * z;
float F = 3 * l4 * theta - 3 * x2y2z2 * len_xy * z - 2 * x2y2 * len_xy * z;

float A = x * (3 * E + x2 * F);
float B = y * (E + x2 * F);
float C = x * (E + y2 * F);
float D = y * (3 * E + y2 * F);

float S1 = x2 - y2 - z * z;
float S2 = y2 - x2 - z * z;
float inv1 = x2y2z2_inv * x2y2z2_inv;
float inv2 = inv1 / (x2y2 * x2y2 * len_xy);

float dJ_dx00 = fx * A * inv2;
float dJ_dx01 = fx * B * inv2;
float dJ_dx02 = fx * S1 * inv1;
float dJ_dx10 = fy * B * inv2;
float dJ_dx11 = fy * C * inv2;
float dJ_dx12 = 2.f * fy * xy * inv1;

float dJ_dy00 = dJ_dx01;
float dJ_dy01 = fx * C * inv2;
float dJ_dy02 = 2.f * fx * xy * inv1;
float dJ_dy10 = dJ_dx11;
float dJ_dy11 = fy * D * inv2;
float dJ_dy12 = fy * S2 * inv1;

float dJ_dz00 = dJ_dx02;
float dJ_dz01 = dJ_dy02;
float dJ_dz02 = 2.f * fx * x * z * inv1;
float dJ_dz10 = dJ_dx12;
float dJ_dz11 = dJ_dy12;
float dJ_dz12 = 2.f * fy * y * z * inv1;

float dL_dtx_raw = dJ_dx00 * v_J[0][0] + dJ_dx01 * v_J[1][0] + dJ_dx02 * v_J[2][0] + dJ_dx10 * v_J[0][1] + dJ_dx11 * v_J[1][1] + dJ_dx12 * v_J[2][1];
float dL_dty_raw = dJ_dy00 * v_J[0][0] + dJ_dy01 * v_J[1][0] + dJ_dy02 * v_J[2][0] + dJ_dy10 * v_J[0][1] + dJ_dy11 * v_J[1][1] + dJ_dy12 * v_J[2][1];
float dL_dtz_raw = dJ_dz00 * v_J[0][0] + dJ_dz01 * v_J[1][0] + dJ_dz02 * v_J[2][0] + dJ_dz10 * v_J[0][1] + dJ_dz11 * v_J[1][1] + dJ_dz12 * v_J[2][1];
float l4 = x2y2z2 * x2y2z2;

float E = -l4 * x2y2 * theta + x2y2z2 * x2y2 * len_xy * z;
float F = 3 * l4 * theta - 3 * x2y2z2 * len_xy * z - 2 * x2y2 * len_xy * z;

float A = x * (3 * E + x2 * F);
float B = y * (E + x2 * F);
float C = x * (E + y2 * F);
float D = y * (3 * E + y2 * F);

float S1 = x2 - y2 - z * z;
float S2 = y2 - x2 - z * z;
float inv1 = x2y2z2_inv * x2y2z2_inv;
float inv2 = inv1 / (x2y2 * x2y2 * len_xy);

float dJ_dx00 = fx * A * inv2;
float dJ_dx01 = fx * B * inv2;
float dJ_dx02 = fx * S1 * inv1;
float dJ_dx10 = fy * B * inv2;
float dJ_dx11 = fy * C * inv2;
float dJ_dx12 = 2.f * fy * xy * inv1;

float dJ_dy00 = dJ_dx01;
float dJ_dy01 = fx * C * inv2;
float dJ_dy02 = 2.f * fx * xy * inv1;
float dJ_dy10 = dJ_dx11;
float dJ_dy11 = fy * D * inv2;
float dJ_dy12 = fy * S2 * inv1;

float dJ_dz00 = dJ_dx02;
float dJ_dz01 = dJ_dy02;
float dJ_dz02 = 2.f * fx * x * z * inv1;
float dJ_dz10 = dJ_dx12;
float dJ_dz11 = dJ_dy12;
float dJ_dz12 = 2.f * fy * y * z * inv1;

float dL_dtx_raw = dJ_dx00 * v_J[0][0] + dJ_dx01 * v_J[1][0] +
dJ_dx02 * v_J[2][0] + dJ_dx10 * v_J[0][1] +
dJ_dx11 * v_J[1][1] + dJ_dx12 * v_J[2][1];
float dL_dty_raw = dJ_dy00 * v_J[0][0] + dJ_dy01 * v_J[1][0] +
dJ_dy02 * v_J[2][0] + dJ_dy10 * v_J[0][1] +
dJ_dy11 * v_J[1][1] + dJ_dy12 * v_J[2][1];
float dL_dtz_raw = dJ_dz00 * v_J[0][0] + dJ_dz01 * v_J[1][0] +
dJ_dz02 * v_J[2][0] + dJ_dz10 * v_J[0][1] +
dJ_dz11 * v_J[1][1] + dJ_dz12 * v_J[2][1];
v_mean3d.x += dL_dtx_raw;
v_mean3d.y += dL_dty_raw;
v_mean3d.z += dL_dtz_raw;
Expand Down

0 comments on commit 4434f09

Please sign in to comment.