From e6d53183786d866af39c6baedb06db2ac0144517 Mon Sep 17 00:00:00 2001 From: Junchen Liu Date: Fri, 8 Nov 2024 16:49:09 -0800 Subject: [PATCH] Reorganize file structure (#480) * minor fix * create include folder * seperate header files --- MANIFEST.in | 1 + gsplat/cuda/_backend.py | 2 +- gsplat/cuda/csrc/adam.cu | 2 - gsplat/cuda/csrc/compute_sh_bwd.cu | 1 + .../csrc/fully_fused_projection_2dgs_bwd.cu | 3 +- .../csrc/fully_fused_projection_2dgs_fwd.cu | 3 +- .../cuda/csrc/fully_fused_projection_bwd.cu | 4 + .../cuda/csrc/fully_fused_projection_fwd.cu | 4 +- .../fully_fused_projection_packed_2dgs_bwd.cu | 3 +- .../fully_fused_projection_packed_2dgs_fwd.cu | 3 +- .../csrc/fully_fused_projection_packed_bwd.cu | 4 + .../csrc/fully_fused_projection_packed_fwd.cu | 4 +- gsplat/cuda/csrc/isect_tiles.cu | 1 - gsplat/cuda/csrc/proj_bwd.cu | 3 +- gsplat/cuda/csrc/proj_fwd.cu | 3 +- .../csrc/quat_scale_to_covar_preci_bwd.cu | 4 +- .../csrc/quat_scale_to_covar_preci_fwd.cu | 3 +- .../csrc/rasterize_to_indices_in_range.cu | 1 - .../rasterize_to_indices_in_range_2dgs.cu | 3 +- .../cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu | 2 +- .../cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu | 3 +- gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu | 1 - gsplat/cuda/csrc/third_party/glm | 2 +- gsplat/cuda/csrc/utils.cuh | 715 ------------------ gsplat/cuda/csrc/world_to_cam_bwd.cu | 2 +- gsplat/cuda/csrc/world_to_cam_fwd.cu | 3 +- gsplat/cuda/include/2dgs.cuh | 78 ++ gsplat/cuda/{csrc => include}/bindings.h | 0 gsplat/cuda/{csrc => include}/helpers.cuh | 8 +- gsplat/cuda/include/proj.cuh | 347 +++++++++ gsplat/cuda/include/quat.cuh | 61 ++ .../include/quat_scale_to_covar_preci.cuh | 128 ++++ .../{csrc => include}/spherical_harmonics.cuh | 4 - gsplat/cuda/include/transform.cuh | 73 ++ gsplat/cuda/{csrc => include}/types.cuh | 12 +- gsplat/cuda/include/utils.cuh | 77 ++ setup.py | 21 +- tests/test_basic.py | 8 +- 38 files changed, 824 insertions(+), 773 deletions(-) delete mode 100644 gsplat/cuda/csrc/utils.cuh create mode 100644 gsplat/cuda/include/2dgs.cuh rename gsplat/cuda/{csrc => include}/bindings.h (100%) rename gsplat/cuda/{csrc => include}/helpers.cuh (95%) create mode 100644 gsplat/cuda/include/proj.cuh create mode 100644 gsplat/cuda/include/quat.cuh create mode 100644 gsplat/cuda/include/quat_scale_to_covar_preci.cuh rename gsplat/cuda/{csrc => include}/spherical_harmonics.cuh (99%) create mode 100644 gsplat/cuda/include/transform.cuh rename gsplat/cuda/{csrc => include}/types.cuh (82%) create mode 100644 gsplat/cuda/include/utils.cuh diff --git a/MANIFEST.in b/MANIFEST.in index 34427bf24..7c58b11c6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ recursive-include gsplat/cuda/csrc * +recursive-include gsplat/cuda/include * \ No newline at end of file diff --git a/gsplat/cuda/_backend.py b/gsplat/cuda/_backend.py index b646c0bcb..c4b281c4a 100644 --- a/gsplat/cuda/_backend.py +++ b/gsplat/cuda/_backend.py @@ -89,7 +89,7 @@ def cuda_toolkit_version(): current_dir = os.path.dirname(os.path.abspath(__file__)) glm_path = os.path.join(current_dir, "csrc", "third_party", "glm") - extra_include_paths = [os.path.join(PATH, "csrc/"), glm_path] + extra_include_paths = [os.path.join(PATH, "include/"), glm_path] extra_cflags = ["-O3"] if NO_FAST_MATH: extra_cuda_cflags = ["-O3"] diff --git a/gsplat/cuda/csrc/adam.cu b/gsplat/cuda/csrc/adam.cu index 3018cfd95..7462432ea 100644 --- a/gsplat/cuda/csrc/adam.cu +++ b/gsplat/cuda/csrc/adam.cu @@ -1,6 +1,4 @@ #include "bindings.h" -#include "helpers.cuh" -#include "utils.cuh" #include #include diff --git a/gsplat/cuda/csrc/compute_sh_bwd.cu b/gsplat/cuda/csrc/compute_sh_bwd.cu index 5b22cbdbe..623e788bb 100644 --- a/gsplat/cuda/csrc/compute_sh_bwd.cu +++ b/gsplat/cuda/csrc/compute_sh_bwd.cu @@ -1,4 +1,5 @@ #include "bindings.h" +#include "helpers.cuh" #include "spherical_harmonics.cuh" #include "types.cuh" diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu index ec7eb1126..ea6a0e13b 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu @@ -1,6 +1,7 @@ #include "bindings.h" #include "helpers.cuh" -#include "utils.cuh" +#include "transform.cuh" +#include "2dgs.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu index d9beedc36..bc56528ff 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu @@ -1,6 +1,7 @@ #include "bindings.h" #include "helpers.cuh" -#include "utils.cuh" +#include "transform.cuh" +#include "2dgs.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index b5757ff40..1de56b182 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -1,6 +1,10 @@ #include "bindings.h" #include "helpers.cuh" #include "utils.cuh" +#include "quat.cuh" +#include "quat_scale_to_covar_preci.cuh" +#include "proj.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index c651e803d..16600eec7 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -1,6 +1,8 @@ #include "bindings.h" -#include "helpers.cuh" #include "utils.cuh" +#include "quat_scale_to_covar_preci.cuh" +#include "proj.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu index 564eb70bd..b969db910 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu @@ -1,6 +1,7 @@ #include "bindings.h" #include "helpers.cuh" -#include "utils.cuh" +#include "transform.cuh" +#include "2dgs.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu index 34ac796fb..7024803e9 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu @@ -1,7 +1,8 @@ #include "bindings.h" #include "helpers.cuh" -#include "utils.cuh" +#include "quat.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index e5a0172fe..9a5361486 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -1,6 +1,10 @@ #include "bindings.h" #include "helpers.cuh" #include "utils.cuh" +#include "quat.cuh" +#include "quat_scale_to_covar_preci.cuh" +#include "proj.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 4d8609f05..46d616bed 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -1,6 +1,8 @@ #include "bindings.h" -#include "helpers.cuh" #include "utils.cuh" +#include "quat_scale_to_covar_preci.cuh" +#include "proj.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/csrc/isect_tiles.cu b/gsplat/cuda/csrc/isect_tiles.cu index f61d8f2ca..ed45db980 100644 --- a/gsplat/cuda/csrc/isect_tiles.cu +++ b/gsplat/cuda/csrc/isect_tiles.cu @@ -1,5 +1,4 @@ #include "bindings.h" -#include "helpers.cuh" #include "types.cuh" #include #include diff --git a/gsplat/cuda/csrc/proj_bwd.cu b/gsplat/cuda/csrc/proj_bwd.cu index 66557f679..d602548bb 100644 --- a/gsplat/cuda/csrc/proj_bwd.cu +++ b/gsplat/cuda/csrc/proj_bwd.cu @@ -1,6 +1,5 @@ #include "bindings.h" -#include "helpers.cuh" -#include "utils.cuh" +#include "proj.cuh" #include #include diff --git a/gsplat/cuda/csrc/proj_fwd.cu b/gsplat/cuda/csrc/proj_fwd.cu index 861f60479..161702cc3 100644 --- a/gsplat/cuda/csrc/proj_fwd.cu +++ b/gsplat/cuda/csrc/proj_fwd.cu @@ -1,6 +1,5 @@ #include "bindings.h" -#include "helpers.cuh" -#include "utils.cuh" +#include "proj.cuh" #include #include diff --git a/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu b/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu index 946355216..283b714aa 100644 --- a/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu +++ b/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu @@ -1,6 +1,6 @@ #include "bindings.h" -#include "helpers.cuh" -#include "utils.cuh" +#include "quat.cuh" +#include "quat_scale_to_covar_preci.cuh" #include #include diff --git a/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu b/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu index f2d613d56..7fbf1d871 100644 --- a/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu +++ b/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu @@ -1,6 +1,5 @@ #include "bindings.h" -#include "helpers.cuh" -#include "utils.cuh" +#include "quat_scale_to_covar_preci.cuh" #include #include diff --git a/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu b/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu index a773e2e17..895699c52 100644 --- a/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu +++ b/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu @@ -1,5 +1,4 @@ #include "bindings.h" -#include "helpers.cuh" #include "types.cuh" #include #include diff --git a/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu b/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu index 1432a1666..85afdd0f3 100644 --- a/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu +++ b/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu @@ -1,7 +1,6 @@ #include "bindings.h" -#include "helpers.cuh" #include "types.cuh" -#include "utils.cuh" +#include "2dgs.cuh" #include #include #include diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu index fd77480bd..eb61c62cd 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu @@ -1,7 +1,7 @@ #include "bindings.h" #include "helpers.cuh" #include "types.cuh" -#include "utils.cuh" +#include "2dgs.cuh" #include #include #include diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu index c36d7605c..e16416223 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu @@ -1,7 +1,6 @@ #include "bindings.h" -#include "helpers.cuh" #include "types.cuh" -#include "utils.cuh" +#include "2dgs.cuh" #include #include #include diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu index 6c7400703..d1c323424 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu @@ -1,5 +1,4 @@ #include "bindings.h" -#include "helpers.cuh" #include "types.cuh" #include #include diff --git a/gsplat/cuda/csrc/third_party/glm b/gsplat/cuda/csrc/third_party/glm index 45008b225..33b4a621a 160000 --- a/gsplat/cuda/csrc/third_party/glm +++ b/gsplat/cuda/csrc/third_party/glm @@ -1 +1 @@ -Subproject commit 45008b225e28eb700fa0f7d3ff69b7c1db94fadf +Subproject commit 33b4a621a697a305bc3a7610d290677b96beb181 diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh deleted file mode 100644 index d6879ea6c..000000000 --- a/gsplat/cuda/csrc/utils.cuh +++ /dev/null @@ -1,715 +0,0 @@ -#ifndef GSPLAT_CUDA_UTILS_H -#define GSPLAT_CUDA_UTILS_H - -#include "helpers.cuh" - -#include -#include - -#define FILTER_INV_SQUARE 2.0f - -namespace gsplat { - -template -inline __device__ mat3 quat_to_rotmat(const vec4 quat) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - // normalize - T inv_norm = rsqrt(x * x + y * y + z * z + w * w); - x *= inv_norm; - y *= inv_norm; - z *= inv_norm; - w *= inv_norm; - T x2 = x * x, y2 = y * y, z2 = z * z; - T xy = x * y, xz = x * z, yz = y * z; - T wx = w * x, wy = w * y, wz = w * z; - return mat3( - (1.f - 2.f * (y2 + z2)), - (2.f * (xy + wz)), - (2.f * (xz - wy)), // 1st col - (2.f * (xy - wz)), - (1.f - 2.f * (x2 + z2)), - (2.f * (yz + wx)), // 2nd col - (2.f * (xz + wy)), - (2.f * (yz - wx)), - (1.f - 2.f * (x2 + y2)) // 3rd col - ); -} - -template -inline __device__ void -quat_to_rotmat_vjp(const vec4 quat, const mat3 v_R, vec4 &v_quat) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - // normalize - T inv_norm = rsqrt(x * x + y * y + z * z + w * w); - x *= inv_norm; - y *= inv_norm; - z *= inv_norm; - w *= inv_norm; - vec4 v_quat_n = vec4( - 2.f * (x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + - z * (v_R[0][1] - v_R[1][0])), - 2.f * - (-2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + - z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1])), - 2.f * (x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + - z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2])), - 2.f * (x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - - 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0])) - ); - - vec4 quat_n = vec4(w, x, y, z); - v_quat += (v_quat_n - glm::dot(v_quat_n, quat_n) * quat_n) * inv_norm; -} - -template -inline __device__ void quat_scale_to_covar_preci( - const vec4 quat, - const vec3 scale, - // optional outputs - mat3 *covar, - mat3 *preci -) { - mat3 R = quat_to_rotmat(quat); - if (covar != nullptr) { - // C = R * S * S * Rt - mat3 S = - mat3(scale[0], 0.f, 0.f, 0.f, scale[1], 0.f, 0.f, 0.f, scale[2]); - mat3 M = R * S; - *covar = M * glm::transpose(M); - } - if (preci != nullptr) { - // P = R * S^-1 * S^-1 * Rt - mat3 S = mat3( - 1.0f / scale[0], - 0.f, - 0.f, - 0.f, - 1.0f / scale[1], - 0.f, - 0.f, - 0.f, - 1.0f / scale[2] - ); - mat3 M = R * S; - *preci = M * glm::transpose(M); - } -} - -template -inline __device__ void quat_scale_to_covar_vjp( - // fwd inputs - const vec4 quat, - const vec3 scale, - // precompute - const mat3 R, - // grad outputs - const mat3 v_covar, - // grad inputs - vec4 &v_quat, - vec3 &v_scale -) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - T sx = scale[0], sy = scale[1], sz = scale[2]; - - // M = R * S - mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); - mat3 M = R * S; - - // https://math.stackexchange.com/a/3850121 - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - // so - // for D = M * Mt, - // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M - mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; - mat3 v_R = v_M * S; - - // grad for (quat, scale) from covar - quat_to_rotmat_vjp(quat, v_R, v_quat); - - v_scale[0] += - R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]; - v_scale[1] += - R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]; - v_scale[2] += - R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; -} - -template -inline __device__ void quat_scale_to_preci_vjp( - // fwd inputs - const vec4 quat, - const vec3 scale, - // precompute - const mat3 R, - // grad outputs - const mat3 v_preci, - // grad inputs - vec4 &v_quat, - vec3 &v_scale -) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - T sx = 1.0f / scale[0], sy = 1.0f / scale[1], sz = 1.0f / scale[2]; - - // M = R * S - mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); - mat3 M = R * S; - - // https://math.stackexchange.com/a/3850121 - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - // so - // for D = M * Mt, - // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M - mat3 v_M = (v_preci + glm::transpose(v_preci)) * M; - mat3 v_R = v_M * S; - - // grad for (quat, scale) from preci - quat_to_rotmat_vjp(quat, v_R, v_quat); - - v_scale[0] += - -sx * sx * - (R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]); - v_scale[1] += - -sy * sy * - (R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]); - v_scale[2] += - -sz * sz * - (R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]); -} - -template -inline __device__ void ortho_proj( - // inputs - const vec3 mean3d, - const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, - const uint32_t width, - const uint32_t height, - // outputs - mat2 &cov2d, - vec2 &mean2d -) { - T x = mean3d[0], y = mean3d[1], z = mean3d[2]; - - // mat3x2 is 3 columns x 2 rows. - mat3x2 J = mat3x2( - fx, - 0.f, // 1st column - 0.f, - fy, // 2nd column - 0.f, - 0.f // 3rd column - ); - cov2d = J * cov3d * glm::transpose(J); - mean2d = vec2({fx * x + cx, fy * y + cy}); -} - -template -inline __device__ void ortho_proj_vjp( - // fwd inputs - const vec3 mean3d, - const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, - const uint32_t width, - const uint32_t height, - // grad outputs - const mat2 v_cov2d, - const vec2 v_mean2d, - // grad inputs - vec3 &v_mean3d, - mat3 &v_cov3d -) { - T x = mean3d[0], y = mean3d[1], z = mean3d[2]; - - // mat3x2 is 3 columns x 2 rows. - mat3x2 J = mat3x2( - fx, - 0.f, // 1st column - 0.f, - fy, // 2nd column - 0.f, - 0.f // 3rd column - ); - - // cov = J * V * Jt; G = df/dcov = v_cov - // -> df/dV = Jt * G * J - // -> df/dJ = G * J * Vt + Gt * J * V - v_cov3d += glm::transpose(J) * v_cov2d * J; - - // df/dx = fx * df/dpixx - // df/dy = fy * df/dpixy - // df/dz = 0 - v_mean3d += vec3(fx * v_mean2d[0], fy * v_mean2d[1], 0.f); -} - -template -inline __device__ void persp_proj( - // inputs - const vec3 mean3d, - const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, - const uint32_t width, - const uint32_t height, - // outputs - mat2 &cov2d, - vec2 &mean2d -) { - T x = mean3d[0], y = mean3d[1], z = mean3d[2]; - - T tan_fovx = 0.5f * width / fx; - T tan_fovy = 0.5f * height / fy; - T lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; - T lim_x_neg = cx / fx + 0.3f * tan_fovx; - T lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; - T lim_y_neg = cy / fy + 0.3f * tan_fovy; - - T rz = 1.f / z; - T rz2 = rz * rz; - T tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); - T ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); - - // mat3x2 is 3 columns x 2 rows. - mat3x2 J = mat3x2( - fx * rz, - 0.f, // 1st column - 0.f, - fy * rz, // 2nd column - -fx * tx * rz2, - -fy * ty * rz2 // 3rd column - ); - cov2d = J * cov3d * glm::transpose(J); - mean2d = vec2({fx * x * rz + cx, fy * y * rz + cy}); -} - -template -inline __device__ void persp_proj_vjp( - // fwd inputs - const vec3 mean3d, - const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, - const uint32_t width, - const uint32_t height, - // grad outputs - const mat2 v_cov2d, - const vec2 v_mean2d, - // grad inputs - vec3 &v_mean3d, - mat3 &v_cov3d -) { - T x = mean3d[0], y = mean3d[1], z = mean3d[2]; - - T tan_fovx = 0.5f * width / fx; - T tan_fovy = 0.5f * height / fy; - T lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; - T lim_x_neg = cx / fx + 0.3f * tan_fovx; - T lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; - T lim_y_neg = cy / fy + 0.3f * tan_fovy; - - T rz = 1.f / z; - T rz2 = rz * rz; - T tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); - T ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); - - // mat3x2 is 3 columns x 2 rows. - mat3x2 J = mat3x2( - fx * rz, - 0.f, // 1st column - 0.f, - fy * rz, // 2nd column - -fx * tx * rz2, - -fy * ty * rz2 // 3rd column - ); - - // cov = J * V * Jt; G = df/dcov = v_cov - // -> df/dV = Jt * G * J - // -> df/dJ = G * J * Vt + Gt * J * V - v_cov3d += glm::transpose(J) * v_cov2d * J; - - // df/dx = fx * rz * df/dpixx - // df/dy = fy * rz * df/dpixy - // df/dz = - fx * mean.x * rz2 * df/dpixx - fy * mean.y * rz2 * df/dpixy - v_mean3d += vec3( - fx * rz * v_mean2d[0], - fy * rz * v_mean2d[1], - -(fx * x * v_mean2d[0] + fy * y * v_mean2d[1]) * rz2 - ); - - // df/dx = -fx * rz2 * df/dJ_02 - // df/dy = -fy * rz2 * df/dJ_12 - // df/dz = -fx * rz2 * df/dJ_00 - fy * rz2 * df/dJ_11 - // + 2 * fx * tx * rz3 * df/dJ_02 + 2 * fy * ty * rz3 - T rz3 = rz2 * rz; - mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + - glm::transpose(v_cov2d) * J * cov3d; - - // fov clipping - if (x * rz <= lim_x_pos && x * rz >= -lim_x_neg) { - v_mean3d.x += -fx * rz2 * v_J[2][0]; - } else { - v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; - } - if (y * rz <= lim_y_pos && y * rz >= -lim_y_neg) { - v_mean3d.y += -fy * rz2 * v_J[2][1]; - } else { - v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; - } - v_mean3d.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + - 2.f * fx * tx * rz3 * v_J[2][0] + - 2.f * fy * ty * rz3 * v_J[2][1]; -} - -template -inline __device__ void fisheye_proj( - // inputs - const vec3 mean3d, - const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, - const uint32_t width, - const uint32_t height, - // outputs - mat2 &cov2d, - vec2 &mean2d -) { - T x = mean3d[0], y = mean3d[1], z = mean3d[2]; - - T eps = 0.0000001f; - T xy_len = glm::length(glm::vec2({x, y})) + eps; - T theta = glm::atan(xy_len, z + eps); - mean2d = - vec2({x * fx * theta / xy_len + cx, y * fy * theta / xy_len + cy}); - - T x2 = x * x + eps; - T y2 = y * y; - T xy = x * y; - T x2y2 = x2 + y2; - T x2y2z2_inv = 1.f / (x2y2 + z * z); - - T b = glm::atan(xy_len, z) / xy_len / x2y2; - T a = z * x2y2z2_inv / (x2y2); - mat3x2 J = mat3x2( - fx * (x2 * a + y2 * b), - fy * xy * (a - b), - fx * xy * (a - b), - fy * (y2 * a + x2 * b), - -fx * x * x2y2z2_inv, - -fy * y * x2y2z2_inv - ); - cov2d = J * cov3d * glm::transpose(J); -} - -template -inline __device__ void fisheye_proj_vjp( - // fwd inputs - const vec3 mean3d, - const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, - const uint32_t width, - const uint32_t height, - // grad outputs - const mat2 v_cov2d, - const vec2 v_mean2d, - // grad inputs - vec3 &v_mean3d, - mat3 &v_cov3d -) { - T x = mean3d[0], y = mean3d[1], z = mean3d[2]; - - const T eps = 0.0000001f; - T x2 = x * x + eps; - T y2 = y * y; - T xy = x * y; - T x2y2 = x2 + y2; - T len_xy = length(glm::vec2({x, y})) + eps; - const T x2y2z2 = x2y2 + z * z; - T x2y2z2_inv = 1.f / x2y2z2; - T b = glm::atan(len_xy, z) / len_xy / x2y2; - T a = z * x2y2z2_inv / (x2y2); - v_mean3d += vec3( - fx * (x2 * a + y2 * b) * v_mean2d[0] + fy * xy * (a - b) * v_mean2d[1], - fx * xy * (a - b) * v_mean2d[0] + fy * (y2 * a + x2 * b) * v_mean2d[1], - -fx * x * x2y2z2_inv * v_mean2d[0] - fy * y * x2y2z2_inv * v_mean2d[1] - ); - - const T theta = glm::atan(len_xy, z); - const T J_b = theta / len_xy / x2y2; - const T J_a = z * x2y2z2_inv / (x2y2); - // mat3x2 is 3 columns x 2 rows. - mat3x2 J = mat3x2( - fx * (x2 * J_a + y2 * J_b), - fy * xy * (J_a - J_b), // 1st column - fx * xy * (J_a - J_b), - fy * (y2 * J_a + x2 * J_b), // 2nd column - -fx * x * x2y2z2_inv, - -fy * y * x2y2z2_inv // 3rd column - ); - v_cov3d += glm::transpose(J) * v_cov2d * J; - - mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + - glm::transpose(v_cov2d) * J * cov3d; - T l4 = x2y2z2 * x2y2z2; - - T E = -l4 * x2y2 * theta + x2y2z2 * x2y2 * len_xy * z; - T F = 3 * l4 * theta - 3 * x2y2z2 * len_xy * z - 2 * x2y2 * len_xy * z; - - T A = x * (3 * E + x2 * F); - T B = y * (E + x2 * F); - T C = x * (E + y2 * F); - T D = y * (3 * E + y2 * F); - - T S1 = x2 - y2 - z * z; - T S2 = y2 - x2 - z * z; - T inv1 = x2y2z2_inv * x2y2z2_inv; - T inv2 = inv1 / (x2y2 * x2y2 * len_xy); - - T dJ_dx00 = fx * A * inv2; - T dJ_dx01 = fx * B * inv2; - T dJ_dx02 = fx * S1 * inv1; - T dJ_dx10 = fy * B * inv2; - T dJ_dx11 = fy * C * inv2; - T dJ_dx12 = 2.f * fy * xy * inv1; - - T dJ_dy00 = dJ_dx01; - T dJ_dy01 = fx * C * inv2; - T dJ_dy02 = 2.f * fx * xy * inv1; - T dJ_dy10 = dJ_dx11; - T dJ_dy11 = fy * D * inv2; - T dJ_dy12 = fy * S2 * inv1; - - T dJ_dz00 = dJ_dx02; - T dJ_dz01 = dJ_dy02; - T dJ_dz02 = 2.f * fx * x * z * inv1; - T dJ_dz10 = dJ_dx12; - T dJ_dz11 = dJ_dy12; - T dJ_dz12 = 2.f * fy * y * z * inv1; - - T dL_dtx_raw = dJ_dx00 * v_J[0][0] + dJ_dx01 * v_J[1][0] + - dJ_dx02 * v_J[2][0] + dJ_dx10 * v_J[0][1] + - dJ_dx11 * v_J[1][1] + dJ_dx12 * v_J[2][1]; - T dL_dty_raw = dJ_dy00 * v_J[0][0] + dJ_dy01 * v_J[1][0] + - dJ_dy02 * v_J[2][0] + dJ_dy10 * v_J[0][1] + - dJ_dy11 * v_J[1][1] + dJ_dy12 * v_J[2][1]; - T dL_dtz_raw = dJ_dz00 * v_J[0][0] + dJ_dz01 * v_J[1][0] + - dJ_dz02 * v_J[2][0] + dJ_dz10 * v_J[0][1] + - dJ_dz11 * v_J[1][1] + dJ_dz12 * v_J[2][1]; - v_mean3d.x += dL_dtx_raw; - v_mean3d.y += dL_dty_raw; - v_mean3d.z += dL_dtz_raw; -} - -template -inline __device__ void pos_world_to_cam( - // [R, t] is the world-to-camera transformation - const mat3 R, - const vec3 t, - const vec3 p, - vec3 &p_c -) { - p_c = R * p + t; -} - -template -inline __device__ void pos_world_to_cam_vjp( - // fwd inputs - const mat3 R, - const vec3 t, - const vec3 p, - // grad outputs - const vec3 v_p_c, - // grad inputs - mat3 &v_R, - vec3 &v_t, - vec3 &v_p -) { - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - v_R += glm::outerProduct(v_p_c, p); - v_t += v_p_c; - v_p += glm::transpose(R) * v_p_c; -} - -template -inline __device__ void covar_world_to_cam( - // [R, t] is the world-to-camera transformation - const mat3 R, - const mat3 covar, - mat3 &covar_c -) { - covar_c = R * covar * glm::transpose(R); -} - -template -inline __device__ void covar_world_to_cam_vjp( - // fwd inputs - const mat3 R, - const mat3 covar, - // grad outputs - const mat3 v_covar_c, - // grad inputs - mat3 &v_R, - mat3 &v_covar -) { - // for D = W * X * WT, G = df/dD - // df/dX = WT * G * W - // df/dW - // = G * (X * WT)T + ((W * X)T * G)T - // = G * W * XT + (XT * WT * G)T - // = G * W * XT + GT * W * X - v_R += v_covar_c * R * glm::transpose(covar) + - glm::transpose(v_covar_c) * R * covar; - v_covar += glm::transpose(R) * v_covar_c * R; -} - -template -inline __device__ T inverse(const mat2 M, mat2 &Minv) { - T det = M[0][0] * M[1][1] - M[0][1] * M[1][0]; - if (det <= 0.f) { - return det; - } - T invDet = 1.f / det; - Minv[0][0] = M[1][1] * invDet; - Minv[0][1] = -M[0][1] * invDet; - Minv[1][0] = Minv[0][1]; - Minv[1][1] = M[0][0] * invDet; - return det; -} - -template -inline __device__ void inverse_vjp(const T Minv, const T v_Minv, T &v_M) { - // P = M^-1 - // df/dM = -P * df/dP * P - v_M += -Minv * v_Minv * Minv; -} - -template -inline __device__ T add_blur(const T eps2d, mat2 &covar, T &compensation) { - T det_orig = covar[0][0] * covar[1][1] - covar[0][1] * covar[1][0]; - covar[0][0] += eps2d; - covar[1][1] += eps2d; - T det_blur = covar[0][0] * covar[1][1] - covar[0][1] * covar[1][0]; - compensation = sqrt(max(0.f, det_orig / det_blur)); - return det_blur; -} - -template -inline __device__ void add_blur_vjp( - const T eps2d, - const mat2 conic_blur, - const T compensation, - const T v_compensation, - mat2 &v_covar -) { - // comp = sqrt(det(covar) / det(covar_blur)) - - // d [det(M)] / d M = adj(M) - // d [det(M + aI)] / d M = adj(M + aI) = adj(M) + a * I - // d [det(M) / det(M + aI)] / d M - // = (det(M + aI) * adj(M) - det(M) * adj(M + aI)) / (det(M + aI))^2 - // = adj(M) / det(M + aI) - adj(M + aI) / det(M + aI) * comp^2 - // = (adj(M) - adj(M + aI) * comp^2) / det(M + aI) - // given that adj(M + aI) = adj(M) + a * I - // = (adj(M + aI) - aI - adj(M + aI) * comp^2) / det(M + aI) - // given that adj(M) / det(M) = inv(M) - // = (1 - comp^2) * inv(M + aI) - aI / det(M + aI) - // given det(inv(M)) = 1 / det(M) - // = (1 - comp^2) * inv(M + aI) - aI * det(inv(M + aI)) - // = (1 - comp^2) * conic_blur - aI * det(conic_blur) - - T det_conic_blur = conic_blur[0][0] * conic_blur[1][1] - - conic_blur[0][1] * conic_blur[1][0]; - T v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6); - T one_minus_sqr_comp = 1 - compensation * compensation; - v_covar[0][0] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[0][0] - - eps2d * det_conic_blur); - v_covar[0][1] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[0][1]); - v_covar[1][0] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][0]); - v_covar[1][1] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][1] - - eps2d * det_conic_blur); -} - -template -inline __device__ void compute_ray_transforms_aabb_vjp( - const T *ray_transforms, - const T *v_means2d, - const vec3 v_normals, - const mat3 W, - const mat3 P, - const vec3 cam_pos, - const vec3 mean_c, - const vec4 quat, - const vec2 scale, - mat3 &_v_ray_transforms, - vec4 &v_quat, - vec2 &v_scale, - vec3 &v_mean -) { - if (v_means2d[0] != 0 || v_means2d[1] != 0) { - const T distance = ray_transforms[6] * ray_transforms[6] + ray_transforms[7] * ray_transforms[7] - - ray_transforms[8] * ray_transforms[8]; - const T f = 1 / (distance); - const T dpx_dT00 = f * ray_transforms[6]; - const T dpx_dT01 = f * ray_transforms[7]; - const T dpx_dT02 = -f * ray_transforms[8]; - const T dpy_dT10 = f * ray_transforms[6]; - const T dpy_dT11 = f * ray_transforms[7]; - const T dpy_dT12 = -f * ray_transforms[8]; - const T dpx_dT30 = ray_transforms[0] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]); - const T dpx_dT31 = ray_transforms[1] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]); - const T dpx_dT32 = -ray_transforms[2] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]); - const T dpy_dT30 = ray_transforms[3] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]); - const T dpy_dT31 = ray_transforms[4] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]); - const T dpy_dT32 = -ray_transforms[5] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]); - - _v_ray_transforms[0][0] += v_means2d[0] * dpx_dT00; - _v_ray_transforms[0][1] += v_means2d[0] * dpx_dT01; - _v_ray_transforms[0][2] += v_means2d[0] * dpx_dT02; - _v_ray_transforms[1][0] += v_means2d[1] * dpy_dT10; - _v_ray_transforms[1][1] += v_means2d[1] * dpy_dT11; - _v_ray_transforms[1][2] += v_means2d[1] * dpy_dT12; - _v_ray_transforms[2][0] += v_means2d[0] * dpx_dT30 + v_means2d[1] * dpy_dT30; - _v_ray_transforms[2][1] += v_means2d[0] * dpx_dT31 + v_means2d[1] * dpy_dT31; - _v_ray_transforms[2][2] += v_means2d[0] * dpx_dT32 + v_means2d[1] * dpy_dT32; - } - - mat3 R = quat_to_rotmat(quat); - mat3 v_M = P * glm::transpose(_v_ray_transforms); - mat3 W_t = glm::transpose(W); - mat3 v_RS = W_t * v_M; - vec3 v_tn = W_t * v_normals; - - // dual visible - vec3 tn = W * R[2]; - T cos = glm::dot(-tn, mean_c); - T multiplier = cos > 0 ? 1 : -1; - v_tn *= multiplier; - - mat3 v_R = mat3(v_RS[0] * scale[0], v_RS[1] * scale[1], v_tn); - - quat_to_rotmat_vjp(quat, v_R, v_quat); - v_scale[0] += (T)glm::dot(v_RS[0], R[0]); - v_scale[1] += (T)glm::dot(v_RS[1], R[1]); - - v_mean += v_RS[2]; -} - -} // namespace gsplat - -#endif // GSPLAT_CUDA_UTILS_H diff --git a/gsplat/cuda/csrc/world_to_cam_bwd.cu b/gsplat/cuda/csrc/world_to_cam_bwd.cu index 297d7b9f5..956309186 100644 --- a/gsplat/cuda/csrc/world_to_cam_bwd.cu +++ b/gsplat/cuda/csrc/world_to_cam_bwd.cu @@ -1,6 +1,6 @@ #include "bindings.h" #include "helpers.cuh" -#include "utils.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/csrc/world_to_cam_fwd.cu b/gsplat/cuda/csrc/world_to_cam_fwd.cu index ae10e06e1..0e05ef1cf 100644 --- a/gsplat/cuda/csrc/world_to_cam_fwd.cu +++ b/gsplat/cuda/csrc/world_to_cam_fwd.cu @@ -1,6 +1,5 @@ #include "bindings.h" -#include "helpers.cuh" -#include "utils.cuh" +#include "transform.cuh" #include #include diff --git a/gsplat/cuda/include/2dgs.cuh b/gsplat/cuda/include/2dgs.cuh new file mode 100644 index 000000000..2a0f0823f --- /dev/null +++ b/gsplat/cuda/include/2dgs.cuh @@ -0,0 +1,78 @@ +#ifndef GSPLAT_CUDA_2DGS_CUH +#define GSPLAT_CUDA_2DGS_CUH + +#include "types.cuh" +#include "quat.cuh" + +#define FILTER_INV_SQUARE 2.0f + +namespace gsplat { + +template +inline __device__ void compute_ray_transforms_aabb_vjp( + const T *ray_transforms, + const T *v_means2d, + const vec3 v_normals, + const mat3 W, + const mat3 P, + const vec3 cam_pos, + const vec3 mean_c, + const vec4 quat, + const vec2 scale, + mat3 &_v_ray_transforms, + vec4 &v_quat, + vec2 &v_scale, + vec3 &v_mean +) { + if (v_means2d[0] != 0 || v_means2d[1] != 0) { + const T distance = ray_transforms[6] * ray_transforms[6] + ray_transforms[7] * ray_transforms[7] - + ray_transforms[8] * ray_transforms[8]; + const T f = 1 / (distance); + const T dpx_dT00 = f * ray_transforms[6]; + const T dpx_dT01 = f * ray_transforms[7]; + const T dpx_dT02 = -f * ray_transforms[8]; + const T dpy_dT10 = f * ray_transforms[6]; + const T dpy_dT11 = f * ray_transforms[7]; + const T dpy_dT12 = -f * ray_transforms[8]; + const T dpx_dT30 = ray_transforms[0] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]); + const T dpx_dT31 = ray_transforms[1] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]); + const T dpx_dT32 = -ray_transforms[2] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]); + const T dpy_dT30 = ray_transforms[3] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]); + const T dpy_dT31 = ray_transforms[4] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]); + const T dpy_dT32 = -ray_transforms[5] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]); + + _v_ray_transforms[0][0] += v_means2d[0] * dpx_dT00; + _v_ray_transforms[0][1] += v_means2d[0] * dpx_dT01; + _v_ray_transforms[0][2] += v_means2d[0] * dpx_dT02; + _v_ray_transforms[1][0] += v_means2d[1] * dpy_dT10; + _v_ray_transforms[1][1] += v_means2d[1] * dpy_dT11; + _v_ray_transforms[1][2] += v_means2d[1] * dpy_dT12; + _v_ray_transforms[2][0] += v_means2d[0] * dpx_dT30 + v_means2d[1] * dpy_dT30; + _v_ray_transforms[2][1] += v_means2d[0] * dpx_dT31 + v_means2d[1] * dpy_dT31; + _v_ray_transforms[2][2] += v_means2d[0] * dpx_dT32 + v_means2d[1] * dpy_dT32; + } + + mat3 R = quat_to_rotmat(quat); + mat3 v_M = P * glm::transpose(_v_ray_transforms); + mat3 W_t = glm::transpose(W); + mat3 v_RS = W_t * v_M; + vec3 v_tn = W_t * v_normals; + + // dual visible + vec3 tn = W * R[2]; + T cos = glm::dot(-tn, mean_c); + T multiplier = cos > 0 ? 1 : -1; + v_tn *= multiplier; + + mat3 v_R = mat3(v_RS[0] * scale[0], v_RS[1] * scale[1], v_tn); + + quat_to_rotmat_vjp(quat, v_R, v_quat); + v_scale[0] += (T)glm::dot(v_RS[0], R[0]); + v_scale[1] += (T)glm::dot(v_RS[1], R[1]); + + v_mean += v_RS[2]; +} + +} // namespace gsplat + +#endif // GSPLAT_CUDA_2DGS_CUH \ No newline at end of file diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/include/bindings.h similarity index 100% rename from gsplat/cuda/csrc/bindings.h rename to gsplat/cuda/include/bindings.h diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/include/helpers.cuh similarity index 95% rename from gsplat/cuda/csrc/helpers.cuh rename to gsplat/cuda/include/helpers.cuh index 93ff6bd45..09500aca9 100644 --- a/gsplat/cuda/csrc/helpers.cuh +++ b/gsplat/cuda/include/helpers.cuh @@ -1,5 +1,5 @@ -#ifndef GSPLAT_CUDA_HELPERS_H -#define GSPLAT_CUDA_HELPERS_H +#ifndef GSPLAT_CUDA_HELPERS_CUH +#define GSPLAT_CUDA_HELPERS_CUH #include "types.cuh" @@ -15,7 +15,7 @@ namespace cg = cooperative_groups; template inline __device__ void warpSum(T *val, WarpT &warp) { - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t i = 0; i < DIM; i++) { val[i] = cg::reduce(warp, val[i], cg::plus()); } @@ -79,4 +79,4 @@ template __forceinline__ __device__ T sum(vec3 a) { } // namespace gsplat -#endif // GSPLAT_CUDA_HELPERS_H +#endif // GSPLAT_CUDA_HELPERS_CUH diff --git a/gsplat/cuda/include/proj.cuh b/gsplat/cuda/include/proj.cuh new file mode 100644 index 000000000..ad44f298c --- /dev/null +++ b/gsplat/cuda/include/proj.cuh @@ -0,0 +1,347 @@ +#ifndef GSPLAT_CUDA_UTILS_H +#define GSPLAT_CUDA_UTILS_H + +#include "types.cuh" + +namespace gsplat { + +template +inline __device__ void ortho_proj( + // inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // outputs + mat2 &cov2d, + vec2 &mean2d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx, + 0.f, // 1st column + 0.f, + fy, // 2nd column + 0.f, + 0.f // 3rd column + ); + cov2d = J * cov3d * glm::transpose(J); + mean2d = vec2({fx * x + cx, fy * y + cy}); +} + +template +inline __device__ void ortho_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx, + 0.f, // 1st column + 0.f, + fy, // 2nd column + 0.f, + 0.f // 3rd column + ); + + // cov = J * V * Jt; G = df/dcov = v_cov + // -> df/dV = Jt * G * J + // -> df/dJ = G * J * Vt + Gt * J * V + v_cov3d += glm::transpose(J) * v_cov2d * J; + + // df/dx = fx * df/dpixx + // df/dy = fy * df/dpixy + // df/dz = 0 + v_mean3d += vec3(fx * v_mean2d[0], fy * v_mean2d[1], 0.f); +} + +template +inline __device__ void persp_proj( + // inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // outputs + mat2 &cov2d, + vec2 &mean2d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + T tan_fovx = 0.5f * width / fx; + T tan_fovy = 0.5f * height / fy; + T lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; + T lim_x_neg = cx / fx + 0.3f * tan_fovx; + T lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; + T lim_y_neg = cy / fy + 0.3f * tan_fovy; + + T rz = 1.f / z; + T rz2 = rz * rz; + T tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); + T ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx * rz, + 0.f, // 1st column + 0.f, + fy * rz, // 2nd column + -fx * tx * rz2, + -fy * ty * rz2 // 3rd column + ); + cov2d = J * cov3d * glm::transpose(J); + mean2d = vec2({fx * x * rz + cx, fy * y * rz + cy}); +} + +template +inline __device__ void persp_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + T tan_fovx = 0.5f * width / fx; + T tan_fovy = 0.5f * height / fy; + T lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; + T lim_x_neg = cx / fx + 0.3f * tan_fovx; + T lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; + T lim_y_neg = cy / fy + 0.3f * tan_fovy; + + T rz = 1.f / z; + T rz2 = rz * rz; + T tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); + T ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx * rz, + 0.f, // 1st column + 0.f, + fy * rz, // 2nd column + -fx * tx * rz2, + -fy * ty * rz2 // 3rd column + ); + + // cov = J * V * Jt; G = df/dcov = v_cov + // -> df/dV = Jt * G * J + // -> df/dJ = G * J * Vt + Gt * J * V + v_cov3d += glm::transpose(J) * v_cov2d * J; + + // df/dx = fx * rz * df/dpixx + // df/dy = fy * rz * df/dpixy + // df/dz = - fx * mean.x * rz2 * df/dpixx - fy * mean.y * rz2 * df/dpixy + v_mean3d += vec3( + fx * rz * v_mean2d[0], + fy * rz * v_mean2d[1], + -(fx * x * v_mean2d[0] + fy * y * v_mean2d[1]) * rz2 + ); + + // df/dx = -fx * rz2 * df/dJ_02 + // df/dy = -fy * rz2 * df/dJ_12 + // df/dz = -fx * rz2 * df/dJ_00 - fy * rz2 * df/dJ_11 + // + 2 * fx * tx * rz3 * df/dJ_02 + 2 * fy * ty * rz3 + T rz3 = rz2 * rz; + mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + + glm::transpose(v_cov2d) * J * cov3d; + + // fov clipping + if (x * rz <= lim_x_pos && x * rz >= -lim_x_neg) { + v_mean3d.x += -fx * rz2 * v_J[2][0]; + } else { + v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; + } + if (y * rz <= lim_y_pos && y * rz >= -lim_y_neg) { + v_mean3d.y += -fy * rz2 * v_J[2][1]; + } else { + v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; + } + v_mean3d.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + + 2.f * fx * tx * rz3 * v_J[2][0] + + 2.f * fy * ty * rz3 * v_J[2][1]; +} + +template +inline __device__ void fisheye_proj( + // inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // outputs + mat2 &cov2d, + vec2 &mean2d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + T eps = 0.0000001f; + T xy_len = glm::length(glm::vec2({x, y})) + eps; + T theta = glm::atan(xy_len, z + eps); + mean2d = + vec2({x * fx * theta / xy_len + cx, y * fy * theta / xy_len + cy}); + + T x2 = x * x + eps; + T y2 = y * y; + T xy = x * y; + T x2y2 = x2 + y2; + T x2y2z2_inv = 1.f / (x2y2 + z * z); + + T b = glm::atan(xy_len, z) / xy_len / x2y2; + T a = z * x2y2z2_inv / (x2y2); + mat3x2 J = mat3x2( + fx * (x2 * a + y2 * b), + fy * xy * (a - b), + fx * xy * (a - b), + fy * (y2 * a + x2 * b), + -fx * x * x2y2z2_inv, + -fy * y * x2y2z2_inv + ); + cov2d = J * cov3d * glm::transpose(J); +} + +template +inline __device__ void fisheye_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + const T eps = 0.0000001f; + T x2 = x * x + eps; + T y2 = y * y; + T xy = x * y; + T x2y2 = x2 + y2; + T len_xy = length(glm::vec2({x, y})) + eps; + const T x2y2z2 = x2y2 + z * z; + T x2y2z2_inv = 1.f / x2y2z2; + T b = glm::atan(len_xy, z) / len_xy / x2y2; + T a = z * x2y2z2_inv / (x2y2); + v_mean3d += vec3( + fx * (x2 * a + y2 * b) * v_mean2d[0] + fy * xy * (a - b) * v_mean2d[1], + fx * xy * (a - b) * v_mean2d[0] + fy * (y2 * a + x2 * b) * v_mean2d[1], + -fx * x * x2y2z2_inv * v_mean2d[0] - fy * y * x2y2z2_inv * v_mean2d[1] + ); + + const T theta = glm::atan(len_xy, z); + const T J_b = theta / len_xy / x2y2; + const T J_a = z * x2y2z2_inv / (x2y2); + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx * (x2 * J_a + y2 * J_b), + fy * xy * (J_a - J_b), // 1st column + fx * xy * (J_a - J_b), + fy * (y2 * J_a + x2 * J_b), // 2nd column + -fx * x * x2y2z2_inv, + -fy * y * x2y2z2_inv // 3rd column + ); + v_cov3d += glm::transpose(J) * v_cov2d * J; + + mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + + glm::transpose(v_cov2d) * J * cov3d; + T l4 = x2y2z2 * x2y2z2; + + T E = -l4 * x2y2 * theta + x2y2z2 * x2y2 * len_xy * z; + T F = 3 * l4 * theta - 3 * x2y2z2 * len_xy * z - 2 * x2y2 * len_xy * z; + + T A = x * (3 * E + x2 * F); + T B = y * (E + x2 * F); + T C = x * (E + y2 * F); + T D = y * (3 * E + y2 * F); + + T S1 = x2 - y2 - z * z; + T S2 = y2 - x2 - z * z; + T inv1 = x2y2z2_inv * x2y2z2_inv; + T inv2 = inv1 / (x2y2 * x2y2 * len_xy); + + T dJ_dx00 = fx * A * inv2; + T dJ_dx01 = fx * B * inv2; + T dJ_dx02 = fx * S1 * inv1; + T dJ_dx10 = fy * B * inv2; + T dJ_dx11 = fy * C * inv2; + T dJ_dx12 = 2.f * fy * xy * inv1; + + T dJ_dy00 = dJ_dx01; + T dJ_dy01 = fx * C * inv2; + T dJ_dy02 = 2.f * fx * xy * inv1; + T dJ_dy10 = dJ_dx11; + T dJ_dy11 = fy * D * inv2; + T dJ_dy12 = fy * S2 * inv1; + + T dJ_dz00 = dJ_dx02; + T dJ_dz01 = dJ_dy02; + T dJ_dz02 = 2.f * fx * x * z * inv1; + T dJ_dz10 = dJ_dx12; + T dJ_dz11 = dJ_dy12; + T dJ_dz12 = 2.f * fy * y * z * inv1; + + T dL_dtx_raw = dJ_dx00 * v_J[0][0] + dJ_dx01 * v_J[1][0] + + dJ_dx02 * v_J[2][0] + dJ_dx10 * v_J[0][1] + + dJ_dx11 * v_J[1][1] + dJ_dx12 * v_J[2][1]; + T dL_dty_raw = dJ_dy00 * v_J[0][0] + dJ_dy01 * v_J[1][0] + + dJ_dy02 * v_J[2][0] + dJ_dy10 * v_J[0][1] + + dJ_dy11 * v_J[1][1] + dJ_dy12 * v_J[2][1]; + T dL_dtz_raw = dJ_dz00 * v_J[0][0] + dJ_dz01 * v_J[1][0] + + dJ_dz02 * v_J[2][0] + dJ_dz10 * v_J[0][1] + + dJ_dz11 * v_J[1][1] + dJ_dz12 * v_J[2][1]; + v_mean3d.x += dL_dtx_raw; + v_mean3d.y += dL_dty_raw; + v_mean3d.z += dL_dtz_raw; +} + +} // namespace gsplat + +#endif // GSPLAT_CUDA_UTILS_H diff --git a/gsplat/cuda/include/quat.cuh b/gsplat/cuda/include/quat.cuh new file mode 100644 index 000000000..2fa0e941c --- /dev/null +++ b/gsplat/cuda/include/quat.cuh @@ -0,0 +1,61 @@ +#ifndef GSPLAT_CUDA_QUAT_CUH +#define GSPLAT_CUDA_QUAT_CUH + +#include "types.cuh" + +namespace gsplat { + +template +inline __device__ mat3 quat_to_rotmat(const vec4 quat) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + // normalize + T inv_norm = rsqrt(x * x + y * y + z * z + w * w); + x *= inv_norm; + y *= inv_norm; + z *= inv_norm; + w *= inv_norm; + T x2 = x * x, y2 = y * y, z2 = z * z; + T xy = x * y, xz = x * z, yz = y * z; + T wx = w * x, wy = w * y, wz = w * z; + return mat3( + (1.f - 2.f * (y2 + z2)), + (2.f * (xy + wz)), + (2.f * (xz - wy)), // 1st col + (2.f * (xy - wz)), + (1.f - 2.f * (x2 + z2)), + (2.f * (yz + wx)), // 2nd col + (2.f * (xz + wy)), + (2.f * (yz - wx)), + (1.f - 2.f * (x2 + y2)) // 3rd col + ); +} + +template +inline __device__ void +quat_to_rotmat_vjp(const vec4 quat, const mat3 v_R, vec4 &v_quat) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + // normalize + T inv_norm = rsqrt(x * x + y * y + z * z + w * w); + x *= inv_norm; + y *= inv_norm; + z *= inv_norm; + w *= inv_norm; + vec4 v_quat_n = vec4( + 2.f * (x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + + z * (v_R[0][1] - v_R[1][0])), + 2.f * + (-2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + + z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1])), + 2.f * (x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + + z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2])), + 2.f * (x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - + 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0])) + ); + + vec4 quat_n = vec4(w, x, y, z); + v_quat += (v_quat_n - glm::dot(v_quat_n, quat_n) * quat_n) * inv_norm; +} + +} // namespace gsplat + +#endif // GSPLAT_CUDA_QUAT_CUH diff --git a/gsplat/cuda/include/quat_scale_to_covar_preci.cuh b/gsplat/cuda/include/quat_scale_to_covar_preci.cuh new file mode 100644 index 000000000..5c976912f --- /dev/null +++ b/gsplat/cuda/include/quat_scale_to_covar_preci.cuh @@ -0,0 +1,128 @@ +#ifndef GSPLAT_CUDA_QUAT_SCALE_TO_COVAR_PRECI_CUH +#define GSPLAT_CUDA_QUAT_SCALE_TO_COVAR_PRECI_CUH + +#include "types.cuh" +#include "quat.cuh" + +namespace gsplat { + +template +inline __device__ void quat_scale_to_covar_preci( + const vec4 quat, + const vec3 scale, + // optional outputs + mat3 *covar, + mat3 *preci +) { + mat3 R = quat_to_rotmat(quat); + if (covar != nullptr) { + // C = R * S * S * Rt + mat3 S = + mat3(scale[0], 0.f, 0.f, 0.f, scale[1], 0.f, 0.f, 0.f, scale[2]); + mat3 M = R * S; + *covar = M * glm::transpose(M); + } + if (preci != nullptr) { + // P = R * S^-1 * S^-1 * Rt + mat3 S = mat3( + 1.0f / scale[0], + 0.f, + 0.f, + 0.f, + 1.0f / scale[1], + 0.f, + 0.f, + 0.f, + 1.0f / scale[2] + ); + mat3 M = R * S; + *preci = M * glm::transpose(M); + } +} + +template +inline __device__ void quat_scale_to_covar_vjp( + // fwd inputs + const vec4 quat, + const vec3 scale, + // precompute + const mat3 R, + // grad outputs + const mat3 v_covar, + // grad inputs + vec4 &v_quat, + vec3 &v_scale +) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + T sx = scale[0], sy = scale[1], sz = scale[2]; + + // M = R * S + mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); + mat3 M = R * S; + + // https://math.stackexchange.com/a/3850121 + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + // so + // for D = M * Mt, + // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M + mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; + mat3 v_R = v_M * S; + + // grad for (quat, scale) from covar + quat_to_rotmat_vjp(quat, v_R, v_quat); + + v_scale[0] += + R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]; + v_scale[1] += + R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]; + v_scale[2] += + R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; +} + +template +inline __device__ void quat_scale_to_preci_vjp( + // fwd inputs + const vec4 quat, + const vec3 scale, + // precompute + const mat3 R, + // grad outputs + const mat3 v_preci, + // grad inputs + vec4 &v_quat, + vec3 &v_scale +) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + T sx = 1.0f / scale[0], sy = 1.0f / scale[1], sz = 1.0f / scale[2]; + + // M = R * S + mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); + mat3 M = R * S; + + // https://math.stackexchange.com/a/3850121 + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + // so + // for D = M * Mt, + // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M + mat3 v_M = (v_preci + glm::transpose(v_preci)) * M; + mat3 v_R = v_M * S; + + // grad for (quat, scale) from preci + quat_to_rotmat_vjp(quat, v_R, v_quat); + + v_scale[0] += + -sx * sx * + (R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]); + v_scale[1] += + -sy * sy * + (R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]); + v_scale[2] += + -sz * sz * + (R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]); +} + +} // namespace gsplat + +#endif // GSPLAT_CUDA_QUAT_SCALE_TO_COVAR_PRECI_CUH diff --git a/gsplat/cuda/csrc/spherical_harmonics.cuh b/gsplat/cuda/include/spherical_harmonics.cuh similarity index 99% rename from gsplat/cuda/csrc/spherical_harmonics.cuh rename to gsplat/cuda/include/spherical_harmonics.cuh index d36f2ff5a..0e039bf8b 100644 --- a/gsplat/cuda/csrc/spherical_harmonics.cuh +++ b/gsplat/cuda/include/spherical_harmonics.cuh @@ -1,11 +1,7 @@ #ifndef GSPLAT_SPHERICAL_HARMONICS_CUH #define GSPLAT_SPHERICAL_HARMONICS_CUH -#include "bindings.h" #include "types.cuh" -#include "utils.cuh" - -#include namespace gsplat { diff --git a/gsplat/cuda/include/transform.cuh b/gsplat/cuda/include/transform.cuh new file mode 100644 index 000000000..a79ba93d9 --- /dev/null +++ b/gsplat/cuda/include/transform.cuh @@ -0,0 +1,73 @@ +#ifndef GSPLAT_CUDA_TRANSFORM_CUH +#define GSPLAT_CUDA_TRANSFORM_CUH + +#include "types.cuh" + +namespace gsplat { + +template +inline __device__ void pos_world_to_cam( + // [R, t] is the world-to-camera transformation + const mat3 R, + const vec3 t, + const vec3 p, + vec3 &p_c +) { + p_c = R * p + t; +} + +template +inline __device__ void pos_world_to_cam_vjp( + // fwd inputs + const mat3 R, + const vec3 t, + const vec3 p, + // grad outputs + const vec3 v_p_c, + // grad inputs + mat3 &v_R, + vec3 &v_t, + vec3 &v_p +) { + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + v_R += glm::outerProduct(v_p_c, p); + v_t += v_p_c; + v_p += glm::transpose(R) * v_p_c; +} + +template +inline __device__ void covar_world_to_cam( + // [R, t] is the world-to-camera transformation + const mat3 R, + const mat3 covar, + mat3 &covar_c +) { + covar_c = R * covar * glm::transpose(R); +} + +template +inline __device__ void covar_world_to_cam_vjp( + // fwd inputs + const mat3 R, + const mat3 covar, + // grad outputs + const mat3 v_covar_c, + // grad inputs + mat3 &v_R, + mat3 &v_covar +) { + // for D = W * X * WT, G = df/dD + // df/dX = WT * G * W + // df/dW + // = G * (X * WT)T + ((W * X)T * G)T + // = G * W * XT + (XT * WT * G)T + // = G * W * XT + GT * W * X + v_R += v_covar_c * R * glm::transpose(covar) + + glm::transpose(v_covar_c) * R * covar; + v_covar += glm::transpose(R) * v_covar_c * R; +} + +} // namespace gsplat + +#endif // GSPLAT_CUDA_TRANSFORM_CUH diff --git a/gsplat/cuda/csrc/types.cuh b/gsplat/cuda/include/types.cuh similarity index 82% rename from gsplat/cuda/csrc/types.cuh rename to gsplat/cuda/include/types.cuh index 70f0472da..8aaceaa94 100644 --- a/gsplat/cuda/csrc/types.cuh +++ b/gsplat/cuda/include/types.cuh @@ -1,13 +1,7 @@ -#ifndef GSPLAT_CUDA_TYPES_H -#define GSPLAT_CUDA_TYPES_H +#ifndef GSPLAT_CUDA_TYPES_CUH +#define GSPLAT_CUDA_TYPES_CUH -#include -#include -#include - -#include #include - #include namespace gsplat { @@ -48,4 +42,4 @@ template <> struct OpType { } // namespace gsplat -#endif // GSPLAT_CUDA_TYPES_H \ No newline at end of file +#endif // GSPLAT_CUDA_TYPES_CUH \ No newline at end of file diff --git a/gsplat/cuda/include/utils.cuh b/gsplat/cuda/include/utils.cuh new file mode 100644 index 000000000..f07c4e13c --- /dev/null +++ b/gsplat/cuda/include/utils.cuh @@ -0,0 +1,77 @@ +#ifndef GSPLAT_CUDA_UTILS_CUH +#define GSPLAT_CUDA_UTILS_CUH + +#include "types.cuh" + +namespace gsplat { + +template +inline __device__ T inverse(const mat2 M, mat2 &Minv) { + T det = M[0][0] * M[1][1] - M[0][1] * M[1][0]; + if (det <= 0.f) { + return det; + } + T invDet = 1.f / det; + Minv[0][0] = M[1][1] * invDet; + Minv[0][1] = -M[0][1] * invDet; + Minv[1][0] = Minv[0][1]; + Minv[1][1] = M[0][0] * invDet; + return det; +} + +template +inline __device__ void inverse_vjp(const T Minv, const T v_Minv, T &v_M) { + // P = M^-1 + // df/dM = -P * df/dP * P + v_M += -Minv * v_Minv * Minv; +} + +template +inline __device__ T add_blur(const T eps2d, mat2 &covar, T &compensation) { + T det_orig = covar[0][0] * covar[1][1] - covar[0][1] * covar[1][0]; + covar[0][0] += eps2d; + covar[1][1] += eps2d; + T det_blur = covar[0][0] * covar[1][1] - covar[0][1] * covar[1][0]; + compensation = sqrt(max(0.f, det_orig / det_blur)); + return det_blur; +} + +template +inline __device__ void add_blur_vjp( + const T eps2d, + const mat2 conic_blur, + const T compensation, + const T v_compensation, + mat2 &v_covar +) { + // comp = sqrt(det(covar) / det(covar_blur)) + + // d [det(M)] / d M = adj(M) + // d [det(M + aI)] / d M = adj(M + aI) = adj(M) + a * I + // d [det(M) / det(M + aI)] / d M + // = (det(M + aI) * adj(M) - det(M) * adj(M + aI)) / (det(M + aI))^2 + // = adj(M) / det(M + aI) - adj(M + aI) / det(M + aI) * comp^2 + // = (adj(M) - adj(M + aI) * comp^2) / det(M + aI) + // given that adj(M + aI) = adj(M) + a * I + // = (adj(M + aI) - aI - adj(M + aI) * comp^2) / det(M + aI) + // given that adj(M) / det(M) = inv(M) + // = (1 - comp^2) * inv(M + aI) - aI / det(M + aI) + // given det(inv(M)) = 1 / det(M) + // = (1 - comp^2) * inv(M + aI) - aI * det(inv(M + aI)) + // = (1 - comp^2) * conic_blur - aI * det(conic_blur) + + T det_conic_blur = conic_blur[0][0] * conic_blur[1][1] - + conic_blur[0][1] * conic_blur[1][0]; + T v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6); + T one_minus_sqr_comp = 1 - compensation * compensation; + v_covar[0][0] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[0][0] - + eps2d * det_conic_blur); + v_covar[0][1] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[0][1]); + v_covar[1][0] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][0]); + v_covar[1][1] += v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][1] - + eps2d * det_conic_blur); +} + +} // namespace gsplat + +#endif // GSPLAT_CUDA_UTILS_CUH diff --git a/setup.py b/setup.py index ff0dd8e00..dbb14b45a 100644 --- a/setup.py +++ b/setup.py @@ -34,11 +34,11 @@ def get_extensions(): from torch.__config__ import parallel_info from torch.utils.cpp_extension import CUDAExtension - extensions_dir_v2 = osp.join("gsplat", "cuda", "csrc") - sources_v2 = glob.glob(osp.join(extensions_dir_v2, "*.cu")) + glob.glob( - osp.join(extensions_dir_v2, "*.cpp") + extensions_dir = osp.join("gsplat", "cuda") + sources = glob.glob(osp.join(extensions_dir, "csrc", "*.cu")) + glob.glob( + osp.join(extensions_dir, "csrc", "*.cpp") ) - sources_v2 = [path for path in sources_v2 if "hip" not in path] + sources = [path for path in sources if "hip" not in path] undef_macros = [] define_macros = [] @@ -90,18 +90,19 @@ def get_extensions(): extra_compile_args["nvcc"] += ["-DWIN32_LEAN_AND_MEAN"] current_dir = pathlib.Path(__file__).parent.resolve() - glm_path = os.path.join(current_dir, "gsplat", "cuda", "csrc", "third_party", "glm") - extension_v2 = CUDAExtension( + glm_path = osp.join(current_dir, "gsplat", "cuda", "csrc", "third_party", "glm") + include_dirs = [glm_path, osp.join(current_dir, "gsplat", "cuda", "include")] + + extension = CUDAExtension( "gsplat.csrc", - sources_v2, - include_dirs=[extensions_dir_v2, glm_path], # glm lives in v2. + sources, + include_dirs=include_dirs, define_macros=define_macros, undef_macros=undef_macros, extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, ) - - return [extension_v2] + return [extension] setup( diff --git a/tests/test_basic.py b/tests/test_basic.py index 22d2ee227..30dc27759 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -11,6 +11,7 @@ import pytest import torch +import os from gsplat._helper import load_test_data @@ -30,7 +31,10 @@ def test_data(): Ks, width, height, - ) = load_test_data(device=device) + ) = load_test_data( + device=device, + data_path=os.path.join(os.path.dirname(__file__), "../assets/test_garden.npz"), + ) colors = colors[None].repeat(len(viewmats), 1, 1) return { "means": means, @@ -270,7 +274,7 @@ def test_projection( ) torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-2) + torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=2e-2) torch.testing.assert_close(v_scales, _v_scales, rtol=1e-1, atol=2e-1) torch.testing.assert_close(v_means, _v_means, rtol=1e-2, atol=6e-2)