Skip to content

Commit

Permalink
Speed up ND rasterization (#130)
Browse files Browse the repository at this point in the history
* refactor block size to be dynamic, hide tile_bounds from interface, update tests to reflect this change

* allocate as much shared memory as possible based on block size

* revert to old method, replace workspace with shared mem

* speedup on nd backward w/ some warp operations

* add fp16 headers

* 128->3 in simple trainer

* lint

* remove comment

* remove cudaPeekAtLastError
  • Loading branch information
kerrj authored Feb 20, 2024
1 parent 10bc1d0 commit 46bd547
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 144 deletions.
15 changes: 11 additions & 4 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def _init_gaussians(self):

self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5)
self.scales = torch.rand(self.num_points, 3, device=self.device)
self.rgbs = torch.rand(self.num_points, 3, device=self.device)
d = 3
self.rgbs = torch.rand(self.num_points, d, device=self.device)

u = torch.rand(self.num_points, 1, device=self.device)
v = torch.rand(self.num_points, 1, device=self.device)
Expand All @@ -64,7 +65,7 @@ def _init_gaussians(self):
],
device=self.device,
)
self.background = torch.zeros(3, device=self.device)
self.background = torch.zeros(d, device=self.device)

self.means.requires_grad = True
self.scales.requires_grad = True
Expand All @@ -73,7 +74,13 @@ def _init_gaussians(self):
self.opacities.requires_grad = True
self.viewmat.requires_grad = False

def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = False):
def train(
self,
iterations: int = 1000,
lr: float = 0.01,
save_imgs: bool = False,
B_SIZE: int = 14,
):
optimizer = optim.Adam(
[self.rgbs, self.means, self.scales, self.opacities, self.quats], lr
)
Expand Down Expand Up @@ -121,7 +128,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
self.W,
B_SIZE,
self.background,
)
)[..., :3]
torch.cuda.synchronize()
times[1] += time.time() - start
loss = mse_loss(out_img, self.gt_image)
Expand Down
157 changes: 80 additions & 77 deletions gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
#include "backward.cuh"
#include "helpers.cuh"
#include <cuda_fp16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

inline __device__ void warpSum3(float3& val, cg::thread_block_tile<32>& tile){
val.x = cg::reduce(tile, val.x, cg::plus<float>());
val.y = cg::reduce(tile, val.y, cg::plus<float>());
val.z = cg::reduce(tile, val.z, cg::plus<float>());
}

inline __device__ void warpSum2(float2& val, cg::thread_block_tile<32>& tile){
val.x = cg::reduce(tile, val.x, cg::plus<float>());
val.y = cg::reduce(tile, val.y, cg::plus<float>());
}

inline __device__ void warpSum(float& val, cg::thread_block_tile<32>& tile){
val = cg::reduce(tile, val, cg::plus<float>());
}

__global__ void nd_rasterize_backward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
Expand All @@ -22,49 +38,40 @@ __global__ void nd_rasterize_backward_kernel(
float2* __restrict__ v_xy,
float3* __restrict__ v_conic,
float* __restrict__ v_rgb,
float* __restrict__ v_opacity,
float* __restrict__ workspace
float* __restrict__ v_opacity
) {
if (channels > MAX_REGISTER_CHANNELS && workspace == nullptr) {
return;
}
// current naive implementation where tile data loading is redundant
// TODO tile data should be shared between tile threads
auto block = cg::this_thread_block();
const int tr = block.thread_rank();
int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x;
unsigned i = blockIdx.y * blockDim.y + threadIdx.y;
unsigned j = blockIdx.x * blockDim.x + threadIdx.x;
float px = (float)j;
float py = (float)i;
int32_t pix_id = i * img_size.x + j;

// return if out of bounds
if (i >= img_size.y || j >= img_size.x) {
return;
}
const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1);

// keep not rasterizing threads around for reading data
const bool inside = (i < img_size.y && j < img_size.x);
// which gaussians get gradients for this pixel
int2 range = tile_bins[tile_id];
const int2 range = tile_bins[tile_id];
// df/d_out for this pixel
const float *v_out = &(v_output[channels * pix_id]);
const float v_out_alpha = v_output_alpha[pix_id];
// this is the T AFTER the last gaussian in this pixel
float T_final = final_Ts[pix_id];
float T = T_final;
// the contribution from gaussians behind the current one
float buffer[MAX_REGISTER_CHANNELS] = {0.f};
float *S;
if (channels <= MAX_REGISTER_CHANNELS) {
S = &buffer[0];
} else {
S = &workspace[channels * pix_id];
}
int bin_final = final_index[pix_id];

