Skip to content

Commit

Permalink
seperate header files
Browse files Browse the repository at this point in the history
  • Loading branch information
JunchenLiu77 committed Nov 7, 2024
1 parent 5293a93 commit 4060f09
Show file tree
Hide file tree
Showing 32 changed files with 732 additions and 686 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
recursive-include gsplat/cuda/csrc *
recursive-include gsplat/cuda/include *
2 changes: 0 additions & 2 deletions gsplat/cuda/csrc/adam.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/compute_sh_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "bindings.h"
#include "helpers.cuh"
#include "spherical_harmonics.cuh"
#include "types.cuh"

Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "transform.cuh"
#include "2dgs.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "transform.cuh"
#include "2dgs.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
4 changes: 4 additions & 0 deletions gsplat/cuda/csrc/fully_fused_projection_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat.cuh"
#include "quat_scale_to_covar_preci.cuh"
#include "proj.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
4 changes: 3 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat_scale_to_covar_preci.cuh"
#include "proj.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "transform.cuh"
#include "2dgs.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
4 changes: 4 additions & 0 deletions gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat.cuh"
#include "quat_scale_to_covar_preci.cuh"
#include "proj.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
4 changes: 3 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat_scale_to_covar_preci.cuh"
#include "proj.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
1 change: 0 additions & 1 deletion gsplat/cuda/csrc/isect_tiles.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "bindings.h"
#include "helpers.cuh"
#include "types.cuh"
#include <cooperative_groups.h>
#include <cub/cub.cuh>
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/proj_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "proj.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/proj_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "proj.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
4 changes: 2 additions & 2 deletions gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat.cuh"
#include "quat_scale_to_covar_preci.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "quat_scale_to_covar_preci.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
1 change: 0 additions & 1 deletion gsplat/cuda/csrc/rasterize_to_indices_in_range.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "bindings.h"
#include "helpers.cuh"
#include "types.cuh"
#include <cooperative_groups.h>
#include <cub/cub.cuh>
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "bindings.h"
#include "helpers.cuh"
#include "types.cuh"
#include "utils.cuh"
#include "2dgs.cuh"
#include <cooperative_groups.h>
#include <cub/cub.cuh>
#include <cuda_runtime.h>
Expand Down
2 changes: 1 addition & 1 deletion gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "bindings.h"
#include "helpers.cuh"
#include "types.cuh"
#include "utils.cuh"
#include "2dgs.cuh"
#include <cooperative_groups.h>
#include <cub/cub.cuh>
#include <cuda_runtime.h>
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "bindings.h"
#include "helpers.cuh"
#include "types.cuh"
#include "utils.cuh"
#include "2dgs.cuh"
#include <cooperative_groups.h>
#include <cub/cub.cuh>
#include <cuda_runtime.h>
Expand Down
1 change: 0 additions & 1 deletion gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "bindings.h"
#include "helpers.cuh"
#include "types.cuh"
#include <cooperative_groups.h>
#include <cub/cub.cuh>
Expand Down
2 changes: 1 addition & 1 deletion gsplat/cuda/csrc/world_to_cam_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/world_to_cam_fwd.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
#include "transform.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
Expand Down
78 changes: 78 additions & 0 deletions gsplat/cuda/include/2dgs.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#ifndef GSPLAT_CUDA_2DGS_CUH
#define GSPLAT_CUDA_2DGS_CUH

#include "types.cuh"
#include "quat.cuh"

#define FILTER_INV_SQUARE 2.0f

