diff --git a/README.md b/README.md index 8d9c4506..3949108d 100644 --- a/README.md +++ b/README.md @@ -87,16 +87,18 @@ def rgb_sigma_fn( return rgbs, sigmas # (n_samples, 3), (n_samples, 1) # Efficient Raymarching: Skip empty and occluded space, pack samples from all rays. -# packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1). +# ray_indices: (n_samples,). t_starts: (n_samples, 1). t_ends: (n_samples, 1). with torch.no_grad(): - packed_info, t_starts, t_ends = nerfacc.ray_marching( + ray_indices, t_starts, t_ends = nerfacc.ray_marching( rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0, early_stop_eps=1e-4, alpha_thre=1e-2, ) # Differentiable Volumetric Rendering. # colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1). -color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends) +color, opacity, depth = nerfacc.rendering( + t_starts, t_ends, ray_indices, n_rays=rays_o.shape[0], rgb_sigma_fn=rgb_sigma_fn +) # Optimize: Both the network and rays will receive gradients optimizer.zero_grad() diff --git a/docs/source/apis/generated/nerfacc.render_transmittance_from_alpha.rst b/docs/source/apis/generated/nerfacc.render_transmittance_from_alpha.rst new file mode 100644 index 00000000..c833e31e --- /dev/null +++ b/docs/source/apis/generated/nerfacc.render_transmittance_from_alpha.rst @@ -0,0 +1,6 @@ +nerfacc.render\_transmittance\_from\_alpha +========================================== + +.. currentmodule:: nerfacc + +.. autofunction:: render_transmittance_from_alpha \ No newline at end of file diff --git a/docs/source/apis/generated/nerfacc.render_transmittance_from_density.rst b/docs/source/apis/generated/nerfacc.render_transmittance_from_density.rst new file mode 100644 index 00000000..5715eca3 --- /dev/null +++ b/docs/source/apis/generated/nerfacc.render_transmittance_from_density.rst @@ -0,0 +1,6 @@ +nerfacc.render\_transmittance\_from\_density +============================================ + +.. currentmodule:: nerfacc + +.. autofunction:: render_transmittance_from_density \ No newline at end of file diff --git a/docs/source/apis/utils.rst b/docs/source/apis/utils.rst index 697e12cc..56b1730c 100644 --- a/docs/source/apis/utils.rst +++ b/docs/source/apis/utils.rst @@ -11,6 +11,8 @@ Utils unpack_info accumulate_along_rays + render_transmittance_from_density + render_transmittance_from_alpha render_weight_from_density render_weight_from_alpha render_visibility diff --git a/docs/source/index.rst b/docs/source/index.rst index ecd83a19..cbdcb941 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -90,16 +90,18 @@ An simple example is like this: return rgbs, sigmas # (n_samples, 3), (n_samples, 1) # Efficient Raymarching: Skip empty and occluded space, pack samples from all rays. - # packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1). + # ray_indices: (n_samples,). t_starts: (n_samples, 1). t_ends: (n_samples, 1). with torch.no_grad(): - packed_info, t_starts, t_ends = nerfacc.ray_marching( + ray_indices, t_starts, t_ends = nerfacc.ray_marching( rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0, early_stop_eps=1e-4, alpha_thre=1e-2, ) # Differentiable Volumetric Rendering. # colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1). - color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends) + color, opacity, depth = nerfacc.rendering( + t_starts, t_ends, ray_indices, n_rays=rays_o.shape[0], rgb_sigma_fn=rgb_sigma_fn + ) # Optimize: Both the network and rays will receive gradients optimizer.zero_grad() diff --git a/examples/utils.py b/examples/utils.py index e5e6b378..b84789f2 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -85,7 +85,7 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): ) for i in range(0, num_rays, chunk): chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( chunk_rays.origins, chunk_rays.viewdirs, scene_aabb=scene_aabb, @@ -99,9 +99,10 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): alpha_thre=alpha_thre, ) rgb, opacity, depth = rendering( - packed_info, t_starts, t_ends, + ray_indices, + n_rays=chunk_rays.origins.shape[0], rgb_sigma_fn=rgb_sigma_fn, render_bkgd=render_bkgd, ) diff --git a/nerfacc/__init__.py b/nerfacc/__init__.py index 620d7e28..61b0404a 100644 --- a/nerfacc/__init__.py +++ b/nerfacc/__init__.py @@ -8,11 +8,13 @@ from .grid import Grid, OccupancyGrid, query_grid from .intersection import ray_aabb_intersect from .losses import distortion as loss_distortion -from .pack import pack_data, unpack_data, unpack_info +from .pack import pack_data, pack_info, unpack_data, unpack_info from .ray_marching import ray_marching from .version import __version__ from .vol_rendering import ( accumulate_along_rays, + render_transmittance_from_alpha, + render_transmittance_from_density, render_visibility, render_weight_from_alpha, render_weight_from_density, @@ -48,7 +50,10 @@ def unpack_to_ray_indices(*args, **kwargs): "pack_data", "unpack_data", "unpack_info", + "pack_info", "ray_resampling", "loss_distortion", "unpack_to_ray_indices", + "render_transmittance_from_density", + "render_transmittance_from_alpha", ] diff --git a/nerfacc/cuda/__init__.py b/nerfacc/cuda/__init__.py index 70aeda7a..c1ff784a 100644 --- a/nerfacc/cuda/__init__.py +++ b/nerfacc/cuda/__init__.py @@ -25,10 +25,45 @@ def call_cuda(*args, **kwargs): ray_marching = _make_lazy_cuda_func("ray_marching") ray_resampling = _make_lazy_cuda_func("ray_resampling") -rendering_forward = _make_lazy_cuda_func("rendering_forward") -rendering_backward = _make_lazy_cuda_func("rendering_backward") -rendering_alphas_forward = _make_lazy_cuda_func("rendering_alphas_forward") -rendering_alphas_backward = _make_lazy_cuda_func("rendering_alphas_backward") +is_cub_available = _make_lazy_cuda_func("is_cub_available") +transmittance_from_sigma_forward_cub = _make_lazy_cuda_func( + "transmittance_from_sigma_forward_cub" +) +transmittance_from_sigma_backward_cub = _make_lazy_cuda_func( + "transmittance_from_sigma_backward_cub" +) +transmittance_from_alpha_forward_cub = _make_lazy_cuda_func( + "transmittance_from_alpha_forward_cub" +) +transmittance_from_alpha_backward_cub = _make_lazy_cuda_func( + "transmittance_from_alpha_backward_cub" +) + +transmittance_from_sigma_forward_naive = _make_lazy_cuda_func( + "transmittance_from_sigma_forward_naive" +) +transmittance_from_sigma_backward_naive = _make_lazy_cuda_func( + "transmittance_from_sigma_backward_naive" +) +transmittance_from_alpha_forward_naive = _make_lazy_cuda_func( + "transmittance_from_alpha_forward_naive" +) +transmittance_from_alpha_backward_naive = _make_lazy_cuda_func( + "transmittance_from_alpha_backward_naive" +) + +weight_from_sigma_forward_naive = _make_lazy_cuda_func( + "weight_from_sigma_forward_naive" +) +weight_from_sigma_backward_naive = _make_lazy_cuda_func( + "weight_from_sigma_backward_naive" +) +weight_from_alpha_forward_naive = _make_lazy_cuda_func( + "weight_from_alpha_forward_naive" +) +weight_from_alpha_backward_naive = _make_lazy_cuda_func( + "weight_from_alpha_backward_naive" +) unpack_data = _make_lazy_cuda_func("unpack_data") unpack_info = _make_lazy_cuda_func("unpack_info") diff --git a/nerfacc/cuda/_backend.py b/nerfacc/cuda/_backend.py index 5c5da4d7..85d2a985 100644 --- a/nerfacc/cuda/_backend.py +++ b/nerfacc/cuda/_backend.py @@ -3,9 +3,14 @@ """ import glob +import json import os +import shutil +import urllib.request +import zipfile from subprocess import DEVNULL, call +from packaging import version from rich.console import Console from torch.utils.cpp_extension import _get_build_directory, load @@ -21,32 +26,70 @@ def cuda_toolkit_available(): return False -def load_extention(name: str): - return load( - name=name, - sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3"], - ) +def cuda_toolkit_version(): + """Get the cuda toolkit version.""" + cuda_home = os.path.join(os.path.dirname(shutil.which("nvcc")), "..") + if os.path.exists(os.path.join(cuda_home, "version.txt")): + with open(os.path.join(cuda_home, "version.txt")) as f: + cuda_version = f.read().strip().split()[-1] + elif os.path.exists(os.path.join(cuda_home, "version.json")): + with open(os.path.join(cuda_home, "version.json")) as f: + cuda_version = json.load(f)["cuda"]["version"] + else: + raise RuntimeError("Cannot find the cuda version.") + return cuda_version -_C = None name = "nerfacc_cuda" -if os.listdir(_get_build_directory(name, verbose=False)) != []: - # If the build exists, we assume the extension has been built - # and we can load it. - _C = load_extention(name) -else: - # First time to build the extension - if cuda_toolkit_available(): +build_dir = _get_build_directory(name, verbose=False) +extra_include_paths = [] +extra_cflags = ["-O3"] +extra_cuda_cflags = ["-O3"] + +_C = None +if cuda_toolkit_available(): + # # we need cub >= 1.15.0 which is shipped with cuda >= 11.6, so download if + # # necessary. (compling does not garentee to success) + # if version.parse(cuda_toolkit_version()) < version.parse("11.6"): + # target_path = os.path.join(build_dir, "cub-1.17.0") + # if not os.path.exists(target_path): + # zip_path, _ = urllib.request.urlretrieve( + # "https://github.com/NVIDIA/cub/archive/1.17.0.tar.gz", + # os.path.join(build_dir, "cub-1.17.0.tar.gz"), + # ) + # shutil.unpack_archive(zip_path, build_dir) + # extra_include_paths.append(target_path) + # extra_cuda_cflags.append("-DTHRUST_IGNORE_CUB_VERSION_CHECK") + # print( + # f"download cub because the cuda version is {cuda_toolkit_version()}" + # ) + + if os.path.exists(os.path.join(build_dir, f"{name}.so")): + # If the build exists, we assume the extension has been built + # and we can load it. + _C = load( + name=name, + sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + ) + else: with Console().status( "[bold yellow]NerfAcc: Setting up CUDA (This may take a few minutes the first time)", spinner="bouncingBall", ): - _C = load_extention(name) - else: - Console().print( - "[yellow]NerfAcc: No CUDA toolkit found. NerfAcc will be disabled.[/yellow]" - ) + _C = load( + name=name, + sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + ) +else: + Console().print( + "[yellow]NerfAcc: No CUDA toolkit found. NerfAcc will be disabled.[/yellow]" + ) + __all__ = ["_C"] diff --git a/nerfacc/cuda/csrc/include/helpers_cuda.h b/nerfacc/cuda/csrc/include/helpers_cuda.h index b0b0fb59..10271f84 100644 --- a/nerfacc/cuda/csrc/include/helpers_cuda.h +++ b/nerfacc/cuda/csrc/include/helpers_cuda.h @@ -6,6 +6,8 @@ #include #include +#include +#include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ @@ -20,3 +22,13 @@ #define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1) #define DEVICE_GUARD(_ten) \ const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); + +// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46 +#define CUB_WRAPPER(func, ...) do { \ + size_t temp_storage_bytes = 0; \ + func(nullptr, temp_storage_bytes, __VA_ARGS__); \ + auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \ + auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ + func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ + AT_CUDA_CHECK(cudaGetLastError()); \ +} while (false) \ No newline at end of file diff --git a/nerfacc/cuda/csrc/pack.cu b/nerfacc/cuda/csrc/pack.cu index 645373f2..3a455042 100644 --- a/nerfacc/cuda/csrc/pack.cu +++ b/nerfacc/cuda/csrc/pack.cu @@ -81,7 +81,7 @@ __global__ void unpack_data_kernel( return; } -torch::Tensor unpack_info(const torch::Tensor packed_info) +torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples) { DEVICE_GUARD(packed_info); CHECK_INPUT(packed_info); @@ -90,7 +90,7 @@ torch::Tensor unpack_info(const torch::Tensor packed_info) const int threads = 256; const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); - int n_samples = packed_info[n_rays - 1].sum(0).item(); + // int n_samples = packed_info[n_rays - 1].sum(0).item(); torch::Tensor ray_indices = torch::empty( {n_samples}, packed_info.options().dtype(torch::kInt32)); diff --git a/nerfacc/cuda/csrc/pybind.cu b/nerfacc/cuda/csrc/pybind.cu index a4663ee5..a8ba63b3 100644 --- a/nerfacc/cuda/csrc/pybind.cu +++ b/nerfacc/cuda/csrc/pybind.cu @@ -6,24 +6,6 @@ #include "include/helpers_math.h" #include "include/helpers_contraction.h" -std::vector rendering_forward( - torch::Tensor packed_info, - torch::Tensor starts, - torch::Tensor ends, - torch::Tensor sigmas, - float early_stop_eps, - float alpha_thre, - bool compression); - -torch::Tensor rendering_backward( - torch::Tensor weights, - torch::Tensor grad_weights, - torch::Tensor packed_info, - torch::Tensor starts, - torch::Tensor ends, - torch::Tensor sigmas, - float early_stop_eps, - float alpha_thre); std::vector ray_aabb_intersect( const torch::Tensor rays_o, @@ -45,7 +27,7 @@ std::vector ray_marching( const float cone_angle); torch::Tensor unpack_info( - const torch::Tensor packed_info); + const torch::Tensor packed_info, const int n_samples); torch::Tensor unpack_info_to_mask( const torch::Tensor packed_info, const int n_samples); @@ -69,32 +51,82 @@ torch::Tensor contract_inv( const torch::Tensor roi, const ContractionType type); -torch::Tensor rendering_alphas_backward( +std::vector ray_resampling( + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, torch::Tensor weights, - torch::Tensor grad_weights, + const int steps); + +torch::Tensor unpack_data( torch::Tensor packed_info, + torch::Tensor data, + int n_samples_per_ray); + +// cub implementations: parallel across samples +bool is_cub_available() { + return (bool) CUB_SUPPORTS_SCAN_BY_KEY(); +} +torch::Tensor transmittance_from_sigma_forward_cub( + torch::Tensor ray_indices, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas); +torch::Tensor transmittance_from_sigma_backward_cub( + torch::Tensor ray_indices, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor transmittance, + torch::Tensor transmittance_grad); +torch::Tensor transmittance_from_alpha_forward_cub( + torch::Tensor ray_indices, torch::Tensor alphas); +torch::Tensor transmittance_from_alpha_backward_cub( + torch::Tensor ray_indices, torch::Tensor alphas, - float early_stop_eps, - float alpha_thre); + torch::Tensor transmittance, + torch::Tensor transmittance_grad); -std::vector rendering_alphas_forward( +// naive implementations: parallel across rays +torch::Tensor transmittance_from_sigma_forward_naive( + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas); +torch::Tensor transmittance_from_sigma_backward_naive( + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor transmittance, + torch::Tensor transmittance_grad); +torch::Tensor transmittance_from_alpha_forward_naive( + torch::Tensor packed_info, + torch::Tensor alphas); +torch::Tensor transmittance_from_alpha_backward_naive( torch::Tensor packed_info, torch::Tensor alphas, - float early_stop_eps, - float alpha_thre, - bool compression); + torch::Tensor transmittance, + torch::Tensor transmittance_grad); -std::vector ray_resampling( +torch::Tensor weight_from_sigma_forward_naive( torch::Tensor packed_info, torch::Tensor starts, torch::Tensor ends, + torch::Tensor sigmas); +torch::Tensor weight_from_sigma_backward_naive( torch::Tensor weights, - const int steps); - -torch::Tensor unpack_data( + torch::Tensor grad_weights, torch::Tensor packed_info, - torch::Tensor data, - int n_samples_per_ray); + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas); +torch::Tensor weight_from_alpha_forward_naive( + torch::Tensor packed_info, + torch::Tensor alphas); +torch::Tensor weight_from_alpha_backward_naive( + torch::Tensor weights, + torch::Tensor grad_weights, + torch::Tensor packed_info, + torch::Tensor alphas); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -115,10 +147,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("ray_resampling", &ray_resampling); // rendering - m.def("rendering_forward", &rendering_forward); - m.def("rendering_backward", &rendering_backward); - m.def("rendering_alphas_forward", &rendering_alphas_forward); - m.def("rendering_alphas_backward", &rendering_alphas_backward); + m.def("is_cub_available", is_cub_available); + m.def("transmittance_from_sigma_forward_cub", transmittance_from_sigma_forward_cub); + m.def("transmittance_from_sigma_backward_cub", transmittance_from_sigma_backward_cub); + m.def("transmittance_from_alpha_forward_cub", transmittance_from_alpha_forward_cub); + m.def("transmittance_from_alpha_backward_cub", transmittance_from_alpha_backward_cub); + + m.def("transmittance_from_sigma_forward_naive", transmittance_from_sigma_forward_naive); + m.def("transmittance_from_sigma_backward_naive", transmittance_from_sigma_backward_naive); + m.def("transmittance_from_alpha_forward_naive", transmittance_from_alpha_forward_naive); + m.def("transmittance_from_alpha_backward_naive", transmittance_from_alpha_backward_naive); + + m.def("weight_from_sigma_forward_naive", weight_from_sigma_forward_naive); + m.def("weight_from_sigma_backward_naive", weight_from_sigma_backward_naive); + m.def("weight_from_alpha_forward_naive", weight_from_alpha_forward_naive); + m.def("weight_from_alpha_backward_naive", weight_from_alpha_backward_naive); // pack & unpack m.def("unpack_data", &unpack_data); diff --git a/nerfacc/cuda/csrc/ray_marching.cu b/nerfacc/cuda/csrc/ray_marching.cu index 5cba94f9..708c4d4c 100644 --- a/nerfacc/cuda/csrc/ray_marching.cu +++ b/nerfacc/cuda/csrc/ray_marching.cu @@ -95,6 +95,7 @@ __global__ void ray_marching_kernel( // first round outputs int *num_steps, // second round outputs + int *ray_indices, float *t_starts, float *t_ends) { @@ -118,6 +119,7 @@ __global__ void ray_marching_kernel( int steps = packed_info[i * 2 + 1]; t_starts += base; t_ends += base; + ray_indices += base; } const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]); @@ -148,6 +150,7 @@ __global__ void ray_marching_kernel( { t_starts[j] = t0; t_ends[j] = t1; + ray_indices[j] = i; } ++j; // march to next sample @@ -245,6 +248,7 @@ std::vector ray_marching( nullptr, /* packed_info */ // outputs num_steps.data_ptr(), + nullptr, /* ray_indices */ nullptr, /* t_starts */ nullptr /* t_ends */); @@ -255,6 +259,7 @@ std::vector ray_marching( int total_steps = cum_steps[cum_steps.size(0) - 1].item(); torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options()); + torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options()); ray_marching_kernel<<>>( // rays @@ -274,10 +279,11 @@ std::vector ray_marching( packed_info.data_ptr(), // outputs nullptr, /* num_steps */ + ray_indices.data_ptr(), t_starts.data_ptr(), t_ends.data_ptr()); - return {packed_info, t_starts, t_ends}; + return {packed_info, ray_indices, t_starts, t_ends}; } // ---------------------------------------------------------------------------- diff --git a/nerfacc/cuda/csrc/render_transmittance.cu b/nerfacc/cuda/csrc/render_transmittance.cu new file mode 100644 index 00000000..55d05818 --- /dev/null +++ b/nerfacc/cuda/csrc/render_transmittance.cu @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2022 Ruilong Li, UC Berkeley. + */ + +#include "include/helpers_cuda.h" + +__global__ void transmittance_from_sigma_forward_kernel( + const uint32_t n_rays, + // inputs + const int *packed_info, + const float *starts, + const float *ends, + const float *sigmas, + // outputs + float *transmittance) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + starts += base; + ends += base; + sigmas += base; + transmittance += base; + + // accumulation + float cumsum = 0.0f; + for (int j = 0; j < steps; ++j) + { + transmittance[j] = __expf(-cumsum); + cumsum += sigmas[j] * (ends[j] - starts[j]); + } + + // // another way to impl: + // float T = 1.f; + // for (int j = 0; j < steps; ++j) + // { + // const float delta = ends[j] - starts[j]; + // const float alpha = 1.f - __expf(-sigmas[j] * delta); + // transmittance[j] = T; + // T *= (1.f - alpha); + // } + return; +} + +__global__ void transmittance_from_sigma_backward_kernel( + const uint32_t n_rays, + // inputs + const int *packed_info, + const float *starts, + const float *ends, + const float *transmittance, + const float *transmittance_grad, + // outputs + float *sigmas_grad) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + transmittance += base; + transmittance_grad += base; + starts += base; + ends += base; + sigmas_grad += base; + + // accumulation + float cumsum = 0.0f; + for (int j = steps - 1; j >= 0; --j) + { + sigmas_grad[j] = cumsum * (ends[j] - starts[j]); + cumsum += -transmittance_grad[j] * transmittance[j]; + } + return; +} + +__global__ void transmittance_from_alpha_forward_kernel( + const uint32_t n_rays, + // inputs + const int *packed_info, + const float *alphas, + // outputs + float *transmittance) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + alphas += base; + transmittance += base; + + // accumulation + float T = 1.0f; + for (int j = 0; j < steps; ++j) + { + transmittance[j] = T; + T *= (1.0f - alphas[j]); + } + return; +} + +__global__ void transmittance_from_alpha_backward_kernel( + const uint32_t n_rays, + // inputs + const int *packed_info, + const float *alphas, + const float *transmittance, + const float *transmittance_grad, + // outputs + float *alphas_grad) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + alphas += base; + transmittance += base; + transmittance_grad += base; + alphas_grad += base; + + // accumulation + float cumsum = 0.0f; + for (int j = steps - 1; j >= 0; --j) + { + alphas_grad[j] = cumsum / fmax(1.0f - alphas[j], 1e-10f); + cumsum += -transmittance_grad[j] * transmittance[j]; + } + return; +} + +torch::Tensor transmittance_from_sigma_forward_naive( + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(starts); + CHECK_INPUT(ends); + CHECK_INPUT(sigmas); + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); + TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); + TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); + + const uint32_t n_samples = sigmas.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor transmittance = torch::empty_like(sigmas); + + // parallel across rays + transmittance_from_sigma_forward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + starts.data_ptr(), + ends.data_ptr(), + sigmas.data_ptr(), + // outputs + transmittance.data_ptr()); + return transmittance; +} + +torch::Tensor transmittance_from_sigma_backward_naive( + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor transmittance, + torch::Tensor transmittance_grad) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(starts); + CHECK_INPUT(ends); + CHECK_INPUT(transmittance); + CHECK_INPUT(transmittance_grad); + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); + TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); + TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1); + TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1); + + const uint32_t n_samples = transmittance.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor sigmas_grad = torch::empty_like(transmittance); + + // parallel across rays + transmittance_from_sigma_backward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + starts.data_ptr(), + ends.data_ptr(), + transmittance.data_ptr(), + transmittance_grad.data_ptr(), + // outputs + sigmas_grad.data_ptr()); + return sigmas_grad; +} + +torch::Tensor transmittance_from_alpha_forward_naive( + torch::Tensor packed_info, torch::Tensor alphas) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(alphas); + TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1); + TORCH_CHECK(packed_info.ndimension() == 2); + + const uint32_t n_samples = alphas.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor transmittance = torch::empty_like(alphas); + + // parallel across rays + transmittance_from_alpha_forward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + alphas.data_ptr(), + // outputs + transmittance.data_ptr()); + return transmittance; +} + +torch::Tensor transmittance_from_alpha_backward_naive( + torch::Tensor packed_info, + torch::Tensor alphas, + torch::Tensor transmittance, + torch::Tensor transmittance_grad) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(transmittance); + CHECK_INPUT(transmittance_grad); + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1); + TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1); + + const uint32_t n_samples = transmittance.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor alphas_grad = torch::empty_like(alphas); + + // parallel across rays + transmittance_from_alpha_backward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + alphas.data_ptr(), + transmittance.data_ptr(), + transmittance_grad.data_ptr(), + // outputs + alphas_grad.data_ptr()); + return alphas_grad; +} diff --git a/nerfacc/cuda/csrc/render_transmittance_cub.cu b/nerfacc/cuda/csrc/render_transmittance_cub.cu new file mode 100644 index 00000000..ae84377f --- /dev/null +++ b/nerfacc/cuda/csrc/render_transmittance_cub.cu @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2022 Ruilong Li, UC Berkeley. + */ +// CUB is supported in CUDA >= 11.0 +// ExclusiveScanByKey is supported in CUB >= 1.15.0 (CUDA >= 11.6) +// See: https://github.com/NVIDIA/cub/tree/main#releases +#include "include/helpers_cuda.h" +#if CUB_SUPPORTS_SCAN_BY_KEY() +#include +#endif + +struct Product +{ + template + __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; } +}; + +#if CUB_SUPPORTS_SCAN_BY_KEY() +template +inline void exclusive_sum_by_key( + KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) +{ + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub ExclusiveSumByKey does not support more than INT_MAX elements"); + CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output, + num_items, cub::Equality(), at::cuda::getCurrentCUDAStream()); +} + +template +inline void exclusive_prod_by_key( + KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) +{ + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub ExclusiveScanByKey does not support more than INT_MAX elements"); + CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f, + num_items, cub::Equality(), at::cuda::getCurrentCUDAStream()); +} +#endif + +torch::Tensor transmittance_from_sigma_forward_cub( + torch::Tensor ray_indices, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas) +{ + DEVICE_GUARD(ray_indices); + CHECK_INPUT(ray_indices); + CHECK_INPUT(starts); + CHECK_INPUT(ends); + CHECK_INPUT(sigmas); + TORCH_CHECK(ray_indices.ndimension() == 1); + TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); + TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); + TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); + + const uint32_t n_samples = sigmas.size(0); + + // parallel across samples + torch::Tensor sigmas_dt = sigmas * (ends - starts); + torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas); +#if CUB_SUPPORTS_SCAN_BY_KEY() + exclusive_sum_by_key( + ray_indices.data_ptr(), + sigmas_dt.data_ptr(), + sigmas_dt_cumsum.data_ptr(), + n_samples); +#else + std::runtime_error("CUB functions are only supported in CUDA >= 11.6."); +#endif + torch::Tensor transmittance = (-sigmas_dt_cumsum).exp(); + return transmittance; +} + +torch::Tensor transmittance_from_sigma_backward_cub( + torch::Tensor ray_indices, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor transmittance, + torch::Tensor transmittance_grad) +{ + DEVICE_GUARD(ray_indices); + CHECK_INPUT(ray_indices); + CHECK_INPUT(starts); + CHECK_INPUT(ends); + CHECK_INPUT(transmittance); + CHECK_INPUT(transmittance_grad); + TORCH_CHECK(ray_indices.ndimension() == 1); + TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); + TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); + TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1); + TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1); + + const uint32_t n_samples = transmittance.size(0); + + // parallel across samples + torch::Tensor sigmas_dt_cumsum_grad = -transmittance_grad * transmittance; + torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad); +#if CUB_SUPPORTS_SCAN_BY_KEY() + exclusive_sum_by_key( + thrust::make_reverse_iterator(ray_indices.data_ptr() + n_samples), + thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr() + n_samples), + thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr() + n_samples), + n_samples); +#else + std::runtime_error("CUB functions are only supported in CUDA >= 11.6."); +#endif + torch::Tensor sigmas_grad = sigmas_dt_grad * (ends - starts); + return sigmas_grad; +} + +torch::Tensor transmittance_from_alpha_forward_cub( + torch::Tensor ray_indices, torch::Tensor alphas) +{ + DEVICE_GUARD(ray_indices); + CHECK_INPUT(ray_indices); + CHECK_INPUT(alphas); + TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1); + TORCH_CHECK(ray_indices.ndimension() == 1); + + const uint32_t n_samples = alphas.size(0); + + // parallel across samples + torch::Tensor transmittance = torch::empty_like(alphas); +#if CUB_SUPPORTS_SCAN_BY_KEY() + exclusive_prod_by_key( + ray_indices.data_ptr(), + (1.0f - alphas).data_ptr(), + transmittance.data_ptr(), + n_samples); +#else + std::runtime_error("CUB functions are only supported in CUDA >= 11.6."); +#endif + return transmittance; +} + +torch::Tensor transmittance_from_alpha_backward_cub( + torch::Tensor ray_indices, + torch::Tensor alphas, + torch::Tensor transmittance, + torch::Tensor transmittance_grad) +{ + DEVICE_GUARD(ray_indices); + CHECK_INPUT(ray_indices); + CHECK_INPUT(transmittance); + CHECK_INPUT(transmittance_grad); + TORCH_CHECK(ray_indices.ndimension() == 1); + TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1); + TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1); + + const uint32_t n_samples = transmittance.size(0); + + // parallel across samples + torch::Tensor sigmas_dt_cumsum_grad = -transmittance_grad * transmittance; + torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad); +#if CUB_SUPPORTS_SCAN_BY_KEY() + exclusive_sum_by_key( + thrust::make_reverse_iterator(ray_indices.data_ptr() + n_samples), + thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr() + n_samples), + thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr() + n_samples), + n_samples); +#else + std::runtime_error("CUB functions are only supported in CUDA >= 11.6."); +#endif + torch::Tensor alphas_grad = sigmas_dt_grad / (1.0f - alphas).clamp_min(1e-10f); + return alphas_grad; +} diff --git a/nerfacc/cuda/csrc/render_weight.cu b/nerfacc/cuda/csrc/render_weight.cu new file mode 100644 index 00000000..27a22a16 --- /dev/null +++ b/nerfacc/cuda/csrc/render_weight.cu @@ -0,0 +1,307 @@ +/* + * Copyright (c) 2022 Ruilong Li, UC Berkeley. + */ + +#include "include/helpers_cuda.h" + +__global__ void weight_from_sigma_forward_kernel( + const uint32_t n_rays, + const int *packed_info, + const float *starts, + const float *ends, + const float *sigmas, + // outputs + float *weights) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + starts += base; + ends += base; + sigmas += base; + weights += base; + + // accumulation + float T = 1.f; + for (int j = 0; j < steps; ++j) + { + const float delta = ends[j] - starts[j]; + const float alpha = 1.f - __expf(-sigmas[j] * delta); + weights[j] = alpha * T; + T *= (1.f - alpha); + } + return; +} + +__global__ void weight_from_sigma_backward_kernel( + const uint32_t n_rays, + const int *packed_info, + const float *starts, + const float *ends, + const float *sigmas, + const float *weights, + const float *grad_weights, + // outputs + float *grad_sigmas) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + starts += base; + ends += base; + sigmas += base; + weights += base; + grad_weights += base; + grad_sigmas += base; + + float accum = 0; + for (int j = 0; j < steps; ++j) + { + accum += grad_weights[j] * weights[j]; + } + + // accumulation + float T = 1.f; + for (int j = 0; j < steps; ++j) + { + const float delta = ends[j] - starts[j]; + const float alpha = 1.f - __expf(-sigmas[j] * delta); + grad_sigmas[j] = (grad_weights[j] * T - accum) * delta; + accum -= grad_weights[j] * weights[j]; + T *= (1.f - alpha); + } + return; +} + +__global__ void weight_from_alpha_forward_kernel( + const uint32_t n_rays, + const int *packed_info, + const float *alphas, + // outputs + float *weights) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + alphas += base; + weights += base; + + // accumulation + float T = 1.f; + for (int j = 0; j < steps; ++j) + { + const float alpha = alphas[j]; + weights[j] = alpha * T; + T *= (1.f - alpha); + } + return; +} + +__global__ void weight_from_alpha_backward_kernel( + const uint32_t n_rays, + const int *packed_info, + const float *alphas, + const float *weights, + const float *grad_weights, + // outputs + float *grad_alphas) +{ + CUDA_GET_THREAD_ID(i, n_rays); + + // locate + const int base = packed_info[i * 2 + 0]; + const int steps = packed_info[i * 2 + 1]; + if (steps == 0) + return; + + alphas += base; + weights += base; + grad_weights += base; + grad_alphas += base; + + float accum = 0; + for (int j = 0; j < steps; ++j) + { + accum += grad_weights[j] * weights[j]; + } + + // accumulation + float T = 1.f; + for (int j = 0; j < steps; ++j) + { + const float alpha = alphas[j]; + grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f); + accum -= grad_weights[j] * weights[j]; + T *= (1.f - alpha); + } + return; +} + +torch::Tensor weight_from_sigma_forward_naive( + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(starts); + CHECK_INPUT(ends); + CHECK_INPUT(sigmas); + + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); + TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); + TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); + + const uint32_t n_samples = sigmas.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor weights = torch::empty_like(sigmas); + + weight_from_sigma_forward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + starts.data_ptr(), + ends.data_ptr(), + sigmas.data_ptr(), + // outputs + weights.data_ptr()); + return weights; +} + +torch::Tensor weight_from_sigma_backward_naive( + torch::Tensor weights, + torch::Tensor grad_weights, + torch::Tensor packed_info, + torch::Tensor starts, + torch::Tensor ends, + torch::Tensor sigmas) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(weights); + CHECK_INPUT(grad_weights); + CHECK_INPUT(packed_info); + CHECK_INPUT(starts); + CHECK_INPUT(ends); + CHECK_INPUT(sigmas); + + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); + TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); + TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); + TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1); + TORCH_CHECK(grad_weights.ndimension() == 2 & grad_weights.size(1) == 1); + + const uint32_t n_samples = sigmas.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor grad_sigmas = torch::empty_like(sigmas); + + weight_from_sigma_backward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + starts.data_ptr(), + ends.data_ptr(), + sigmas.data_ptr(), + weights.data_ptr(), + grad_weights.data_ptr(), + // outputs + grad_sigmas.data_ptr()); + + return grad_sigmas; +} + +torch::Tensor weight_from_alpha_forward_naive( + torch::Tensor packed_info, torch::Tensor alphas) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(alphas); + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1); + + const uint32_t n_samples = alphas.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor weights = torch::empty_like(alphas); + + weight_from_alpha_forward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + alphas.data_ptr(), + // outputs + weights.data_ptr()); + return weights; +} + +torch::Tensor weight_from_alpha_backward_naive( + torch::Tensor weights, + torch::Tensor grad_weights, + torch::Tensor packed_info, + torch::Tensor alphas) +{ + DEVICE_GUARD(packed_info); + CHECK_INPUT(packed_info); + CHECK_INPUT(alphas); + CHECK_INPUT(weights); + CHECK_INPUT(grad_weights); + TORCH_CHECK(packed_info.ndimension() == 2); + TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1); + TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1); + TORCH_CHECK(grad_weights.ndimension() == 2 & grad_weights.size(1) == 1); + + const uint32_t n_samples = alphas.size(0); + const uint32_t n_rays = packed_info.size(0); + + const int threads = 256; + const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); + + // outputs + torch::Tensor grad_alphas = torch::empty_like(alphas); + + weight_from_alpha_backward_kernel<<< + blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( + n_rays, + // inputs + packed_info.data_ptr(), + alphas.data_ptr(), + weights.data_ptr(), + grad_weights.data_ptr(), + // outputs + grad_alphas.data_ptr()); + return grad_alphas; +} diff --git a/nerfacc/cuda/csrc/rendering.cu b/nerfacc/cuda/csrc/rendering.cu deleted file mode 100644 index 043d2b4a..00000000 --- a/nerfacc/cuda/csrc/rendering.cu +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Copyright (c) 2022 Ruilong Li, UC Berkeley. - */ - -#include "include/helpers_cuda.h" - -template -__global__ void rendering_forward_kernel( - const uint32_t n_rays, - const int *packed_info, // input ray & point indices. - const scalar_t *starts, // input start t - const scalar_t *ends, // input end t - const scalar_t *sigmas, // input density after activation - const scalar_t *alphas, // input alpha (opacity) values. - const scalar_t early_stop_eps, // transmittance threshold for early stop - const scalar_t alpha_thre, // alpha threshold for emtpy space - // outputs: should be all-zero initialized - int *num_steps, // the number of valid steps for each ray - scalar_t *weights, // the number rendering weights for each sample - bool *compact_selector // the samples that we needs to compute the gradients -) -{ - CUDA_GET_THREAD_ID(i, n_rays); - - // locate - const int base = packed_info[i * 2 + 0]; // point idx start. - const int steps = packed_info[i * 2 + 1]; // point idx shift. - if (steps == 0) - return; - - if (alphas != nullptr) - { - // rendering with alpha - alphas += base; - } - else - { - // rendering with density - starts += base; - ends += base; - sigmas += base; - } - - if (num_steps != nullptr) - { - num_steps += i; - } - if (weights != nullptr) - { - weights += base; - } - if (compact_selector != nullptr) - { - compact_selector += base; - } - - // accumulated rendering - scalar_t T = 1.f; - int cnt = 0; - for (int j = 0; j < steps; ++j) - { - if (T < early_stop_eps) - { - break; - } - scalar_t alpha; - if (alphas != nullptr) - { - // rendering with alpha - alpha = alphas[j]; - } - else - { - // rendering with density - scalar_t delta = ends[j] - starts[j]; - alpha = 1.f - __expf(-sigmas[j] * delta); - } - if (alpha < alpha_thre) - { - // empty space - continue; - } - const scalar_t weight = alpha * T; - T *= (1.f - alpha); - if (weights != nullptr) - { - weights[j] = weight; - } - if (compact_selector != nullptr) - { - compact_selector[j] = true; - } - cnt += 1; - } - if (num_steps != nullptr) - { - *num_steps = cnt; - } - return; -} - -template -__global__ void rendering_backward_kernel( - const uint32_t n_rays, - const int *packed_info, // input ray & point indices. - const scalar_t *starts, // input start t - const scalar_t *ends, // input end t - const scalar_t *sigmas, // input density after activation - const scalar_t *alphas, // input alpha (opacity) values. - const scalar_t early_stop_eps, // transmittance threshold for early stop - const scalar_t alpha_thre, // alpha threshold for emtpy space - const scalar_t *weights, // forward output - const scalar_t *grad_weights, // input gradients - // if alphas was given, we compute the gradients for alphas. - // otherwise, we compute the gradients for sigmas. - scalar_t *grad_sigmas, // output gradients - scalar_t *grad_alphas // output gradients -) -{ - CUDA_GET_THREAD_ID(i, n_rays); - - // locate - const int base = packed_info[i * 2 + 0]; // point idx start. - const int steps = packed_info[i * 2 + 1]; // point idx shift. - if (steps == 0) - return; - - if (alphas != nullptr) - { - // rendering with alpha - alphas += base; - grad_alphas += base; - } - else - { - // rendering with density - starts += base; - ends += base; - sigmas += base; - grad_sigmas += base; - } - - weights += base; - grad_weights += base; - - scalar_t accum = 0; - for (int j = 0; j < steps; ++j) - { - accum += grad_weights[j] * weights[j]; - } - - // backward of accumulated rendering - scalar_t T = 1.f; - for (int j = 0; j < steps; ++j) - { - if (T < early_stop_eps) - { - break; - } - scalar_t alpha; - if (alphas != nullptr) - { - // rendering with alpha - alpha = alphas[j]; - if (alpha < alpha_thre) - { - // empty space - continue; - } - grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f); - } - else - { - // rendering with density - scalar_t delta = ends[j] - starts[j]; - alpha = 1.f - __expf(-sigmas[j] * delta); - if (alpha < alpha_thre) - { - // empty space - continue; - } - grad_sigmas[j] = (grad_weights[j] * T - accum) * delta; - } - - accum -= grad_weights[j] * weights[j]; - T *= (1.f - alpha); - } -} - -std::vector rendering_forward( - torch::Tensor packed_info, - torch::Tensor starts, - torch::Tensor ends, - torch::Tensor sigmas, - float early_stop_eps, - float alpha_thre, - bool compression) -{ - DEVICE_GUARD(packed_info); - - CHECK_INPUT(packed_info); - CHECK_INPUT(starts); - CHECK_INPUT(ends); - CHECK_INPUT(sigmas); - - TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2); - TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); - TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); - TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); - - const uint32_t n_rays = packed_info.size(0); - const uint32_t n_samples = sigmas.size(0); - - const int threads = 256; - const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); - - if (compression) - { - // compress the samples to get rid of invisible ones. - torch::Tensor num_steps = torch::zeros({n_rays}, packed_info.options()); - torch::Tensor compact_selector = torch::zeros( - {n_samples}, sigmas.options().dtype(torch::kBool)); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - sigmas.scalar_type(), - "rendering_forward", - ([&] - { rendering_forward_kernel<<>>( - n_rays, - // inputs - packed_info.data_ptr(), - starts.data_ptr(), - ends.data_ptr(), - sigmas.data_ptr(), - nullptr, // alphas - early_stop_eps, - alpha_thre, - // outputs - num_steps.data_ptr(), - nullptr, - compact_selector.data_ptr()); })); - - torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32); - torch::Tensor compact_packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1); - return {compact_packed_info, compact_selector}; - } - else - { - // just do the forward rendering. - torch::Tensor weights = torch::zeros({n_samples}, sigmas.options()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - sigmas.scalar_type(), - "rendering_forward", - ([&] - { rendering_forward_kernel<<>>( - n_rays, - // inputs - packed_info.data_ptr(), - starts.data_ptr(), - ends.data_ptr(), - sigmas.data_ptr(), - nullptr, // alphas - early_stop_eps, - alpha_thre, - // outputs - nullptr, - weights.data_ptr(), - nullptr); })); - - return {weights}; - } -} - -torch::Tensor rendering_backward( - torch::Tensor weights, - torch::Tensor grad_weights, - torch::Tensor packed_info, - torch::Tensor starts, - torch::Tensor ends, - torch::Tensor sigmas, - float early_stop_eps, - float alpha_thre) -{ - DEVICE_GUARD(packed_info); - const uint32_t n_rays = packed_info.size(0); - const uint32_t n_samples = sigmas.size(0); - - const int threads = 256; - const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); - - // outputs - torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - sigmas.scalar_type(), - "rendering_backward", - ([&] - { rendering_backward_kernel<<>>( - n_rays, - // inputs - packed_info.data_ptr(), - starts.data_ptr(), - ends.data_ptr(), - sigmas.data_ptr(), - nullptr, // alphas - early_stop_eps, - alpha_thre, - weights.data_ptr(), - grad_weights.data_ptr(), - // outputs - grad_sigmas.data_ptr(), - nullptr // alphas gradients - ); })); - - return grad_sigmas; -} - -// -- rendering with alphas -- // - -std::vector rendering_alphas_forward( - torch::Tensor packed_info, - torch::Tensor alphas, - float early_stop_eps, - float alpha_thre, - bool compression) -{ - DEVICE_GUARD(packed_info); - - CHECK_INPUT(packed_info); - CHECK_INPUT(alphas); - - TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2); - TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1); - - const uint32_t n_rays = packed_info.size(0); - const uint32_t n_samples = alphas.size(0); - - const int threads = 256; - const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); - - if (compression) - { - // compress the samples to get rid of invisible ones. - torch::Tensor num_steps = torch::zeros({n_rays}, packed_info.options()); - torch::Tensor compact_selector = torch::zeros( - {n_samples}, alphas.options().dtype(torch::kBool)); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - alphas.scalar_type(), - "rendering_alphas_forward", - ([&] - { rendering_forward_kernel<<>>( - n_rays, - // inputs - packed_info.data_ptr(), - nullptr, // starts - nullptr, // ends - nullptr, // sigmas - alphas.data_ptr(), - early_stop_eps, - alpha_thre, - // outputs - num_steps.data_ptr(), - nullptr, - compact_selector.data_ptr()); })); - - torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32); - torch::Tensor compact_packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1); - return {compact_selector, compact_packed_info}; - } - else - { - // just do the forward rendering. - torch::Tensor weights = torch::zeros({n_samples}, alphas.options()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - alphas.scalar_type(), - "rendering_forward", - ([&] - { rendering_forward_kernel<<>>( - n_rays, - // inputs - packed_info.data_ptr(), - nullptr, // starts - nullptr, // ends - nullptr, // sigmas - alphas.data_ptr(), - early_stop_eps, - alpha_thre, - // outputs - nullptr, - weights.data_ptr(), - nullptr); })); - - return {weights}; - } -} - -torch::Tensor rendering_alphas_backward( - torch::Tensor weights, - torch::Tensor grad_weights, - torch::Tensor packed_info, - torch::Tensor alphas, - float early_stop_eps, - float alpha_thre) -{ - DEVICE_GUARD(packed_info); - const uint32_t n_rays = packed_info.size(0); - const uint32_t n_samples = alphas.size(0); - - const int threads = 256; - const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); - - // outputs - torch::Tensor grad_alphas = torch::zeros(alphas.sizes(), alphas.options()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - alphas.scalar_type(), - "rendering_alphas_backward", - ([&] - { rendering_backward_kernel<<>>( - n_rays, - // inputs - packed_info.data_ptr(), - nullptr, // starts - nullptr, // ends - nullptr, // sigmas - alphas.data_ptr(), - early_stop_eps, - alpha_thre, - weights.data_ptr(), - grad_weights.data_ptr(), - // outputs - nullptr, // sigma gradients - grad_alphas.data_ptr()); })); - - return grad_alphas; -} diff --git a/nerfacc/pack.py b/nerfacc/pack.py index 2c0d492e..b4cd8adc 100644 --- a/nerfacc/pack.py +++ b/nerfacc/pack.py @@ -44,7 +44,41 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: @torch.no_grad() -def unpack_info(packed_info: Tensor) -> Tensor: +def pack_info(ray_indices: Tensor, n_rays: int = None) -> Tensor: + """Pack `ray_indices` to `packed_info`. Useful for converting per sample data to per ray data. + + Note: + this function is not differentiable to any inputs. + + Args: + ray_indices: Ray index of each sample. LongTensor with shape (n_sample). + + Returns: + packed_info: Stores information on which samples belong to the same ray. \ + See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). + """ + assert ( + ray_indices.dim() == 1 + ), "ray_indices must be a 1D tensor with shape (n_samples)." + if ray_indices.is_cuda: + ray_indices = ray_indices.contiguous().int() + device = ray_indices.device + if n_rays is None: + n_rays = int(ray_indices.max()) + 1 + # else: + # assert n_rays > ray_indices.max() + src = torch.ones_like(ray_indices) + num_steps = torch.zeros((n_rays,), device=device, dtype=torch.int) + num_steps.scatter_add_(0, ray_indices.long(), src) + cum_steps = num_steps.cumsum(dim=0, dtype=torch.int) + packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1) + else: + raise NotImplementedError("Only support cuda inputs.") + return packed_info.int() + + +@torch.no_grad() +def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor: """Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data. Note: @@ -53,6 +87,7 @@ def unpack_info(packed_info: Tensor) -> Tensor: Args: packed_info: Stores information on which samples belong to the same ray. \ See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). + n_samples: Total number of samples. Returns: Ray index of each sample. LongTensor with shape (n_sample). @@ -71,7 +106,7 @@ def unpack_info(packed_info: Tensor) -> Tensor: # torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1]) print(packed_info.shape, t_starts.shape, t_ends.shape) # Unpack per-ray info to per-sample info. - ray_indices = unpack_info(packed_info) + ray_indices = unpack_info(packed_info, t_starts.shape[0]) # torch.Size([115200]) torch.int64 print(ray_indices.shape, ray_indices.dtype) @@ -80,7 +115,7 @@ def unpack_info(packed_info: Tensor) -> Tensor: packed_info.dim() == 2 and packed_info.shape[-1] == 2 ), "packed_info must be a 2D tensor with shape (n_rays, 2)." if packed_info.is_cuda: - ray_indices = _C.unpack_info(packed_info.contiguous().int()) + ray_indices = _C.unpack_info(packed_info.contiguous().int(), n_samples) else: raise NotImplementedError("Only support cuda inputs.") return ray_indices.long() diff --git a/nerfacc/ray_marching.py b/nerfacc/ray_marching.py index 1f5eecd2..6cd3db99 100644 --- a/nerfacc/ray_marching.py +++ b/nerfacc/ray_marching.py @@ -7,7 +7,6 @@ from .contraction import ContractionType from .grid import Grid from .intersection import ray_aabb_intersect -from .pack import unpack_info from .vol_rendering import render_visibility @@ -82,10 +81,7 @@ def ray_marching( Returns: A tuple of tensors. - - **packed_info**: Stores information on which samples belong to the same ray. \ - Tensor with shape (n_rays, 2). The first column stores the index of the \ - first sample of each ray. The second column stores the number of samples \ - of each ray. + - **ray_indices**: Ray index of each sample. IntTensor with shape (n_samples). - **t_starts**: Per-sample start distance. Tensor with shape (n_samples, 1). - **t_ends**: Per-sample end distance. Tensor with shape (n_samples, 1). @@ -103,32 +99,31 @@ def ray_marching( rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) # Ray marching with near far plane. - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3 ) # Ray marching with aabb. scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device=device) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, scene_aabb=scene_aabb, render_step_size=1e-3 ) # Ray marching with per-ray t_min and t_max. t_min = torch.zeros((batch_size,), device=device) t_max = torch.ones((batch_size,), device=device) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, t_min=t_min, t_max=t_max, render_step_size=1e-3 ) # Ray marching with aabb and skip areas based on occupancy grid. scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device=device) grid = OccupancyGrid(roi_aabb=[0.0, 0.0, 0.0, 0.5, 0.5, 0.5]).to(device) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, scene_aabb=scene_aabb, grid=grid, render_step_size=1e-3 ) # Convert t_starts and t_ends to sample locations. - ray_indices = unpack_info(packed_info) t_mid = (t_starts + t_ends) / 2.0 sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] @@ -179,7 +174,7 @@ def ray_marching( contraction_type = ContractionType.AABB.to_cpp_version() # marching with grid-based skipping - packed_info, t_starts, t_ends = _C.ray_marching( + packed_info, ray_indices, t_starts, t_ends = _C.ray_marching( # rays rays_o.contiguous(), rays_d.contiguous(), @@ -197,7 +192,6 @@ def ray_marching( # skip invisible space if sigma_fn is not None or alpha_fn is not None: # Query sigma without gradients - ray_indices = unpack_info(packed_info) if sigma_fn is not None: sigmas = sigma_fn(t_starts, t_ends, ray_indices.long()) assert ( @@ -211,10 +205,16 @@ def ray_marching( ), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape) # Compute visibility of the samples, and filter out invisible samples - visibility, packed_info_visible = render_visibility( - packed_info, alphas, early_stop_eps, alpha_thre + masks = render_visibility( + alphas, + ray_indices=ray_indices, + early_stop_eps=early_stop_eps, + alpha_thre=alpha_thre, + ) + ray_indices, t_starts, t_ends = ( + ray_indices[masks], + t_starts[masks], + t_ends[masks], ) - t_starts, t_ends = t_starts[visibility], t_ends[visibility] - packed_info = packed_info_visible - return packed_info, t_starts, t_ends + return ray_indices, t_starts, t_ends diff --git a/nerfacc/vol_rendering.py b/nerfacc/vol_rendering.py index e392ebdb..7b14d676 100644 --- a/nerfacc/vol_rendering.py +++ b/nerfacc/vol_rendering.py @@ -9,44 +9,43 @@ import nerfacc.cuda as _C -from .pack import unpack_info +from .pack import pack_info def rendering( # ray marching results - packed_info: torch.Tensor, t_starts: torch.Tensor, t_ends: torch.Tensor, + ray_indices: torch.Tensor, + n_rays: int, # radiance field rgb_sigma_fn: Optional[Callable] = None, rgb_alpha_fn: Optional[Callable] = None, # rendering options - early_stop_eps: float = 1e-4, - alpha_thre: float = 0.0, render_bkgd: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Render the rays through the radience field defined by `rgb_sigma_fn`. - This function is differentiable to the outputs of `rgb_sigma_fn` so it can be used for - gradient-based optimization. + This function is differentiable to the outputs of `rgb_sigma_fn` so it can + be used for gradient-based optimization. + + Note: + Either `rgb_sigma_fn` or `rgb_alpha_fn` should be provided. Warning: - This function is not differentiable to `t_starts`, `t_ends`. + This function is not differentiable to `t_starts`, `t_ends` and `ray_indices`. Args: - packed_info: Packed ray marching info. See :func:`ray_marching` for details. t_starts: Per-sample start distance. Tensor with shape (n_samples, 1). t_ends: Per-sample end distance. Tensor with shape (n_samples, 1). + ray_indices: Ray index of each sample. IntTensor with shape (n_samples). + n_rays: Total number of rays. This will decide the shape of the ouputs. rgb_sigma_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \ ray indices (N,)} and returns the post-activation rgb (N, 3) and density \ - values (N, 1). At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be \ - specified. + values (N, 1). rgb_alpha_fn: A function that takes in samples {t_starts (N, 1), t_ends (N, 1), \ ray indices (N,)} and returns the post-activation rgb (N, 3) and opacity \ - values (N, 1). At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be \ - specified. - early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4. - alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. + values (N, 1). render_bkgd: Optional. Background color. Tensor with shape (3,). Returns: @@ -56,48 +55,27 @@ def rendering( .. code-block:: python - import torch - from nerfacc import OccupancyGrid, ray_marching, rendering - - device = "cuda:0" - batch_size = 128 - rays_o = torch.rand((batch_size, 3), device=device) - rays_d = torch.randn((batch_size, 3), device=device) - rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) - - # Ray marching. - packed_info, t_starts, t_ends = ray_marching( - rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3 - ) - - # Rendering. - def rgb_sigma_fn(t_starts, t_ends, ray_indices): - # This is a dummy function that returns random values. - rgbs = torch.rand((t_starts.shape[0], 3), device=device) - sigmas = torch.rand((t_starts.shape[0], 1), device=device) - return rgbs, sigmas - colors, opacities, depths = rendering(rgb_sigma_fn, packed_info, t_starts, t_ends) - - # torch.Size([128, 3]) torch.Size([128, 1]) torch.Size([128, 1]) - print(colors.shape, opacities.shape, depths.shape) + >>> rays_o = torch.rand((128, 3), device="cuda:0") + >>> rays_d = torch.randn((128, 3), device="cuda:0") + >>> rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) + >>> ray_indices, t_starts, t_ends = ray_marching( + >>> rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3) + >>> def rgb_sigma_fn(t_starts, t_ends, ray_indices): + >>> # This is a dummy function that returns random values. + >>> rgbs = torch.rand((t_starts.shape[0], 3), device="cuda:0") + >>> sigmas = torch.rand((t_starts.shape[0], 1), device="cuda:0") + >>> return rgbs, sigmas + >>> colors, opacities, depths = rendering( + >>> t_starts, t_ends, ray_indices, n_rays=128, rgb_sigma_fn=rgb_sigma_fn) + >>> print(colors.shape, opacities.shape, depths.shape) + torch.Size([128, 3]) torch.Size([128, 1]) torch.Size([128, 1]) """ - if callable(packed_info): - raise RuntimeError( - "You maybe want to use the nerfacc<=0.2.1 version. For nerfacc>0.2.1, " - "The first argument of `rendering` should be the packed ray packed info. " - "See the latest documentation for details: " - "https://www.nerfacc.com/en/latest/apis/rendering.html#nerfacc.rendering" - ) - if rgb_sigma_fn is None and rgb_alpha_fn is None: raise ValueError( "At least one of `rgb_sigma_fn` and `rgb_alpha_fn` should be specified." ) - n_rays = packed_info.shape[0] - ray_indices = unpack_info(packed_info) - # Query sigma/alpha and color with gradients if rgb_sigma_fn is not None: rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long()) @@ -107,9 +85,13 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): assert ( sigmas.shape == t_starts.shape ), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape) - # Rendering: compute weights and ray indices. + # Rendering: compute weights. weights = render_weight_from_density( - packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre + t_starts, + t_ends, + sigmas, + ray_indices=ray_indices, + n_rays=n_rays, ) elif rgb_alpha_fn is not None: rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices.long()) @@ -119,9 +101,11 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): assert ( alphas.shape == t_starts.shape ), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape) - # Rendering: compute weights and ray indices. + # Rendering: compute weights. weights = render_weight_from_alpha( - packed_info, alphas, early_stop_eps, alpha_thre + alphas, + ray_indices=ray_indices, + n_rays=n_rays, ) # Rendering: accumulate rgbs, opacities, and depths along the rays. @@ -159,8 +143,7 @@ def accumulate_along_rays( Args: weights: Volumetric rendering weights for those samples. Tensor with shape \ (n_samples,). - ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \ - It can be obtained from `unpack_info(packed_info)`. + ray_indices: Ray index of each sample. IntTensor with shape (n_samples). values: The values to be accmulated. Tensor with shape (n_samples, D). If \ None, the accumulated values are just weights. Default is None. n_rays: Total number of rays. This will decide the shape of the ouputs. If \ @@ -188,16 +171,16 @@ def accumulate_along_rays( print(colors.shape, opacities.shape, depths.shape) """ - assert ray_indices.dim() == 1 and weights.dim() == 1 + assert ray_indices.dim() == 1 and weights.dim() == 2 if not weights.is_cuda: raise NotImplementedError("Only support cuda inputs.") if values is not None: assert ( values.dim() == 2 and values.shape[0] == weights.shape[0] ), "Invalid shapes: {} vs {}".format(values.shape, weights.shape) - src = weights[:, None] * values + src = weights * values else: - src = weights[:, None] + src = weights if ray_indices.numel() == 0: assert n_rays is not None @@ -205,8 +188,7 @@ def accumulate_along_rays( if n_rays is None: n_rays = int(ray_indices.max()) + 1 - else: - assert n_rays > ray_indices.max() + # assert n_rays > ray_indices.max() ray_indices = ray_indices.int() index = ray_indices[:, None].long().expand(-1, src.shape[-1]) @@ -215,276 +197,473 @@ def accumulate_along_rays( return outputs -def render_weight_from_density( - packed_info, - t_starts, - t_ends, - sigmas, - early_stop_eps: float = 1e-4, - alpha_thre: float = 0.0, -) -> torch.Tensor: - """Compute transmittance weights from density. +def render_transmittance_from_density( + t_starts: Tensor, + t_ends: Tensor, + sigmas: Tensor, + *, + packed_info: Optional[torch.Tensor] = None, + ray_indices: Optional[torch.Tensor] = None, + n_rays: Optional[int] = None, +) -> Tensor: + """Compute transmittance :math:`T_i` from density :math:`\\sigma_i`. + + .. math:: + T_i = exp(-\\sum_{j=1}^{i-1}\\sigma_j\delta_j) + + Note: + Either `ray_indices` or `packed_info` should be provided. If `ray_indices` is + provided, CUB acceleration will be used if available (CUDA >= 11.6). Otherwise, + we will use the naive implementation with `packed_info`. Args: - packed_info: Stores information on which samples belong to the same ray. \ - See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). t_starts: Where the frustum-shape sample starts along a ray. Tensor with \ shape (n_samples, 1). t_ends: Where the frustum-shape sample ends along a ray. Tensor with \ shape (n_samples, 1). sigmas: The density values of the samples. Tensor with shape (n_samples, 1). - early_stop_eps: The epsilon value for early stopping. Default is 1e-4. - alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. + packed_info: Optional. Stores information on which samples belong to the same ray. \ + See :func:`nerfacc.ray_marching` for details. LongTensor with shape (n_rays, 2). + ray_indices: Optional. Ray index of each sample. LongTensor with shape (n_sample). + n_rays: Optional. Number of rays. Only useful when `ray_indices` is provided yet \ + CUB acceleration is not available. We will implicitly convert `ray_indices` to \ + `packed_info` and use the naive implementation. If not provided, we will infer \ + it from `ray_indices` but it will be slower. + + Returns: + The rendering transmittance. Tensor with shape (n_sample, 1). + + Examples: + + .. code-block:: python + + >>> t_starts = torch.tensor( + >>> [[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0]], device="cuda") + >>> t_ends = torch.tensor( + >>> [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]], device="cuda") + >>> sigmas = torch.tensor( + >>> [[0.4], [0.8], [0.1], [0.8], [0.1], [0.0], [0.9]], device="cuda") + >>> ray_indices = torch.tensor([0, 0, 0, 1, 1, 2, 2], device="cuda") + >>> transmittance = render_transmittance_from_density( + >>> t_starts, t_ends, sigmas, ray_indices=ray_indices) + [[1.00], [0.67], [0.30], [1.00], [0.45], [1.00], [1.00]] + + """ + assert ( + ray_indices is not None or packed_info is not None + ), "Either ray_indices or packed_info should be provided." + if ray_indices is not None and _C.is_cub_available(): + transmittance = _RenderingTransmittanceFromDensityCUB.apply( + ray_indices, t_starts, t_ends, sigmas + ) + else: + if packed_info is None: + packed_info = pack_info(ray_indices, n_rays=n_rays) + transmittance = _RenderingTransmittanceFromDensityNaive.apply( + packed_info, t_starts, t_ends, sigmas + ) + return transmittance + + +def render_transmittance_from_alpha( + alphas: Tensor, + *, + packed_info: Optional[torch.Tensor] = None, + ray_indices: Optional[torch.Tensor] = None, + n_rays: Optional[int] = None, +) -> Tensor: + """Compute transmittance :math:`T_i` from alpha :math:`\\alpha_i`. + .. math:: + T_i = \\prod_{j=1}^{i-1}(1-\\alpha_j) + + Note: + Either `ray_indices` or `packed_info` should be provided. If `ray_indices` is + provided, CUB acceleration will be used if available (CUDA >= 11.6). Otherwise, + we will use the naive implementation with `packed_info`. + + Args: + alphas: The opacity values of the samples. Tensor with shape (n_samples, 1). + packed_info: Optional. Stores information on which samples belong to the same ray. \ + See :func:`nerfacc.ray_marching` for details. LongTensor with shape (n_rays, 2). + ray_indices: Optional. Ray index of each sample. LongTensor with shape (n_sample). + n_rays: Optional. Number of rays. Only useful when `ray_indices` is provided yet \ + CUB acceleration is not available. We will implicitly convert `ray_indices` to \ + `packed_info` and use the naive implementation. If not provided, we will infer \ + it from `ray_indices` but it will be slower. + Returns: - transmittance weights with shape (n_samples,). + The rendering transmittance. Tensor with shape (n_sample, 1). Examples: .. code-block:: python - rays_o = torch.rand((128, 3), device="cuda:0") - rays_d = torch.randn((128, 3), device="cuda:0") - rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) + >>> alphas = torch.tensor( + >>> [[0.4], [0.8], [0.1], [0.8], [0.1], [0.0], [0.9]], device="cuda")) + >>> ray_indices = torch.tensor([0, 0, 0, 1, 1, 2, 2], device="cuda") + >>> transmittance = render_transmittance_from_alpha(alphas, ray_indices=ray_indices) + tensor([[1.0], [0.6], [0.12], [1.0], [0.2], [1.0], [1.0]]) - # Ray marching with near far plane. - packed_info, t_starts, t_ends = ray_marching( - rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3 + """ + assert ( + ray_indices is not None or packed_info is not None + ), "Either ray_indices or packed_info should be provided." + if ray_indices is not None and _C.is_cub_available(): + transmittance = _RenderingTransmittanceFromAlphaCUB.apply( + ray_indices, alphas ) - # pesudo density - sigmas = torch.rand((t_starts.shape[0], 1), device="cuda:0") - # Rendering: compute weights and ray indices. - weights = render_weight_from_density( - packed_info, t_starts, t_ends, sigmas, early_stop_eps=1e-4 + else: + if packed_info is None: + packed_info = pack_info(ray_indices, n_rays=n_rays) + transmittance = _RenderingTransmittanceFromAlphaNaive.apply( + packed_info, alphas ) - # torch.Size([115200, 1]) torch.Size([115200]) - print(sigmas.shape, weights.shape) + return transmittance + + +def render_weight_from_density( + t_starts: Tensor, + t_ends: Tensor, + sigmas: Tensor, + *, + packed_info: Optional[torch.Tensor] = None, + ray_indices: Optional[torch.Tensor] = None, + n_rays: Optional[int] = None, +) -> torch.Tensor: + """Compute rendering weights :math:`w_i` from density :math:`\\sigma_i` and interval :math:`\\delta_i`. + + .. math:: + w_i = T_i(1 - exp(-\\sigma_i\delta_i)), \\quad\\textrm{where}\\quad T_i = exp(-\\sum_{j=1}^{i-1}\\sigma_j\delta_j) + + Note: + Either `ray_indices` or `packed_info` should be provided. If `ray_indices` is + provided, CUB acceleration will be used if available (CUDA >= 11.6). Otherwise, + we will use the naive implementation with `packed_info`. + Args: + t_starts: Where the frustum-shape sample starts along a ray. Tensor with \ + shape (n_samples, 1). + t_ends: Where the frustum-shape sample ends along a ray. Tensor with \ + shape (n_samples, 1). + sigmas: The density values of the samples. Tensor with shape (n_samples, 1). + packed_info: Optional. Stores information on which samples belong to the same ray. \ + See :func:`nerfacc.ray_marching` for details. LongTensor with shape (n_rays, 2). + ray_indices: Optional. Ray index of each sample. LongTensor with shape (n_sample). + n_rays: Optional. Number of rays. Only useful when `ray_indices` is provided yet \ + CUB acceleration is not available. We will implicitly convert `ray_indices` to \ + `packed_info` and use the naive implementation. If not provided, we will infer \ + it from `ray_indices` but it will be slower. + + Returns: + The rendering weights. Tensor with shape (n_sample, 1). + + Examples: + + .. code-block:: python + + >>> t_starts = torch.tensor( + >>> [[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0]], device="cuda") + >>> t_ends = torch.tensor( + >>> [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]], device="cuda") + >>> sigmas = torch.tensor( + >>> [[0.4], [0.8], [0.1], [0.8], [0.1], [0.0], [0.9]], device="cuda") + >>> ray_indices = torch.tensor([0, 0, 0, 1, 1, 2, 2], device="cuda") + >>> weights = render_weight_from_density( + >>> t_starts, t_ends, sigmas, ray_indices=ray_indices) + [[0.33], [0.37], [0.03], [0.55], [0.04], [0.00], [0.59]] + """ - if not sigmas.is_cuda: - raise NotImplementedError("Only support cuda inputs.") - weights = _RenderingDensity.apply( - packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre - ) + assert ( + ray_indices is not None or packed_info is not None + ), "Either ray_indices or packed_info should be provided." + if ray_indices is not None and _C.is_cub_available(): + transmittance = _RenderingTransmittanceFromDensityCUB.apply( + ray_indices, t_starts, t_ends, sigmas + ) + alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) + weights = transmittance * alphas + else: + if packed_info is None: + packed_info = pack_info(ray_indices, n_rays=n_rays) + weights = _RenderingWeightFromDensityNaive.apply( + packed_info, t_starts, t_ends, sigmas + ) return weights def render_weight_from_alpha( - packed_info, - alphas, - early_stop_eps: float = 1e-4, - alpha_thre: float = 0.0, + alphas: Tensor, + *, + packed_info: Optional[torch.Tensor] = None, + ray_indices: Optional[torch.Tensor] = None, + n_rays: Optional[int] = None, ) -> torch.Tensor: - """Compute transmittance weights from opacity. + """Compute rendering weights :math:`w_i` from opacity :math:`\\alpha_i`. + + .. math:: + w_i = T_i\\alpha_i, \\quad\\textrm{where}\\quad T_i = \\prod_{j=1}^{i-1}(1-\\alpha_j) + + Note: + Either `ray_indices` or `packed_info` should be provided. If `ray_indices` is + provided, CUB acceleration will be used if available (CUDA >= 11.6). Otherwise, + we will use the naive implementation with `packed_info`. Args: - packed_info: Stores information on which samples belong to the same ray. \ - See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). alphas: The opacity values of the samples. Tensor with shape (n_samples, 1). - early_stop_eps: The epsilon value for early stopping. Default is 1e-4. - alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. + packed_info: Optional. Stores information on which samples belong to the same ray. \ + See :func:`nerfacc.ray_marching` for details. LongTensor with shape (n_rays, 2). + ray_indices: Optional. Ray index of each sample. LongTensor with shape (n_sample). + n_rays: Optional. Number of rays. Only useful when `ray_indices` is provided yet \ + CUB acceleration is not available. We will implicitly convert `ray_indices` to \ + `packed_info` and use the naive implementation. If not provided, we will infer \ + it from `ray_indices` but it will be slower. Returns: - transmittance weights with shape (n_samples,). + The rendering weights. Tensor with shape (n_sample, 1). Examples: .. code-block:: python - rays_o = torch.rand((128, 3), device="cuda:0") - rays_d = torch.randn((128, 3), device="cuda:0") - rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) - - # Ray marching with near far plane. - packed_info, t_starts, t_ends = ray_marching( - rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3 - ) - # pesudo opacity - alphas = torch.rand((t_starts.shape[0], 1), device="cuda:0") - # Rendering: compute weights and ray indices. - weights = render_weight_from_alpha( - packed_info, alphas, early_stop_eps=1e-4 - ) - # torch.Size([115200, 1]) torch.Size([115200]) - print(alphas.shape, weights.shape) + >>> alphas = torch.tensor( + >>> [[0.4], [0.8], [0.1], [0.8], [0.1], [0.0], [0.9]], device="cuda")) + >>> ray_indices = torch.tensor([0, 0, 0, 1, 1, 2, 2], device="cuda") + >>> weights = render_weight_from_alpha(alphas, ray_indices=ray_indices) + tensor([[0.4], [0.48], [0.012], [0.8], [0.02], [0.0], [0.9]]) """ - if not alphas.is_cuda: - raise NotImplementedError("Only support cuda inputs.") - weights = _RenderingAlpha.apply( - packed_info, alphas, early_stop_eps, alpha_thre - ) + assert ( + ray_indices is not None or packed_info is not None + ), "Either ray_indices or packed_info should be provided." + if ray_indices is not None and _C.is_cub_available(): + transmittance = _RenderingTransmittanceFromAlphaCUB.apply( + ray_indices, alphas + ) + weights = transmittance * alphas + else: + if packed_info is None: + packed_info = pack_info(ray_indices, n_rays=n_rays) + weights = _RenderingWeightFromAlphaNaive.apply(packed_info, alphas) return weights @torch.no_grad() def render_visibility( - packed_info: torch.Tensor, alphas: torch.Tensor, + *, + ray_indices: Optional[torch.Tensor] = None, + packed_info: Optional[torch.Tensor] = None, + n_rays: Optional[int] = None, early_stop_eps: float = 1e-4, alpha_thre: float = 0.0, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Filter out invisible samples given alpha (opacity). +) -> torch.Tensor: + """Filter out transparent and occluded samples. + + In this function, we first compute the transmittance from the sample opacity. The + transmittance is then used to filter out occluded samples. And opacity is used to + filter out transparent samples. The function returns a boolean tensor indicating + which samples are visible (`transmittance > early_stop_eps` and `opacity > alpha_thre`). + + Note: + Either `ray_indices` or `packed_info` should be provided. If `ray_indices` is + provided, CUB acceleration will be used if available (CUDA >= 11.6). Otherwise, + we will use the naive implementation with `packed_info`. Args: - packed_info: Stores information on which samples belong to the same ray. \ - See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). alphas: The opacity values of the samples. Tensor with shape (n_samples, 1). - early_stop_eps: The epsilon value for early stopping. Default is 1e-4. - alpha_thre: Alpha threshold for skipping empty space. Default: 0.0. + packed_info: Optional. Stores information on which samples belong to the same ray. \ + See :func:`nerfacc.ray_marching` for details. LongTensor with shape (n_rays, 2). + ray_indices: Optional. Ray index of each sample. LongTensor with shape (n_sample). + n_rays: Optional. Number of rays. Only useful when `ray_indices` is provided yet \ + CUB acceleration is not available. We will implicitly convert `ray_indices` to \ + `packed_info` and use the naive implementation. If not provided, we will infer \ + it from `ray_indices` but it will be slower. + early_stop_eps: The early stopping threshold on transmittance. + alpha_thre: The threshold on opacity. Returns: - A tuple of tensors. - - - **visibility**: The visibility mask for samples. Boolen tensor of shape \ - (n_samples,). - - **packed_info_visible**: The new packed_info for visible samples. \ - Tensor shape (n_rays, 2). It should be used if you use the visiblity \ - mask to filter out invisible samples. + The visibility of each sample. Tensor with shape (n_samples, 1). Examples: .. code-block:: python - rays_o = torch.rand((128, 3), device="cuda:0") - rays_d = torch.randn((128, 3), device="cuda:0") - rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) + >>> alphas = torch.tensor( + >>> [[0.4], [0.8], [0.1], [0.8], [0.1], [0.0], [0.9]], device="cuda") + >>> ray_indices = torch.tensor([0, 0, 0, 1, 1, 2, 2], device="cuda") + >>> transmittance = render_transmittance_from_alpha(alphas, ray_indices=ray_indices) + tensor([[1.0], [0.6], [0.12], [1.0], [0.2], [1.0], [1.0]]) + >>> visibility = render_visibility( + >>> alphas, ray_indices=ray_indices, early_stop_eps=0.3, alpha_thre=0.2) + tensor([True, True, False, True, False, False, True]) - # Ray marching with near far plane. - packed_info, t_starts, t_ends = ray_marching( - rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3 + """ + assert ( + ray_indices is not None or packed_info is not None + ), "Either ray_indices or packed_info should be provided." + if ray_indices is not None and _C.is_cub_available(): + transmittance = _RenderingTransmittanceFromAlphaCUB.apply( + ray_indices, alphas ) - # pesudo opacity - alphas = torch.rand((t_starts.shape[0], 1), device="cuda:0") - # Rendering but only for computing visibility of each samples. - visibility, packed_info_visible = render_visibility( - packed_info, alphas, early_stop_eps=1e-4 + else: + if packed_info is None: + packed_info = pack_info(ray_indices, n_rays=n_rays) + transmittance = _RenderingTransmittanceFromAlphaNaive.apply( + packed_info, alphas ) - t_starts_visible = t_starts[visibility] - t_ends_visible = t_ends[visibility] - # torch.Size([115200, 1]) torch.Size([1283, 1]) - print(t_starts.shape, t_starts_visible.shape) + visibility = transmittance >= early_stop_eps + if alpha_thre > 0: + visibility = visibility & (alphas >= alpha_thre) + visibility = visibility.squeeze(-1) + return visibility - """ - visibility, packed_info_visible = _C.rendering_alphas_forward( - packed_info.contiguous(), - alphas.contiguous(), - early_stop_eps, - alpha_thre, - True, # compute visibility instead of weights - ) - return visibility, packed_info_visible + +class _RenderingTransmittanceFromDensityCUB(torch.autograd.Function): + """Rendering transmittance from density with CUB implementation.""" + + @staticmethod + def forward(ctx, ray_indices, t_starts, t_ends, sigmas): + ray_indices = ray_indices.contiguous().int() + t_starts = t_starts.contiguous() + t_ends = t_ends.contiguous() + sigmas = sigmas.contiguous() + transmittance = _C.transmittance_from_sigma_forward_cub( + ray_indices, t_starts, t_ends, sigmas + ) + if ctx.needs_input_grad[3]: + ctx.save_for_backward(ray_indices, t_starts, t_ends, transmittance) + return transmittance + + @staticmethod + def backward(ctx, transmittance_grads): + transmittance_grads = transmittance_grads.contiguous() + ray_indices, t_starts, t_ends, transmittance = ctx.saved_tensors + grad_sigmas = _C.transmittance_from_sigma_backward_cub( + ray_indices, t_starts, t_ends, transmittance, transmittance_grads + ) + return None, None, None, grad_sigmas -class _RenderingDensity(torch.autograd.Function): - """Rendering transmittance weights from density.""" +class _RenderingTransmittanceFromDensityNaive(torch.autograd.Function): + """Rendering transmittance from density with naive forloop.""" @staticmethod - def forward( - ctx, - packed_info, - t_starts, - t_ends, - sigmas, - early_stop_eps: float = 1e-4, - alpha_thre: float = 0.0, - ): - packed_info = packed_info.contiguous() + def forward(ctx, packed_info, t_starts, t_ends, sigmas): + packed_info = packed_info.contiguous().int() t_starts = t_starts.contiguous() t_ends = t_ends.contiguous() sigmas = sigmas.contiguous() - weights = _C.rendering_forward( - packed_info, - t_starts, - t_ends, - sigmas, - early_stop_eps, - alpha_thre, - False, # not doing filtering - )[0] - if ctx.needs_input_grad[3]: # sigmas + transmittance = _C.transmittance_from_sigma_forward_naive( + packed_info, t_starts, t_ends, sigmas + ) + if ctx.needs_input_grad[3]: + ctx.save_for_backward(packed_info, t_starts, t_ends, transmittance) + return transmittance + + @staticmethod + def backward(ctx, transmittance_grads): + transmittance_grads = transmittance_grads.contiguous() + packed_info, t_starts, t_ends, transmittance = ctx.saved_tensors + grad_sigmas = _C.transmittance_from_sigma_backward_naive( + packed_info, t_starts, t_ends, transmittance, transmittance_grads + ) + return None, None, None, grad_sigmas + + +class _RenderingTransmittanceFromAlphaCUB(torch.autograd.Function): + """Rendering transmittance from opacity with CUB implementation.""" + + @staticmethod + def forward(ctx, ray_indices, alphas): + ray_indices = ray_indices.contiguous().int() + alphas = alphas.contiguous() + transmittance = _C.transmittance_from_alpha_forward_cub( + ray_indices, alphas + ) + if ctx.needs_input_grad[1]: + ctx.save_for_backward(ray_indices, transmittance, alphas) + return transmittance + + @staticmethod + def backward(ctx, transmittance_grads): + transmittance_grads = transmittance_grads.contiguous() + ray_indices, transmittance, alphas = ctx.saved_tensors + grad_alphas = _C.transmittance_from_alpha_backward_cub( + ray_indices, alphas, transmittance, transmittance_grads + ) + return None, grad_alphas + + +class _RenderingTransmittanceFromAlphaNaive(torch.autograd.Function): + """Rendering transmittance from opacity with naive forloop.""" + + @staticmethod + def forward(ctx, packed_info, alphas): + packed_info = packed_info.contiguous().int() + alphas = alphas.contiguous() + transmittance = _C.transmittance_from_alpha_forward_naive( + packed_info, alphas + ) + if ctx.needs_input_grad[1]: + ctx.save_for_backward(packed_info, transmittance, alphas) + return transmittance + + @staticmethod + def backward(ctx, transmittance_grads): + transmittance_grads = transmittance_grads.contiguous() + packed_info, transmittance, alphas = ctx.saved_tensors + grad_alphas = _C.transmittance_from_alpha_backward_naive( + packed_info, alphas, transmittance, transmittance_grads + ) + return None, grad_alphas + + +class _RenderingWeightFromDensityNaive(torch.autograd.Function): + """Rendering weight from density with naive forloop.""" + + @staticmethod + def forward(ctx, packed_info, t_starts, t_ends, sigmas): + packed_info = packed_info.contiguous().int() + t_starts = t_starts.contiguous() + t_ends = t_ends.contiguous() + sigmas = sigmas.contiguous() + weights = _C.weight_from_sigma_forward_naive( + packed_info, t_starts, t_ends, sigmas + ) + if ctx.needs_input_grad[3]: ctx.save_for_backward( - packed_info, - t_starts, - t_ends, - sigmas, - weights, + packed_info, t_starts, t_ends, sigmas, weights ) - ctx.early_stop_eps = early_stop_eps - ctx.alpha_thre = alpha_thre return weights @staticmethod def backward(ctx, grad_weights): grad_weights = grad_weights.contiguous() - early_stop_eps = ctx.early_stop_eps - alpha_thre = ctx.alpha_thre - ( - packed_info, - t_starts, - t_ends, - sigmas, - weights, - ) = ctx.saved_tensors - grad_sigmas = _C.rendering_backward( - weights, - grad_weights, - packed_info, - t_starts, - t_ends, - sigmas, - early_stop_eps, - alpha_thre, + packed_info, t_starts, t_ends, sigmas, weights = ctx.saved_tensors + grad_sigmas = _C.weight_from_sigma_backward_naive( + weights, grad_weights, packed_info, t_starts, t_ends, sigmas ) - return None, None, None, grad_sigmas, None, None + return None, None, None, grad_sigmas -class _RenderingAlpha(torch.autograd.Function): - """Rendering transmittance weights from alpha.""" +class _RenderingWeightFromAlphaNaive(torch.autograd.Function): + """Rendering weight from opacity with naive forloop.""" @staticmethod - def forward( - ctx, - packed_info, - alphas, - early_stop_eps: float = 1e-4, - alpha_thre: float = 0.0, - ): - packed_info = packed_info.contiguous() + def forward(ctx, packed_info, alphas): + packed_info = packed_info.contiguous().int() alphas = alphas.contiguous() - weights = _C.rendering_alphas_forward( - packed_info, - alphas, - early_stop_eps, - alpha_thre, - False, # not doing filtering - )[0] - if ctx.needs_input_grad[1]: # alphas - ctx.save_for_backward( - packed_info, - alphas, - weights, - ) - ctx.early_stop_eps = early_stop_eps - ctx.alpha_thre = alpha_thre + weights = _C.weight_from_alpha_forward_naive(packed_info, alphas) + if ctx.needs_input_grad[1]: + ctx.save_for_backward(packed_info, alphas, weights) return weights @staticmethod def backward(ctx, grad_weights): grad_weights = grad_weights.contiguous() - early_stop_eps = ctx.early_stop_eps - alpha_thre = ctx.alpha_thre - ( - packed_info, - alphas, - weights, - ) = ctx.saved_tensors - grad_sigmas = _C.rendering_alphas_backward( - weights, - grad_weights, - packed_info, - alphas, - early_stop_eps, - alpha_thre, + packed_info, alphas, weights = ctx.saved_tensors + grad_alphas = _C.weight_from_alpha_backward_naive( + weights, grad_weights, packed_info, alphas ) - return None, grad_sigmas, None, None + return None, grad_alphas diff --git a/pyproject.toml b/pyproject.toml index 1e1f7064..61194910 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nerfacc" -version = "0.2.4" +version = "0.3.0" description = "A General NeRF Acceleration Toolbox." readme = "README.md" authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}] diff --git a/scripts/run_profiler.py b/scripts/run_profiler.py index b4d70049..9b99e291 100644 --- a/scripts/run_profiler.py +++ b/scripts/run_profiler.py @@ -1,9 +1,13 @@ from typing import Callable import torch +import tqdm import nerfacc +# timing +# https://github.com/pytorch/pytorch/commit/d2784c233bfc57a1d836d961694bcc8ec4ed45e4 + class Profiler: def __init__(self, warmup=10, repeat=1000): @@ -30,6 +34,7 @@ def __call__(self, func: Callable): # return events = prof.key_averages() + # print(events.table(sort_by="self_cpu_time_total", row_limit=10)) self_cpu_time_total = ( sum([event.self_cpu_time_total for event in events]) / self.repeat ) @@ -49,15 +54,62 @@ def __call__(self, func: Callable): def main(): device = "cuda:0" torch.manual_seed(42) - profiler = Profiler(warmup=10, repeat=1000) - - # contract - print("* contract") - x = torch.rand([1024, 3], device=device) - roi = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.float32, device=device) - fn = lambda: nerfacc.contract( - x, roi=roi, type=nerfacc.ContractionType.UN_BOUNDED_TANH + profiler = Profiler(warmup=10, repeat=100) + + # # contract + # print("* contract") + # x = torch.rand([1024, 3], device=device) + # roi = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.float32, device=device) + # fn = lambda: nerfacc.contract( + # x, roi=roi, type=nerfacc.ContractionType.UN_BOUNDED_TANH + # ) + # cpu_t, cuda_t, cuda_bytes = profiler(fn) + # print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB") + + # rendering + print("* rendering") + batch_size = 81920 + rays_o = torch.rand((batch_size, 3), device=device) + rays_d = torch.randn((batch_size, 3), device=device) + rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) + + ray_indices, t_starts, t_ends = nerfacc.ray_marching( + rays_o, + rays_d, + near_plane=0.1, + far_plane=1.0, + render_step_size=1e-1, + ) + sigmas = torch.randn_like(t_starts, requires_grad=True) + fn = ( + lambda: nerfacc.render_weight_from_density( + ray_indices, t_starts, t_ends, sigmas + ) + .sum() + .backward() + ) + fn() + torch.cuda.synchronize() + for _ in tqdm.tqdm(range(100)): + fn() + torch.cuda.synchronize() + + cpu_t, cuda_t, cuda_bytes = profiler(fn) + print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB") + + packed_info = nerfacc.pack_info(ray_indices, n_rays=batch_size).int() + fn = ( + lambda: nerfacc.vol_rendering._RenderingDensity.apply( + packed_info, t_starts, t_ends, sigmas, 0 + ) + .sum() + .backward() ) + fn() + torch.cuda.synchronize() + for _ in tqdm.tqdm(range(100)): + fn() + torch.cuda.synchronize() cpu_t, cuda_t, cuda_bytes = profiler(fn) print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB") diff --git a/tests/test_loss.py b/tests/test_loss.py index b3952768..5c597e8a 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1,7 +1,7 @@ import pytest import torch -from nerfacc import ray_marching +from nerfacc import pack_info, ray_marching from nerfacc.losses import distortion device = "cuda:0" @@ -15,13 +15,14 @@ def test_distortion(): rays_d = torch.randn((batch_size, 3), device=device) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3, ) + packed_info = pack_info(ray_indices, n_rays=batch_size) weights = torch.rand((t_starts.shape[0],), device=device) loss = distortion(packed_info, weights, t_starts, t_ends) assert loss.shape == (batch_size,) diff --git a/tests/test_pack.py b/tests/test_pack.py index d7e3369c..9f6a0350 100644 --- a/tests/test_pack.py +++ b/tests/test_pack.py @@ -1,7 +1,7 @@ import pytest import torch -from nerfacc import pack_data, unpack_data, unpack_info +from nerfacc import pack_data, pack_info, unpack_data, unpack_info device = "cuda:0" batch_size = 32 @@ -31,7 +31,9 @@ def test_unpack_info(): ray_indices_tgt = torch.tensor( [0, 2, 2, 2, 2], dtype=torch.int64, device=device ) - ray_indices = unpack_info(packed_info) + ray_indices = unpack_info(packed_info, n_samples=5) + packed_info_2 = pack_info(ray_indices, n_rays=packed_info.shape[0]) + assert torch.allclose(packed_info.int(), packed_info_2.int()) assert torch.allclose(ray_indices, ray_indices_tgt) diff --git a/tests/test_ray_marching.py b/tests/test_ray_marching.py index 7c623dc0..28b9b33c 100644 --- a/tests/test_ray_marching.py +++ b/tests/test_ray_marching.py @@ -13,7 +13,7 @@ def test_marching_with_near_far(): rays_d = torch.randn((batch_size, 3), device=device) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, near_plane=0.1, @@ -31,7 +31,7 @@ def test_marching_with_grid(): grid = OccupancyGrid(roi_aabb=[0, 0, 0, 1, 1, 1]).to(device) grid._binary[:] = True - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, grid=grid, @@ -39,7 +39,7 @@ def test_marching_with_grid(): far_plane=1.0, render_step_size=1e-2, ) - ray_indices = unpack_info(packed_info).long() + ray_indices = ray_indices.long() samples = ( rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0 ) diff --git a/tests/test_rendering.py b/tests/test_rendering.py index 37437182..9c71a68d 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -3,6 +3,7 @@ from nerfacc import ( accumulate_along_rays, + render_transmittance_from_density, render_visibility, render_weight_from_alpha, render_weight_from_density, @@ -16,9 +17,9 @@ @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") def test_render_visibility(): - packed_info = torch.tensor( - [[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device - ) # (n_rays, 2) + ray_indices = torch.tensor( + [0, 2, 2, 2, 2], dtype=torch.int32, device=device + ) # (samples,) alphas = torch.tensor( [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device ).unsqueeze( @@ -26,37 +27,29 @@ def test_render_visibility(): ) # (n_samples, 1) # transmittance: [1.0, 1.0, 0.7, 0.14, 0.028] - vis, packed_info_vis = render_visibility( - packed_info, alphas, early_stop_eps=0.03, alpha_thre=0.0 + vis = render_visibility( + alphas, ray_indices=ray_indices, early_stop_eps=0.03, alpha_thre=0.0 ) vis_tgt = torch.tensor( [True, True, True, True, False], dtype=torch.bool, device=device ) - packed_info_vis_tgt = torch.tensor( - [[0, 1], [1, 0], [1, 3]], dtype=torch.int32, device=device - ) # (n_rays, 2) assert torch.allclose(vis, vis_tgt) - assert torch.allclose(packed_info_vis, packed_info_vis_tgt) # transmittance: [1.0, 1.0, 1.0, 0.2, 0.04] - vis, packed_info_vis = render_visibility( - packed_info, alphas, early_stop_eps=0.05, alpha_thre=0.35 + vis = render_visibility( + alphas, ray_indices=ray_indices, early_stop_eps=0.05, alpha_thre=0.35 ) vis_tgt = torch.tensor( [True, False, True, True, False], dtype=torch.bool, device=device ) - packed_info_vis_tgt = torch.tensor( - [[0, 1], [1, 0], [1, 2]], dtype=torch.int32, device=device - ) # (n_rays, 2) assert torch.allclose(vis, vis_tgt) - assert torch.allclose(packed_info_vis, packed_info_vis_tgt) @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") def test_render_weight_from_alpha(): - packed_info = torch.tensor( - [[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device - ) # (n_rays, 2) + ray_indices = torch.tensor( + [0, 2, 2, 2, 2], dtype=torch.int32, device=device + ) # (samples,) alphas = torch.tensor( [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device ).unsqueeze( @@ -65,64 +58,160 @@ def test_render_weight_from_alpha(): # transmittance: [1.0, 1.0, 0.7, 0.14, 0.028] weights = render_weight_from_alpha( - packed_info, alphas, early_stop_eps=0.03, alpha_thre=0.0 + alphas, ray_indices=ray_indices, n_rays=3 ) weights_tgt = torch.tensor( - [1.0 * 0.4, 1.0 * 0.3, 0.7 * 0.8, 0.14 * 0.8, 0.0 * 0.0], + [1.0 * 0.4, 1.0 * 0.3, 0.7 * 0.8, 0.14 * 0.8, 0.028 * 0.5], dtype=torch.float32, device=device, - ) + ).unsqueeze(-1) assert torch.allclose(weights, weights_tgt) +@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") def test_render_weight_from_density(): - packed_info = torch.tensor( - [[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device - ) # (n_rays, 2) - sigmas = torch.rand((batch_size, 1), device=device) # (n_samples, 1) + ray_indices = torch.tensor( + [0, 2, 2, 2, 2], dtype=torch.int32, device=device + ) # (samples,) + sigmas = torch.rand( + (ray_indices.shape[0], 1), device=device + ) # (n_samples, 1) t_starts = torch.rand_like(sigmas) t_ends = torch.rand_like(sigmas) + 1.0 alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) - weights = render_weight_from_density(packed_info, t_starts, t_ends, sigmas) - weights_tgt = render_weight_from_alpha(packed_info, alphas) + weights = render_weight_from_density( + t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=3 + ) + weights_tgt = render_weight_from_alpha( + alphas, ray_indices=ray_indices, n_rays=3 + ) assert torch.allclose(weights, weights_tgt) +@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") def test_accumulate_along_rays(): ray_indices = torch.tensor( [0, 2, 2, 2, 2], dtype=torch.int32, device=device - ) # (n_rays, 2) + ) # (n_rays,) weights = torch.tensor( [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device - ) + ).unsqueeze(-1) values = torch.rand((5, 2), device=device) # (n_samples, 1) ray_values = accumulate_along_rays( weights, ray_indices, values=values, n_rays=3 ) assert ray_values.shape == (3, 2) - assert torch.allclose(ray_values[0, :], weights[0, None] * values[0, :]) + assert torch.allclose(ray_values[0, :], weights[0, :] * values[0, :]) assert (ray_values[1, :] == 0).all() assert torch.allclose( - ray_values[2, :], (weights[1:, None] * values[1:]).sum(dim=0) + ray_values[2, :], (weights[1:, :] * values[1:]).sum(dim=0) ) +@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") def test_rendering(): def rgb_sigma_fn(t_starts, t_ends, ray_indices): return torch.hstack([t_starts] * 3), t_starts - packed_info = torch.tensor( - [[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device - ) # (n_rays, 2) - sigmas = torch.rand((5, 1), device=device) # (n_samples, 1) + ray_indices = torch.tensor( + [0, 2, 2, 2, 2], dtype=torch.int32, device=device + ) # (samples,) + sigmas = torch.rand( + (ray_indices.shape[0], 1), device=device + ) # (n_samples, 1) t_starts = torch.rand_like(sigmas) t_ends = torch.rand_like(sigmas) + 1.0 _, _, _ = rendering( - packed_info, t_starts, t_ends, rgb_sigma_fn=rgb_sigma_fn + t_starts, + t_ends, + ray_indices=ray_indices, + n_rays=3, + rgb_sigma_fn=rgb_sigma_fn, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") +def test_grads(): + ray_indices = torch.tensor( + [0, 2, 2, 2, 2], dtype=torch.int32, device=device + ) # (samples,) + packed_info = torch.tensor( + [[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device + ) + sigmas = torch.tensor([[0.4], [0.8], [0.1], [0.8], [0.1]], device="cuda") + sigmas.requires_grad = True + t_starts = torch.rand_like(sigmas) + t_ends = t_starts + 1.0 + + weights_ref = torch.tensor( + [[0.3297], [0.5507], [0.0428], [0.2239], [0.0174]], device="cuda" + ) + sigmas_grad_ref = torch.tensor( + [[0.6703], [0.1653], [0.1653], [0.1653], [0.1653]], device="cuda" + ) + + # naive impl. trans from sigma + trans = render_transmittance_from_density( + t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=3 + ) + weights = trans * (1.0 - torch.exp(-sigmas * (t_ends - t_starts))) + weights.sum().backward() + sigmas_grad = sigmas.grad.clone() + sigmas.grad.zero_() + assert torch.allclose(weights_ref, weights, atol=1e-4) + assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4) + + # naive impl. trans from alpha + trans = render_transmittance_from_density( + t_starts, t_ends, sigmas, packed_info=packed_info, n_rays=3 + ) + weights = trans * (1.0 - torch.exp(-sigmas * (t_ends - t_starts))) + weights.sum().backward() + sigmas_grad = sigmas.grad.clone() + sigmas.grad.zero_() + assert torch.allclose(weights_ref, weights, atol=1e-4) + assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4) + + weights = render_weight_from_density( + t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=3 + ) + weights.sum().backward() + sigmas_grad = sigmas.grad.clone() + sigmas.grad.zero_() + assert torch.allclose(weights_ref, weights, atol=1e-4) + assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4) + + weights = render_weight_from_density( + t_starts, t_ends, sigmas, packed_info=packed_info, n_rays=3 + ) + weights.sum().backward() + sigmas_grad = sigmas.grad.clone() + sigmas.grad.zero_() + assert torch.allclose(weights_ref, weights, atol=1e-4) + assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4) + + alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) + weights = render_weight_from_alpha( + alphas, ray_indices=ray_indices, n_rays=3 + ) + weights.sum().backward() + sigmas_grad = sigmas.grad.clone() + sigmas.grad.zero_() + assert torch.allclose(weights_ref, weights, atol=1e-4) + assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4) + + alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) + weights = render_weight_from_alpha( + alphas, packed_info=packed_info, n_rays=3 ) + weights.sum().backward() + sigmas_grad = sigmas.grad.clone() + sigmas.grad.zero_() + assert torch.allclose(weights_ref, weights, atol=1e-4) + assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4) if __name__ == "__main__": @@ -131,3 +220,4 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): test_render_weight_from_density() test_accumulate_along_rays() test_rendering() + test_grads() diff --git a/tests/test_resampling.py b/tests/test_resampling.py index 70323c1d..1ac517b7 100644 --- a/tests/test_resampling.py +++ b/tests/test_resampling.py @@ -1,7 +1,7 @@ import pytest import torch -from nerfacc import ray_marching, ray_resampling +from nerfacc import pack_info, ray_marching, ray_resampling device = "cuda:0" batch_size = 128 @@ -13,13 +13,14 @@ def test_resampling(): rays_d = torch.randn((batch_size, 3), device=device) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) - packed_info, t_starts, t_ends = ray_marching( + ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3, ) + packed_info = pack_info(ray_indices, n_rays=batch_size) weights = torch.rand((t_starts.shape[0],), device=device) packed_info, t_starts, t_ends = ray_resampling( packed_info, t_starts, t_ends, weights, n_samples=32