extern __shared__ half workspace[];

// iterate backward to compute the jacobians wrt rgb, opacity, mean2d, and
// conic recursively compute T_{n-1} from T_n, where T_i = prod(j < i) (1 -
// alpha_j), and S_{n-1} from S_n, where S_j = sum_{i > j}(rgb_i * alpha_i *
// T_i) df/dalpha_i = rgb_i * T_i - S_{i+1| / (1 - alpha_i)
for (int idx = bin_final - 1; idx >= range.x; --idx) {
half *S = (half*)(&workspace[channels * tr]);
for(int c=0; c<channels; ++c){
S[c] = __float2half(0.f);
}
const int bin_final = inside ? final_index[pix_id] : 0;
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
const int warp_bin_final = cg::reduce(warp, bin_final, cg::greater<int>());
for (int idx = warp_bin_final - 1; idx >= range.x; --idx) {
int valid = inside && idx < bin_final;
const int32_t g = gaussians_ids_sorted[idx];
const float3 conic = conics[g];
const float2 center = xys[g];
Expand All @@ -73,68 +80,64 @@ __global__ void nd_rasterize_backward_kernel(
0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) +
conic.y * delta.x * delta.y;
if (sigma < 0.f) {
continue;
valid = 0;
}
const float opac = opacities[g];
const float vis = __expf(-sigma);
const float alpha = min(0.99f, opac * vis);
if (alpha < 1.f / 255.f) {
valid = 0;
}
if(!warp.any(valid)){
continue;
}

// compute the current T for this gaussian
const float ra = 1.f / (1.f - alpha);
T *= ra;
// rgb = rgbs[g];
// update v_rgb for this gaussian
const float fac = alpha * T;
float v_alpha = 0.f;
for (int c = 0; c < channels; ++c) {
// gradient wrt rgb
atomicAdd(&(v_rgb[channels * g + c]), fac * v_out[c]);
// contribution from this pixel
v_alpha += (rgbs[channels * g + c] * T - S[c] * ra) * v_out[c];
// contribution from background pixel
v_alpha += -T_final * ra * background[c] * v_out[c];
// update the running sum
S[c] += rgbs[channels * g + c] * fac;
float3 v_conic_local = {0.f, 0.f, 0.f};
float2 v_xy_local = {0.f, 0.f};
float v_opacity_local = 0.f;
if(valid){
// compute the current T for this gaussian
const float ra = 1.f / (1.f - alpha);
T *= ra;
// update v_rgb for this gaussian
const float fac = alpha * T;
for (int c = 0; c < channels; ++c) {
// gradient wrt rgb
atomicAdd(&(v_rgb[channels * g + c]), fac * v_out[c]);
// contribution from this pixel
v_alpha += (rgbs[channels * g + c] * T - __half2float(S[c]) * ra) * v_out[c];
// contribution from background pixel
v_alpha += -T_final * ra * background[c] * v_out[c];
// update the running sum
S[c] = __hadd(S[c], __float2half(rgbs[channels * g + c] * fac));
}
v_alpha += T_final * ra * v_out_alpha;
const float v_sigma = -opac * vis * v_alpha;
v_conic_local = {0.5f * v_sigma * delta.x * delta.x,
0.5f * v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y};
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
v_opacity_local = vis * v_alpha;
}
warpSum3(v_conic_local, warp);
warpSum2(v_xy_local, warp);
warpSum(v_opacity_local, warp);
if (warp.thread_rank() == 0) {
float* v_conic_ptr = (float*)(v_conic);
atomicAdd(v_conic_ptr + 3*g + 0, v_conic_local.x);
atomicAdd(v_conic_ptr + 3*g + 1, v_conic_local.y);
atomicAdd(v_conic_ptr + 3*g + 2, v_conic_local.z);

float* v_xy_ptr = (float*)(v_xy);
atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x);
atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y);

atomicAdd(v_opacity + g, v_opacity_local);
}
v_alpha += T_final * ra * v_out_alpha;
// update v_opacity for this gaussian
atomicAdd(&(v_opacity[g]), vis * v_alpha);

// compute vjps for conics and means
// d_sigma / d_delta = conic * delta
// d_sigma / d_conic = delta * delta.T
const float v_sigma = -opac * vis * v_alpha;

atomicAdd(&(v_conic[g].x), 0.5f * v_sigma * delta.x * delta.x);
atomicAdd(&(v_conic[g].y), 0.5f * v_sigma * delta.x * delta.y);
atomicAdd(&(v_conic[g].z), 0.5f * v_sigma * delta.y * delta.y);
atomicAdd(
&(v_xy[g].x), v_sigma * (conic.x * delta.x + conic.y * delta.y)
);
atomicAdd(
&(v_xy[g].y), v_sigma * (conic.y * delta.x + conic.z * delta.y)
);
}
}