namespace gsplat {

template <typename T>
inline __device__ void compute_ray_transforms_aabb_vjp(
const T *ray_transforms,
const T *v_means2d,
const vec3<T> v_normals,
const mat3<T> W,
const mat3<T> P,
const vec3<T> cam_pos,
const vec3<T> mean_c,
const vec4<T> quat,
const vec2<T> scale,
mat3<T> &_v_ray_transforms,
vec4<T> &v_quat,
vec2<T> &v_scale,
vec3<T> &v_mean
) {
if (v_means2d[0] != 0 || v_means2d[1] != 0) {
const T distance = ray_transforms[6] * ray_transforms[6] + ray_transforms[7] * ray_transforms[7] -
ray_transforms[8] * ray_transforms[8];
const T f = 1 / (distance);
const T dpx_dT00 = f * ray_transforms[6];
const T dpx_dT01 = f * ray_transforms[7];
const T dpx_dT02 = -f * ray_transforms[8];
const T dpy_dT10 = f * ray_transforms[6];
const T dpy_dT11 = f * ray_transforms[7];
const T dpy_dT12 = -f * ray_transforms[8];
const T dpx_dT30 = ray_transforms[0] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]);
const T dpx_dT31 = ray_transforms[1] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]);
const T dpx_dT32 = -ray_transforms[2] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]);
const T dpy_dT30 = ray_transforms[3] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]);
const T dpy_dT31 = ray_transforms[4] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]);
const T dpy_dT32 = -ray_transforms[5] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]);

_v_ray_transforms[0][0] += v_means2d[0] * dpx_dT00;
_v_ray_transforms[0][1] += v_means2d[0] * dpx_dT01;
_v_ray_transforms[0][2] += v_means2d[0] * dpx_dT02;
_v_ray_transforms[1][0] += v_means2d[1] * dpy_dT10;
_v_ray_transforms[1][1] += v_means2d[1] * dpy_dT11;
_v_ray_transforms[1][2] += v_means2d[1] * dpy_dT12;
_v_ray_transforms[2][0] += v_means2d[0] * dpx_dT30 + v_means2d[1] * dpy_dT30;
_v_ray_transforms[2][1] += v_means2d[0] * dpx_dT31 + v_means2d[1] * dpy_dT31;
_v_ray_transforms[2][2] += v_means2d[0] * dpx_dT32 + v_means2d[1] * dpy_dT32;
}

mat3<T> R = quat_to_rotmat(quat);
mat3<T> v_M = P * glm::transpose(_v_ray_transforms);
mat3<T> W_t = glm::transpose(W);
mat3<T> v_RS = W_t * v_M;
vec3<T> v_tn = W_t * v_normals;

// dual visible
vec3<T> tn = W * R[2];
T cos = glm::dot(-tn, mean_c);
T multiplier = cos > 0 ? 1 : -1;
v_tn *= multiplier;

mat3<T> v_R = mat3<T>(v_RS[0] * scale[0], v_RS[1] * scale[1], v_tn);

quat_to_rotmat_vjp<T>(quat, v_R, v_quat);
v_scale[0] += (T)glm::dot(v_RS[0], R[0]);
v_scale[1] += (T)glm::dot(v_RS[1], R[1]);

v_mean += v_RS[2];
}

} // namespace gsplat

#endif // GSPLAT_CUDA_2DGS_CUH
8 changes: 4 additions & 4 deletions gsplat/cuda/include/helpers.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef GSPLAT_CUDA_HELPERS_H
#define GSPLAT_CUDA_HELPERS_H
#ifndef GSPLAT_CUDA_HELPERS_CUH
#define GSPLAT_CUDA_HELPERS_CUH

#include "types.cuh"

Expand All @@ -15,7 +15,7 @@ namespace cg = cooperative_groups;

template <uint32_t DIM, class T, class WarpT>
inline __device__ void warpSum(T *val, WarpT &warp) {
GSPLAT_PRAGMA_UNROLL
#pragma unroll
for (uint32_t i = 0; i < DIM; i++) {
val[i] = cg::reduce(warp, val[i], cg::plus<T>());
}
Expand Down Expand Up @@ -79,4 +79,4 @@ template <typename T> __forceinline__ __device__ T sum(vec3<T> a) {

} // namespace gsplat

#endif // GSPLAT_CUDA_HELPERS_H
#endif // GSPLAT_CUDA_HELPERS_CUH
Loading

0 comments on commit 4060f09

Please sign in to comment.