From 46bd547d9ee29bf209f3fb1106775ae1caca5ca6 Mon Sep 17 00:00:00 2001 From: Justin Kerr Date: Tue, 20 Feb 2024 13:40:20 -0800 Subject: [PATCH] Speed up ND rasterization (#130) * refactor block size to be dynamic, hide tile_bounds from interface, update tests to reflect this change * allocate as much shared memory as possible based on block size * revert to old method, replace workspace with shared mem * speedup on nd backward w/ some warp operations * add fp16 headers * 128->3 in simple trainer * lint * remove comment * remove cudaPeekAtLastError --- examples/simple_trainer.py | 15 +++- gsplat/cuda/csrc/backward.cu | 157 +++++++++++++++++----------------- gsplat/cuda/csrc/backward.cuh | 3 +- gsplat/cuda/csrc/bindings.cu | 26 +++--- gsplat/cuda/csrc/forward.cu | 139 +++++++++++++++++++----------- 5 files changed, 196 insertions(+), 144 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 9484e951b..0f2e3d5e0 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -38,7 +38,8 @@ def _init_gaussians(self): self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) self.scales = torch.rand(self.num_points, 3, device=self.device) - self.rgbs = torch.rand(self.num_points, 3, device=self.device) + d = 3 + self.rgbs = torch.rand(self.num_points, d, device=self.device) u = torch.rand(self.num_points, 1, device=self.device) v = torch.rand(self.num_points, 1, device=self.device) @@ -64,7 +65,7 @@ def _init_gaussians(self): ], device=self.device, ) - self.background = torch.zeros(3, device=self.device) + self.background = torch.zeros(d, device=self.device) self.means.requires_grad = True self.scales.requires_grad = True @@ -73,7 +74,13 @@ def _init_gaussians(self): self.opacities.requires_grad = True self.viewmat.requires_grad = False - def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = False): + def train( + self, + iterations: int = 1000, + lr: float = 0.01, + save_imgs: bool = False, + B_SIZE: int = 14, + ): optimizer = optim.Adam( [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr ) @@ -121,7 +128,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals self.W, B_SIZE, self.background, - ) + )[..., :3] torch.cuda.synchronize() times[1] += time.time() - start loss = mse_loss(out_img, self.gt_image) diff --git a/gsplat/cuda/csrc/backward.cu b/gsplat/cuda/csrc/backward.cu index 353972dff..1e0c3fa98 100644 --- a/gsplat/cuda/csrc/backward.cu +++ b/gsplat/cuda/csrc/backward.cu @@ -1,9 +1,25 @@ #include "backward.cuh" #include "helpers.cuh" +#include #include #include namespace cg = cooperative_groups; +inline __device__ void warpSum3(float3& val, cg::thread_block_tile<32>& tile){ + val.x = cg::reduce(tile, val.x, cg::plus()); + val.y = cg::reduce(tile, val.y, cg::plus()); + val.z = cg::reduce(tile, val.z, cg::plus()); +} + +inline __device__ void warpSum2(float2& val, cg::thread_block_tile<32>& tile){ + val.x = cg::reduce(tile, val.x, cg::plus()); + val.y = cg::reduce(tile, val.y, cg::plus()); +} + +inline __device__ void warpSum(float& val, cg::thread_block_tile<32>& tile){ + val = cg::reduce(tile, val, cg::plus()); +} + __global__ void nd_rasterize_backward_kernel( const dim3 tile_bounds, const dim3 img_size, @@ -22,28 +38,21 @@ __global__ void nd_rasterize_backward_kernel( float2* __restrict__ v_xy, float3* __restrict__ v_conic, float* __restrict__ v_rgb, - float* __restrict__ v_opacity, - float* __restrict__ workspace + float* __restrict__ v_opacity ) { - if (channels > MAX_REGISTER_CHANNELS && workspace == nullptr) { - return; - } - // current naive implementation where tile data loading is redundant - // TODO tile data should be shared between tile threads + auto block = cg::this_thread_block(); + const int tr = block.thread_rank(); int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x; unsigned i = blockIdx.y * blockDim.y + threadIdx.y; unsigned j = blockIdx.x * blockDim.x + threadIdx.x; float px = (float)j; float py = (float)i; - int32_t pix_id = i * img_size.x + j; - - // return if out of bounds - if (i >= img_size.y || j >= img_size.x) { - return; - } + const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1); + // keep not rasterizing threads around for reading data + const bool inside = (i < img_size.y && j < img_size.x); // which gaussians get gradients for this pixel - int2 range = tile_bins[tile_id]; + const int2 range = tile_bins[tile_id]; // df/d_out for this pixel const float *v_out = &(v_output[channels * pix_id]); const float v_out_alpha = v_output_alpha[pix_id]; @@ -51,20 +60,18 @@ __global__ void nd_rasterize_backward_kernel( float T_final = final_Ts[pix_id]; float T = T_final; // the contribution from gaussians behind the current one - float buffer[MAX_REGISTER_CHANNELS] = {0.f}; - float *S; - if (channels <= MAX_REGISTER_CHANNELS) { - S = &buffer[0]; - } else { - S = &workspace[channels * pix_id]; - } - int bin_final = final_index[pix_id]; + + extern __shared__ half workspace[]; - // iterate backward to compute the jacobians wrt rgb, opacity, mean2d, and - // conic recursively compute T_{n-1} from T_n, where T_i = prod(j < i) (1 - - // alpha_j), and S_{n-1} from S_n, where S_j = sum_{i > j}(rgb_i * alpha_i * - // T_i) df/dalpha_i = rgb_i * T_i - S_{i+1| / (1 - alpha_i) - for (int idx = bin_final - 1; idx >= range.x; --idx) { + half *S = (half*)(&workspace[channels * tr]); + for(int c=0; c warp = cg::tiled_partition<32>(block); + const int warp_bin_final = cg::reduce(warp, bin_final, cg::greater()); + for (int idx = warp_bin_final - 1; idx >= range.x; --idx) { + int valid = inside && idx < bin_final; const int32_t g = gaussians_ids_sorted[idx]; const float3 conic = conics[g]; const float2 center = xys[g]; @@ -73,68 +80,64 @@ __global__ void nd_rasterize_backward_kernel( 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y; if (sigma < 0.f) { - continue; + valid = 0; } const float opac = opacities[g]; const float vis = __expf(-sigma); const float alpha = min(0.99f, opac * vis); if (alpha < 1.f / 255.f) { + valid = 0; + } + if(!warp.any(valid)){ continue; } - - // compute the current T for this gaussian - const float ra = 1.f / (1.f - alpha); - T *= ra; - // rgb = rgbs[g]; - // update v_rgb for this gaussian - const float fac = alpha * T; float v_alpha = 0.f; - for (int c = 0; c < channels; ++c) { - // gradient wrt rgb - atomicAdd(&(v_rgb[channels * g + c]), fac * v_out[c]); - // contribution from this pixel - v_alpha += (rgbs[channels * g + c] * T - S[c] * ra) * v_out[c]; - // contribution from background pixel - v_alpha += -T_final * ra * background[c] * v_out[c]; - // update the running sum - S[c] += rgbs[channels * g + c] * fac; + float3 v_conic_local = {0.f, 0.f, 0.f}; + float2 v_xy_local = {0.f, 0.f}; + float v_opacity_local = 0.f; + if(valid){ + // compute the current T for this gaussian + const float ra = 1.f / (1.f - alpha); + T *= ra; + // update v_rgb for this gaussian + const float fac = alpha * T; + for (int c = 0; c < channels; ++c) { + // gradient wrt rgb + atomicAdd(&(v_rgb[channels * g + c]), fac * v_out[c]); + // contribution from this pixel + v_alpha += (rgbs[channels * g + c] * T - __half2float(S[c]) * ra) * v_out[c]; + // contribution from background pixel + v_alpha += -T_final * ra * background[c] * v_out[c]; + // update the running sum + S[c] = __hadd(S[c], __float2half(rgbs[channels * g + c] * fac)); + } + v_alpha += T_final * ra * v_out_alpha; + const float v_sigma = -opac * vis * v_alpha; + v_conic_local = {0.5f * v_sigma * delta.x * delta.x, + 0.5f * v_sigma * delta.x * delta.y, + 0.5f * v_sigma * delta.y * delta.y}; + v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y), + v_sigma * (conic.y * delta.x + conic.z * delta.y)}; + v_opacity_local = vis * v_alpha; + } + warpSum3(v_conic_local, warp); + warpSum2(v_xy_local, warp); + warpSum(v_opacity_local, warp); + if (warp.thread_rank() == 0) { + float* v_conic_ptr = (float*)(v_conic); + atomicAdd(v_conic_ptr + 3*g + 0, v_conic_local.x); + atomicAdd(v_conic_ptr + 3*g + 1, v_conic_local.y); + atomicAdd(v_conic_ptr + 3*g + 2, v_conic_local.z); + + float* v_xy_ptr = (float*)(v_xy); + atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x); + atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y); + + atomicAdd(v_opacity + g, v_opacity_local); } - v_alpha += T_final * ra * v_out_alpha; - // update v_opacity for this gaussian - atomicAdd(&(v_opacity[g]), vis * v_alpha); - - // compute vjps for conics and means - // d_sigma / d_delta = conic * delta - // d_sigma / d_conic = delta * delta.T - const float v_sigma = -opac * vis * v_alpha; - - atomicAdd(&(v_conic[g].x), 0.5f * v_sigma * delta.x * delta.x); - atomicAdd(&(v_conic[g].y), 0.5f * v_sigma * delta.x * delta.y); - atomicAdd(&(v_conic[g].z), 0.5f * v_sigma * delta.y * delta.y); - atomicAdd( - &(v_xy[g].x), v_sigma * (conic.x * delta.x + conic.y * delta.y) - ); - atomicAdd( - &(v_xy[g].y), v_sigma * (conic.y * delta.x + conic.z * delta.y) - ); } } -inline __device__ void warpSum3(float3& val, cg::thread_block_tile<32>& tile){ - val.x = cg::reduce(tile, val.x, cg::plus()); - val.y = cg::reduce(tile, val.y, cg::plus()); - val.z = cg::reduce(tile, val.z, cg::plus()); -} - -inline __device__ void warpSum2(float2& val, cg::thread_block_tile<32>& tile){ - val.x = cg::reduce(tile, val.x, cg::plus()); - val.y = cg::reduce(tile, val.y, cg::plus()); -} - -inline __device__ void warpSum(float& val, cg::thread_block_tile<32>& tile){ - val = cg::reduce(tile, val, cg::plus()); -} - __global__ void rasterize_backward_kernel( const dim3 tile_bounds, const dim3 img_size, diff --git a/gsplat/cuda/csrc/backward.cuh b/gsplat/cuda/csrc/backward.cuh index 95c53f538..f45d2d3a7 100644 --- a/gsplat/cuda/csrc/backward.cuh +++ b/gsplat/cuda/csrc/backward.cuh @@ -47,8 +47,7 @@ __global__ void nd_rasterize_backward_kernel( float2* __restrict__ v_xy, float3* __restrict__ v_conic, float* __restrict__ v_rgb, - float* __restrict__ v_opacity, - float* __restrict__ workspace + float* __restrict__ v_opacity ); __global__ void rasterize_backward_kernel( diff --git a/gsplat/cuda/csrc/bindings.cu b/gsplat/cuda/csrc/bindings.cu index b9e005345..16bea7f1a 100644 --- a/gsplat/cuda/csrc/bindings.cu +++ b/gsplat/cuda/csrc/bindings.cu @@ -454,8 +454,13 @@ nd_rasterize_forward_tensor( torch::Tensor final_idx = torch::zeros( {img_height, img_width}, xys.options().dtype(torch::kInt32) ); + const int B = block_dim3.x * block_dim3.y; + const uint32_t shared_mem = B*sizeof(int) + B*sizeof(float3) + B*sizeof(float3) + B*channels*sizeof(half); + if(cudaFuncSetAttribute(nd_rasterize_forward, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem) != cudaSuccess){ + AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, " bytes), try lowering block_size"); + } - nd_rasterize_forward<<>>( + nd_rasterize_forward<<>>( tile_bounds_dim3, img_size_dim3, channels, @@ -527,17 +532,13 @@ std:: torch::zeros({num_points, channels}, xys.options()); torch::Tensor v_opacity = torch::zeros({num_points, 1}, xys.options()); - torch::Tensor workspace; - if (channels > 3) { - workspace = torch::zeros( - {img_height, img_width, channels}, - xys.options().dtype(torch::kFloat32) - ); - } else { - workspace = torch::zeros({0}, xys.options().dtype(torch::kFloat32)); + const int B = block.x * block.y; + //shared mem accounts for each thread having a local shared memory workspace for running sum + const uint32_t shared_mem = B*channels*sizeof(half); + if(cudaFuncSetAttribute(nd_rasterize_backward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem) != cudaSuccess){ + AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, " bytes), try lowering block_size"); } - - nd_rasterize_backward_kernel<<>>( + nd_rasterize_backward_kernel<<>>( tile_bounds, img_size, channels, @@ -555,8 +556,7 @@ std:: (float2 *)v_xy.contiguous().data_ptr(), (float3 *)v_conic.contiguous().data_ptr(), v_colors.contiguous().data_ptr(), - v_opacity.contiguous().data_ptr(), - workspace.data_ptr() + v_opacity.contiguous().data_ptr() ); return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); diff --git a/gsplat/cuda/csrc/forward.cu b/gsplat/cuda/csrc/forward.cu index 9cb349847..ac13b9309 100644 --- a/gsplat/cuda/csrc/forward.cu +++ b/gsplat/cuda/csrc/forward.cu @@ -4,6 +4,7 @@ #include #include #include +#include namespace cg = cooperative_groups; @@ -187,70 +188,112 @@ __global__ void nd_rasterize_forward( float* __restrict__ out_img, const float* __restrict__ background ) { - // current naive implementation where tile data loading is redundant - // TODO tile data should be shared between tile threads - int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x; - unsigned i = blockIdx.y * blockDim.y + threadIdx.y; - unsigned j = blockIdx.x * blockDim.x + threadIdx.x; + auto block = cg::this_thread_block(); + int32_t tile_id = + block.group_index().y * tile_bounds.x + block.group_index().x; + unsigned i = + block.group_index().y * block.group_dim().y + block.thread_index().y; + unsigned j = + block.group_index().x * block.group_dim().x + block.thread_index().x; + float px = (float)j; float py = (float)i; int32_t pix_id = i * img_size.x + j; // return if out of bounds - if (i >= img_size.y || j >= img_size.x) { - return; - } + // keep not rasterizing threads around for reading data + bool inside = (i < img_size.y && j < img_size.x); + bool done = !inside; + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches // which gaussians to look through in this tile int2 range = tile_bins[tile_id]; + const int block_size = block.size(); + int num_batches = (range.y - range.x + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t* id_batch = (int32_t*)s; + float3* xy_opacity_batch = (float3*)&id_batch[block_size]; + float3* conic_batch = (float3*)&xy_opacity_batch[block_size]; + __half* color_out_batch = (__half*)&conic_batch[block_size]; + for(int c = 0; c < channels; ++c) + color_out_batch[block.thread_rank() * channels + c] = __float2half(0.f); + + // current visibility left to render float T = 1.f; + // index of most recent gaussian to write to this thread's pixel + int cur_idx = 0; - // iterate over all gaussians and apply rendering EWA equation (e.q. 2 from - // paper) - int idx; - for (idx = range.x; idx < range.y; ++idx) { - const int32_t g = gaussian_ids_sorted[idx]; - const float3 conic = conics[g]; - const float2 center = xys[g]; - const float2 delta = {center.x - px, center.y - py}; - - // Mahalanobis distance (here referred to as sigma) measures how many - // standard deviations away distance delta is. sigma = -0.5(d.T * conic - // * d) - const float sigma = - 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + - conic.y * delta.x * delta.y; - if (sigma < 0.f) { - continue; + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing its + // designated pixel + int tr = block.thread_rank(); + __half* pix_out = &color_out_batch[block.thread_rank() * channels]; + // float* pix_out = out_img + pix_id * channels; + for (int b = 0; b < num_batches; ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + int batch_start = range.x + block_size * b; + int idx = batch_start + tr; + if (idx < range.y) { + int32_t g_id = gaussian_ids_sorted[idx]; + id_batch[tr] = g_id; + const float2 xy = xys[g_id]; + const float opac = opacities[g_id]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g_id]; } - const float opac = opacities[g]; - const float alpha = min(0.999f, opac * __expf(-sigma)); + // wait for other threads to collect the gaussians in batch + block.sync(); - // break out conditions - if (alpha < 1.f / 255.f) { - continue; - } - const float next_T = T * (1.f - alpha); - if (next_T <= 1e-4f) { - // we want to render the last gaussian that contributes and note - // that here idx > range.x so we don't underflow - idx -= 1; - break; + // process gaussians in the current batch for this pixel + int batch_size = min(block_size, range.y - batch_start); + for (int t = 0; (t < batch_size) && !done; ++t) { + const float3 conic = conic_batch[t]; + const float3 xy_opac = xy_opacity_batch[t]; + const float opac = xy_opac.z; + const float2 delta = {xy_opac.x - px, xy_opac.y - py}; + const float sigma = 0.5f * (conic.x * delta.x * delta.x + + conic.z * delta.y * delta.y) + + conic.y * delta.x * delta.y; + const float alpha = min(0.999f, opac * __expf(-sigma)); + if (sigma < 0.f || alpha < 1.f / 255.f) { + continue; + } + + const float next_T = T * (1.f - alpha); + if (next_T <= 1e-4f) { // this pixel is done + // we want to render the last gaussian that contributes and note + // that here idx > range.x so we don't underflow + done = true; + break; + } + + int32_t g = id_batch[t]; + const float vis = alpha * T; + for (int c = 0; c < channels; ++c) { + pix_out[c] = __hadd(pix_out[c], __float2half(colors[channels * g + c] * vis)); + } + T = next_T; + cur_idx = batch_start + t; } - const float vis = alpha * T; + } + + if (inside) { + // add background + final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel + final_index[pix_id] = + cur_idx; // index of in bin of last gaussian in this pixel for (int c = 0; c < channels; ++c) { - out_img[channels * pix_id + c] += colors[channels * g + c] * vis; + out_img[pix_id * channels + c] = __half2float(pix_out[c]) + T * background[c]; } - T = next_T; - } - final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel - final_index[pix_id] = - (idx == range.y) - ? idx - 1 - : idx; // index of in bin of last gaussian in this pixel - for (int c = 0; c < channels; ++c) { - out_img[channels * pix_id + c] += T * background[c]; } }