From 329fc9fe6bd5305aa63947ea4400fc3672879313 Mon Sep 17 00:00:00 2001 From: Zhuoyang Date: Wed, 4 Oct 2023 16:58:36 -0700 Subject: [PATCH] finish bin_and_sort_gaussians and fix some bugs of _torch_impl --- diff_rast/_torch_impl.py | 22 ++++++- diff_rast/cuda/__init__.py | 1 + diff_rast/cuda/csrc/bindings.cu | 74 ++++++++++++++++++++-- diff_rast/cuda/csrc/bindings.h | 28 +++++++-- diff_rast/cuda/csrc/ext.cpp | 1 + diff_rast/cuda/csrc/forward.cu | 2 +- tests/test_bin_and_sort_gaussians.py | 92 ++++++++++++++++++++++++++++ tests/test_map_gaussians.py | 12 +++- 8 files changed, 214 insertions(+), 18 deletions(-) create mode 100644 tests/test_bin_and_sort_gaussians.py diff --git a/diff_rast/_torch_impl.py b/diff_rast/_torch_impl.py index a2ea58f5b..52fd2ba87 100644 --- a/diff_rast/_torch_impl.py +++ b/diff_rast/_torch_impl.py @@ -306,13 +306,13 @@ def map_gaussian_to_intersects( tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds) - cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1] + cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1].item() # Get raw byte representation of the float value at the given index raw_bytes = struct.pack("f", depths[idx]) # Interpret those bytes as an int32_t - depth_id_n = struct.unpack("I", raw_bytes)[0] + depth_id_n = struct.unpack("i", raw_bytes)[0] for i in range(tile_min[1], tile_max[1]): for j in range(tile_min[0], tile_max[0]): @@ -348,3 +348,21 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted): tile_bins[cur_tile_idx, 0] = idx return tile_bins + + +def bin_and_sort_gaussians( + num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds +): + isect_ids, gaussian_ids = map_gaussian_to_intersects( + num_points, xys, depths, radii, cum_tiles_hit, tile_bounds + ) + + # Sorting isect_ids_unsorted + sorted_values, sorted_indices = torch.sort(isect_ids) + + isect_ids_sorted = sorted_values + gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices) + + tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted) + + return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins diff --git a/diff_rast/cuda/__init__.py b/diff_rast/cuda/__init__.py index e2ca4dc32..df7dcd614 100644 --- a/diff_rast/cuda/__init__.py +++ b/diff_rast/cuda/__init__.py @@ -21,3 +21,4 @@ def call_cuda(*args, **kwargs): compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects") map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects") get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges") +bin_and_sort_gaussians = _make_lazy_cuda_func("bin_and_sort_gaussians") diff --git a/diff_rast/cuda/csrc/bindings.cu b/diff_rast/cuda/csrc/bindings.cu index 1747fff51..63b0752d1 100644 --- a/diff_rast/cuda/csrc/bindings.cu +++ b/diff_rast/cuda/csrc/bindings.cu @@ -258,7 +258,7 @@ project_gaussians_backward_tensor( } std::tuple compute_cumulative_intersects_tensor( - const int num_points, torch::Tensor &num_tiles_hit + const int num_points, const torch::Tensor &num_tiles_hit ) { // ref: // https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a @@ -287,10 +287,10 @@ std::tuple compute_cumulative_intersects_tensor( std::tuple map_gaussian_to_intersects_tensor( const int num_points, - torch::Tensor &xys, - torch::Tensor &depths, - torch::Tensor &radii, - torch::Tensor &cum_tiles_hit, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, const std::tuple tile_bounds ) { CHECK_INPUT(xys); @@ -329,7 +329,7 @@ std::tuple map_gaussian_to_intersects_tensor( torch::Tensor get_tile_bin_edges_tensor( int num_intersects, - torch::Tensor &isect_ids_sorted + const torch::Tensor &isect_ids_sorted ) { CHECK_INPUT(isect_ids_sorted); torch::Tensor tile_bins = @@ -342,4 +342,66 @@ torch::Tensor get_tile_bin_edges_tensor( (int2 *)tile_bins.contiguous().data_ptr() ); return tile_bins; +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +bin_and_sort_gaussians_tensor( + const int num_points, + const int num_intersects, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds +){ + CHECK_INPUT(xys); + CHECK_INPUT(depths); + CHECK_INPUT(radii); + CHECK_INPUT(cum_tiles_hit); + + dim3 tile_bounds_dim3; + tile_bounds_dim3.x = std::get<0>(tile_bounds); + tile_bounds_dim3.y = std::get<1>(tile_bounds); + tile_bounds_dim3.z = std::get<2>(tile_bounds); + + torch::Tensor gaussian_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); + torch::Tensor gaussian_ids_sorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); + torch::Tensor isect_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + torch::Tensor isect_ids_sorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + torch::Tensor tile_bins = + torch::zeros({num_intersects, 2}, xys.options().dtype(torch::kInt32)); + + bin_and_sort_gaussians( + num_points, + num_intersects, + (float2 *)xys.contiguous().data_ptr(), + depths.contiguous().data_ptr(), + radii.contiguous().data_ptr(), + cum_tiles_hit.contiguous().data_ptr(), + tile_bounds_dim3, + // Outputs. + isect_ids_unsorted.contiguous().data_ptr(), + gaussian_ids_unsorted.contiguous().data_ptr(), + isect_ids_sorted.contiguous().data_ptr(), + gaussian_ids_sorted.contiguous().data_ptr(), + (int2 *)tile_bins.contiguous().data_ptr() + ); + + return std::make_tuple( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins + ); + } \ No newline at end of file diff --git a/diff_rast/cuda/csrc/bindings.h b/diff_rast/cuda/csrc/bindings.h index e236d19ab..3b20e62b3 100644 --- a/diff_rast/cuda/csrc/bindings.h +++ b/diff_rast/cuda/csrc/bindings.h @@ -81,19 +81,35 @@ project_gaussians_backward_tensor( ); std::tuple compute_cumulative_intersects_tensor( - const int num_points, torch::Tensor &num_tiles_hit + const int num_points, const torch::Tensor &num_tiles_hit ); std::tuple map_gaussian_to_intersects_tensor( const int num_points, - torch::Tensor &xys, - torch::Tensor &depths, - torch::Tensor &radii, - torch::Tensor &cum_tiles_hit, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, const std::tuple tile_bounds ); torch::Tensor get_tile_bin_edges_tensor( int num_intersects, - torch::Tensor &isect_ids_sorted + const torch::Tensor &isect_ids_sorted +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +bin_and_sort_gaussians_tensor( + const int num_points, + const int num_intersects, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds ); \ No newline at end of file diff --git a/diff_rast/cuda/csrc/ext.cpp b/diff_rast/cuda/csrc/ext.cpp index c733bf55a..0b8ee4b15 100644 --- a/diff_rast/cuda/csrc/ext.cpp +++ b/diff_rast/cuda/csrc/ext.cpp @@ -13,4 +13,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor); m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor); m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor); + m.def("bin_and_sort_gaussians", &bin_and_sort_gaussians_tensor); } diff --git a/diff_rast/cuda/csrc/forward.cu b/diff_rast/cuda/csrc/forward.cu index 4f25f9b03..aa12b771b 100644 --- a/diff_rast/cuda/csrc/forward.cu +++ b/diff_rast/cuda/csrc/forward.cu @@ -166,7 +166,7 @@ __global__ void map_gaussian_to_intersects( // update the intersection info for all tiles this gaussian hits int32_t cur_idx = (idx == 0) ? 0 : cum_tiles_hit[idx - 1]; // printf("point %d starting at %d\n", idx, cur_idx); - u_int64_t depth_id = (u_int64_t) * (u_int32_t *)&(depths[idx]); + int64_t depth_id = (int64_t) * (int32_t *)&(depths[idx]); for (int i = tile_min.y; i < tile_max.y; ++i) { for (int j = tile_min.x; j < tile_max.x; ++j) { // isect_id is tile ID and depth as int32 diff --git a/tests/test_bin_and_sort_gaussians.py b/tests/test_bin_and_sort_gaussians.py new file mode 100644 index 000000000..24b2bae31 --- /dev/null +++ b/tests/test_bin_and_sort_gaussians.py @@ -0,0 +1,92 @@ +import pytest +import torch + + +device = torch.device("cuda:0") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_bin_and_sort_gaussians(): + from diff_rast import _torch_impl + import diff_rast.cuda as _C + + torch.manual_seed(42) + + num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) + scales = torch.randn((num_points, 3), device=device) + glob_scale = 0.3 + quats = torch.randn((num_points, 4), device=device) + quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) + viewmat = torch.eye(4, device=device) + projmat = torch.eye(4, device=device) + fx, fy = 3.0, 3.0 + H, W = 512, 512 + clip_thresh = 0.01 + + BLOCK_X, BLOCK_Y = 16, 16 + tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 + + ( + _cov3d, + _xys, + _depths, + _radii, + _conics, + _num_tiles_hit, + _masks, + ) = _torch_impl.project_gaussians_forward( + means3d, + scales, + glob_scale, + quats, + viewmat, + projmat, + fx, + fy, + (H, W), + tile_bounds, + clip_thresh, + ) + + _xys = _xys[_masks] + _depths = _depths[_masks] + _radii = _radii[_masks] + _conics = _conics[_masks] + _num_tiles_hit = _num_tiles_hit[_masks] + + num_points = num_points - torch.count_nonzero(~_masks).item() + + _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) + _num_intersects = _cum_tiles_hit[-1].item() + _depths = _depths.contiguous() + + ( + _isect_ids_unsorted, + _gaussian_ids_unsorted, + _isect_ids_sorted, + _gaussian_ids_sorted, + _tile_bins, + ) = _torch_impl.bin_and_sort_gaussians( + num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + ( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins, + ) = _C.bin_and_sort_gaussians( + num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + torch.testing.assert_close(_isect_ids_unsorted, isect_ids_unsorted) + torch.testing.assert_close(_gaussian_ids_unsorted, gaussian_ids_unsorted) + torch.testing.assert_close(_isect_ids_sorted, isect_ids_sorted) + torch.testing.assert_close(_gaussian_ids_sorted, gaussian_ids_sorted) + torch.testing.assert_close(_tile_bins, tile_bins) + + +if __name__ == "__main__": + test_bin_and_sort_gaussians() diff --git a/tests/test_map_gaussians.py b/tests/test_map_gaussians.py index 394fe7199..912d22226 100644 --- a/tests/test_map_gaussians.py +++ b/tests/test_map_gaussians.py @@ -48,15 +48,21 @@ def test_map_gaussians(): tile_bounds, clip_thresh, ) + _xys = _xys[_masks] + _depths = _depths[_masks] + _radii = _radii[_masks] + _conics = _conics[_masks] + _num_tiles_hit = _num_tiles_hit[_masks] + + num_points = num_points - torch.count_nonzero(~_masks).item() _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) _depths = _depths.contiguous() - isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( + _isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects( num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds ) - - _isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects( + isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds )