Skip to content

Commit

Permalink
finish bin_and_sort_gaussians and fix some bugs of _torch_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoyang-Pan committed Oct 4, 2023
1 parent 3021ffa commit 329fc9f
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 18 deletions.
22 changes: 20 additions & 2 deletions diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ def map_gaussian_to_intersects(

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]
cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1].item()

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack("f", depths[idx])

# Interpret those bytes as an int32_t
depth_id_n = struct.unpack("I", raw_bytes)[0]
depth_id_n = struct.unpack("i", raw_bytes)[0]

for i in range(tile_min[1], tile_max[1]):
for j in range(tile_min[0], tile_max[0]):
Expand Down Expand Up @@ -348,3 +348,21 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted):
tile_bins[cur_tile_idx, 0] = idx

return tile_bins


def bin_and_sort_gaussians(
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds
):
isect_ids, gaussian_ids = map_gaussian_to_intersects(
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
)

# Sorting isect_ids_unsorted
sorted_values, sorted_indices = torch.sort(isect_ids)

isect_ids_sorted = sorted_values
gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices)

tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted)

return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins
1 change: 1 addition & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ def call_cuda(*args, **kwargs):
compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects")
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges")
bin_and_sort_gaussians = _make_lazy_cuda_func("bin_and_sort_gaussians")
74 changes: 68 additions & 6 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ project_gaussians_backward_tensor(
}

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
const int num_points, const torch::Tensor &num_tiles_hit
) {
// ref:
// https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
Expand Down Expand Up @@ -287,10 +287,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
) {
CHECK_INPUT(xys);
Expand Down Expand Up @@ -329,7 +329,7 @@ std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
torch::Tensor &isect_ids_sorted
const torch::Tensor &isect_ids_sorted
) {
CHECK_INPUT(isect_ids_sorted);
torch::Tensor tile_bins =
Expand All @@ -342,4 +342,66 @@ torch::Tensor get_tile_bin_edges_tensor(
(int2 *)tile_bins.contiguous().data_ptr<int>()
);
return tile_bins;
}

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
bin_and_sort_gaussians_tensor(
const int num_points,
const int num_intersects,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
){
CHECK_INPUT(xys);
CHECK_INPUT(depths);
CHECK_INPUT(radii);
CHECK_INPUT(cum_tiles_hit);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

torch::Tensor gaussian_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor gaussian_ids_sorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor isect_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));
torch::Tensor isect_ids_sorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));
torch::Tensor tile_bins =
torch::zeros({num_intersects, 2}, xys.options().dtype(torch::kInt32));

bin_and_sort_gaussians(
num_points,
num_intersects,
(float2 *)xys.contiguous().data_ptr<float>(),
depths.contiguous().data_ptr<float>(),
radii.contiguous().data_ptr<int32_t>(),
cum_tiles_hit.contiguous().data_ptr<int32_t>(),
tile_bounds_dim3,
// Outputs.
isect_ids_unsorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_unsorted.contiguous().data_ptr<int32_t>(),
isect_ids_sorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_sorted.contiguous().data_ptr<int32_t>(),
(int2 *)tile_bins.contiguous().data_ptr<int>()
);

return std::make_tuple(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins
);

}
28 changes: 22 additions & 6 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,35 @@ project_gaussians_backward_tensor(
);

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
const int num_points, const torch::Tensor &num_tiles_hit
);

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
torch::Tensor &isect_ids_sorted
const torch::Tensor &isect_ids_sorted
);

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
bin_and_sort_gaussians_tensor(
const int num_points,
const int num_intersects,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);
1 change: 1 addition & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor);
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor);
m.def("bin_and_sort_gaussians", &bin_and_sort_gaussians_tensor);
}
2 changes: 1 addition & 1 deletion diff_rast/cuda/csrc/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ __global__ void map_gaussian_to_intersects(
// update the intersection info for all tiles this gaussian hits
int32_t cur_idx = (idx == 0) ? 0 : cum_tiles_hit[idx - 1];
// printf("point %d starting at %d\n", idx, cur_idx);
u_int64_t depth_id = (u_int64_t) * (u_int32_t *)&(depths[idx]);
int64_t depth_id = (int64_t) * (int32_t *)&(depths[idx]);
for (int i = tile_min.y; i < tile_max.y; ++i) {
for (int j = tile_min.x; j < tile_max.x; ++j) {
// isect_id is tile ID and depth as int32
Expand Down
92 changes: 92 additions & 0 deletions tests/test_bin_and_sort_gaussians.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import torch


device = torch.device("cuda:0")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_bin_and_sort_gaussians():
from diff_rast import _torch_impl
import diff_rast.cuda as _C

torch.manual_seed(42)

num_points = 100
means3d = torch.randn((num_points, 3), device=device, requires_grad=True)
scales = torch.randn((num_points, 3), device=device)
glob_scale = 0.3
quats = torch.randn((num_points, 4), device=device)
quats /= torch.linalg.norm(quats, dim=-1, keepdim=True)
viewmat = torch.eye(4, device=device)
projmat = torch.eye(4, device=device)
fx, fy = 3.0, 3.0
H, W = 512, 512
clip_thresh = 0.01

BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1

(
_cov3d,
_xys,
_depths,
_radii,
_conics,
_num_tiles_hit,
_masks,
) = _torch_impl.project_gaussians_forward(
means3d,
scales,
glob_scale,
quats,
viewmat,
projmat,
fx,
fy,
(H, W),
tile_bounds,
clip_thresh,
)

_xys = _xys[_masks]
_depths = _depths[_masks]
_radii = _radii[_masks]
_conics = _conics[_masks]
_num_tiles_hit = _num_tiles_hit[_masks]

num_points = num_points - torch.count_nonzero(~_masks).item()

_cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32)
_num_intersects = _cum_tiles_hit[-1].item()
_depths = _depths.contiguous()

(
_isect_ids_unsorted,
_gaussian_ids_unsorted,
_isect_ids_sorted,
_gaussian_ids_sorted,
_tile_bins,
) = _torch_impl.bin_and_sort_gaussians(
num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds
)

(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
) = _C.bin_and_sort_gaussians(
num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds
)

torch.testing.assert_close(_isect_ids_unsorted, isect_ids_unsorted)
torch.testing.assert_close(_gaussian_ids_unsorted, gaussian_ids_unsorted)
torch.testing.assert_close(_isect_ids_sorted, isect_ids_sorted)
torch.testing.assert_close(_gaussian_ids_sorted, gaussian_ids_sorted)
torch.testing.assert_close(_tile_bins, tile_bins)


if __name__ == "__main__":
test_bin_and_sort_gaussians()
12 changes: 9 additions & 3 deletions tests/test_map_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,21 @@ def test_map_gaussians():
tile_bounds,
clip_thresh,
)
_xys = _xys[_masks]
_depths = _depths[_masks]
_radii = _radii[_masks]
_conics = _conics[_masks]
_num_tiles_hit = _num_tiles_hit[_masks]

num_points = num_points - torch.count_nonzero(~_masks).item()

_cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32)
_depths = _depths.contiguous()

isect_ids, gaussian_ids = _C.map_gaussian_to_intersects(
_isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects(
num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds
)

_isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects(
isect_ids, gaussian_ids = _C.map_gaussian_to_intersects(
num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds
)

Expand Down

0 comments on commit 329fc9f

Please sign in to comment.