inline __device__ void warpSum3(float3& val, cg::thread_block_tile<32>& tile){
val.x = cg::reduce(tile, val.x, cg::plus<float>());
val.y = cg::reduce(tile, val.y, cg::plus<float>());
val.z = cg::reduce(tile, val.z, cg::plus<float>());
}

inline __device__ void warpSum2(float2& val, cg::thread_block_tile<32>& tile){
val.x = cg::reduce(tile, val.x, cg::plus<float>());
val.y = cg::reduce(tile, val.y, cg::plus<float>());
}

inline __device__ void warpSum(float& val, cg::thread_block_tile<32>& tile){
val = cg::reduce(tile, val, cg::plus<float>());
}

__global__ void rasterize_backward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
Expand Down
3 changes: 1 addition & 2 deletions gsplat/cuda/csrc/backward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ __global__ void nd_rasterize_backward_kernel(
float2* __restrict__ v_xy,
float3* __restrict__ v_conic,
float* __restrict__ v_rgb,
float* __restrict__ v_opacity,
float* __restrict__ workspace
float* __restrict__ v_opacity
);

__global__ void rasterize_backward_kernel(
Expand Down
26 changes: 13 additions & 13 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,13 @@ nd_rasterize_forward_tensor(
torch::Tensor final_idx = torch::zeros(
{img_height, img_width}, xys.options().dtype(torch::kInt32)
);
const int B = block_dim3.x * block_dim3.y;
const uint32_t shared_mem = B*sizeof(int) + B*sizeof(float3) + B*sizeof(float3) + B*channels*sizeof(half);
if(cudaFuncSetAttribute(nd_rasterize_forward, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem) != cudaSuccess){
AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, " bytes), try lowering block_size");
}

nd_rasterize_forward<<<tile_bounds_dim3, block_dim3>>>(
nd_rasterize_forward<<<tile_bounds_dim3, block_dim3, shared_mem>>>(
tile_bounds_dim3,
img_size_dim3,
channels,
Expand Down Expand Up @@ -527,17 +532,13 @@ std::
torch::zeros({num_points, channels}, xys.options());
torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options());

torch::Tensor workspace;
if (channels > 3) {
workspace = torch::zeros(
{img_height, img_width, channels},
xys.options().dtype(torch::kFloat32)
);
} else {
workspace = torch::zeros({0}, xys.options().dtype(torch::kFloat32));
const int B = block.x * block.y;
//shared mem accounts for each thread having a local shared memory workspace for running sum
const uint32_t shared_mem = B*channels*sizeof(half);
if(cudaFuncSetAttribute(nd_rasterize_backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem) != cudaSuccess){
AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, " bytes), try lowering block_size");
}

nd_rasterize_backward_kernel<<<tile_bounds, block>>>(
nd_rasterize_backward_kernel<<<tile_bounds, block, shared_mem>>>(
tile_bounds,
img_size,
channels,
Expand All @@ -555,8 +556,7 @@ std::
(float2 *)v_xy.contiguous().data_ptr<float>(),
(float3 *)v_conic.contiguous().data_ptr<float>(),
v_colors.contiguous().data_ptr<float>(),
v_opacity.contiguous().data_ptr<float>(),
workspace.data_ptr<float>()
v_opacity.contiguous().data_ptr<float>()
);

return std::make_tuple(v_xy, v_conic, v_colors, v_opacity);
Expand Down
Loading

0 comments on commit 46bd547

Please sign in to comment.