Skip to content

Commit

Permalink
Cub (#103)
Browse files Browse the repository at this point in the history
- Faster rendering function via nvidia-cub, shipped with cuda >= 11.0 (Require >=11.6 for out use). ~10% speedup
- Expose transmittance computation.
  • Loading branch information
liruilong940607 authored Nov 7, 2022
1 parent bca2d4d commit 674424e
Show file tree
Hide file tree
Showing 27 changed files with 1,686 additions and 835 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,18 @@ def rgb_sigma_fn(
return rgbs, sigmas # (n_samples, 3), (n_samples, 1)

# Efficient Raymarching: Skip empty and occluded space, pack samples from all rays.
# packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1).
# ray_indices: (n_samples,). t_starts: (n_samples, 1). t_ends: (n_samples, 1).
with torch.no_grad():
packed_info, t_starts, t_ends = nerfacc.ray_marching(
ray_indices, t_starts, t_ends = nerfacc.ray_marching(
rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0,
early_stop_eps=1e-4, alpha_thre=1e-2,
)

# Differentiable Volumetric Rendering.
# colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1).
color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends)
color, opacity, depth = nerfacc.rendering(
t_starts, t_ends, ray_indices, n_rays=rays_o.shape[0], rgb_sigma_fn=rgb_sigma_fn
)

# Optimize: Both the network and rays will receive gradients
optimizer.zero_grad()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nerfacc.render\_transmittance\_from\_alpha
==========================================

.. currentmodule:: nerfacc

.. autofunction:: render_transmittance_from_alpha
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nerfacc.render\_transmittance\_from\_density
============================================

.. currentmodule:: nerfacc

.. autofunction:: render_transmittance_from_density
2 changes: 2 additions & 0 deletions docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Utils
unpack_info

accumulate_along_rays
render_transmittance_from_density
render_transmittance_from_alpha
render_weight_from_density
render_weight_from_alpha
render_visibility
Expand Down
8 changes: 5 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,18 @@ An simple example is like this:
return rgbs, sigmas # (n_samples, 3), (n_samples, 1)
# Efficient Raymarching: Skip empty and occluded space, pack samples from all rays.
# packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1).
# ray_indices: (n_samples,). t_starts: (n_samples, 1). t_ends: (n_samples, 1).
with torch.no_grad():
packed_info, t_starts, t_ends = nerfacc.ray_marching(
ray_indices, t_starts, t_ends = nerfacc.ray_marching(
rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0,
early_stop_eps=1e-4, alpha_thre=1e-2,
)
# Differentiable Volumetric Rendering.
# colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1).
color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends)
color, opacity, depth = nerfacc.rendering(
t_starts, t_ends, ray_indices, n_rays=rays_o.shape[0], rgb_sigma_fn=rgb_sigma_fn
)
# Optimize: Both the network and rays will receive gradients
optimizer.zero_grad()
Expand Down
5 changes: 3 additions & 2 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
packed_info, t_starts, t_ends = ray_marching(
ray_indices, t_starts, t_ends = ray_marching(
chunk_rays.origins,
chunk_rays.viewdirs,
scene_aabb=scene_aabb,
Expand All @@ -99,9 +99,10 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
alpha_thre=alpha_thre,
)
rgb, opacity, depth = rendering(
packed_info,
t_starts,
t_ends,
ray_indices,
n_rays=chunk_rays.origins.shape[0],
rgb_sigma_fn=rgb_sigma_fn,
render_bkgd=render_bkgd,
)
Expand Down
7 changes: 6 additions & 1 deletion nerfacc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from .grid import Grid, OccupancyGrid, query_grid
from .intersection import ray_aabb_intersect
from .losses import distortion as loss_distortion
from .pack import pack_data, unpack_data, unpack_info
from .pack import pack_data, pack_info, unpack_data, unpack_info
from .ray_marching import ray_marching
from .version import __version__
from .vol_rendering import (
accumulate_along_rays,
render_transmittance_from_alpha,
render_transmittance_from_density,
render_visibility,
render_weight_from_alpha,
render_weight_from_density,
Expand Down Expand Up @@ -48,7 +50,10 @@ def unpack_to_ray_indices(*args, **kwargs):
"pack_data",
"unpack_data",
"unpack_info",
"pack_info",
"ray_resampling",
"loss_distortion",
"unpack_to_ray_indices",
"render_transmittance_from_density",
"render_transmittance_from_alpha",
]
43 changes: 39 additions & 4 deletions nerfacc/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,45 @@ def call_cuda(*args, **kwargs):
ray_marching = _make_lazy_cuda_func("ray_marching")
ray_resampling = _make_lazy_cuda_func("ray_resampling")

rendering_forward = _make_lazy_cuda_func("rendering_forward")
rendering_backward = _make_lazy_cuda_func("rendering_backward")
rendering_alphas_forward = _make_lazy_cuda_func("rendering_alphas_forward")
rendering_alphas_backward = _make_lazy_cuda_func("rendering_alphas_backward")
is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
"transmittance_from_sigma_forward_cub"
)
transmittance_from_sigma_backward_cub = _make_lazy_cuda_func(
"transmittance_from_sigma_backward_cub"
)
transmittance_from_alpha_forward_cub = _make_lazy_cuda_func(
"transmittance_from_alpha_forward_cub"
)
transmittance_from_alpha_backward_cub = _make_lazy_cuda_func(
"transmittance_from_alpha_backward_cub"
)

transmittance_from_sigma_forward_naive = _make_lazy_cuda_func(
"transmittance_from_sigma_forward_naive"
)
transmittance_from_sigma_backward_naive = _make_lazy_cuda_func(
"transmittance_from_sigma_backward_naive"
)
transmittance_from_alpha_forward_naive = _make_lazy_cuda_func(
"transmittance_from_alpha_forward_naive"
)
transmittance_from_alpha_backward_naive = _make_lazy_cuda_func(
"transmittance_from_alpha_backward_naive"
)

weight_from_sigma_forward_naive = _make_lazy_cuda_func(
"weight_from_sigma_forward_naive"
)
weight_from_sigma_backward_naive = _make_lazy_cuda_func(
"weight_from_sigma_backward_naive"
)
weight_from_alpha_forward_naive = _make_lazy_cuda_func(
"weight_from_alpha_forward_naive"
)
weight_from_alpha_backward_naive = _make_lazy_cuda_func(
"weight_from_alpha_backward_naive"
)

unpack_data = _make_lazy_cuda_func("unpack_data")
unpack_info = _make_lazy_cuda_func("unpack_info")
Expand Down
83 changes: 63 additions & 20 deletions nerfacc/cuda/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
"""

import glob
import json
import os
import shutil
import urllib.request
import zipfile
from subprocess import DEVNULL, call

from packaging import version
from rich.console import Console
from torch.utils.cpp_extension import _get_build_directory, load

Expand All @@ -21,32 +26,70 @@ def cuda_toolkit_available():
return False


def load_extention(name: str):
return load(
name=name,
sources=glob.glob(os.path.join(PATH, "csrc/*.cu")),
extra_cflags=["-O3"],
extra_cuda_cflags=["-O3"],
)
def cuda_toolkit_version():
"""Get the cuda toolkit version."""
cuda_home = os.path.join(os.path.dirname(shutil.which("nvcc")), "..")
if os.path.exists(os.path.join(cuda_home, "version.txt")):
with open(os.path.join(cuda_home, "version.txt")) as f:
cuda_version = f.read().strip().split()[-1]
elif os.path.exists(os.path.join(cuda_home, "version.json")):
with open(os.path.join(cuda_home, "version.json")) as f:
cuda_version = json.load(f)["cuda"]["version"]
else:
raise RuntimeError("Cannot find the cuda version.")
return cuda_version


_C = None
name = "nerfacc_cuda"
if os.listdir(_get_build_directory(name, verbose=False)) != []:
# If the build exists, we assume the extension has been built
# and we can load it.
_C = load_extention(name)
else:
# First time to build the extension
if cuda_toolkit_available():
build_dir = _get_build_directory(name, verbose=False)
extra_include_paths = []
extra_cflags = ["-O3"]
extra_cuda_cflags = ["-O3"]

_C = None
if cuda_toolkit_available():
# # we need cub >= 1.15.0 which is shipped with cuda >= 11.6, so download if
# # necessary. (compling does not garentee to success)
# if version.parse(cuda_toolkit_version()) < version.parse("11.6"):
# target_path = os.path.join(build_dir, "cub-1.17.0")
# if not os.path.exists(target_path):
# zip_path, _ = urllib.request.urlretrieve(
# "https://github.com/NVIDIA/cub/archive/1.17.0.tar.gz",
# os.path.join(build_dir, "cub-1.17.0.tar.gz"),
# )
# shutil.unpack_archive(zip_path, build_dir)
# extra_include_paths.append(target_path)
# extra_cuda_cflags.append("-DTHRUST_IGNORE_CUB_VERSION_CHECK")
# print(
# f"download cub because the cuda version is {cuda_toolkit_version()}"
# )

if os.path.exists(os.path.join(build_dir, f"{name}.so")):
# If the build exists, we assume the extension has been built
# and we can load it.
_C = load(
name=name,
sources=glob.glob(os.path.join(PATH, "csrc/*.cu")),
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
)
else:
with Console().status(
"[bold yellow]NerfAcc: Setting up CUDA (This may take a few minutes the first time)",
spinner="bouncingBall",
):
_C = load_extention(name)
else:
Console().print(
"[yellow]NerfAcc: No CUDA toolkit found. NerfAcc will be disabled.[/yellow]"
)
_C = load(
name=name,
sources=glob.glob(os.path.join(PATH, "csrc/*.cu")),
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
)
else:
Console().print(
"[yellow]NerfAcc: No CUDA toolkit found. NerfAcc will be disabled.[/yellow]"
)


__all__ = ["_C"]
12 changes: 12 additions & 0 deletions nerfacc/cuda/csrc/include/helpers_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/cub_definitions.cuh>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
Expand All @@ -20,3 +22,13 @@
#define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1)
#define DEVICE_GUARD(_ten) \
const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));

// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
4 changes: 2 additions & 2 deletions nerfacc/cuda/csrc/pack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ __global__ void unpack_data_kernel(
return;
}

torch::Tensor unpack_info(const torch::Tensor packed_info)
torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
Expand All @@ -90,7 +90,7 @@ torch::Tensor unpack_info(const torch::Tensor packed_info)
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);

int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kInt32));

Expand Down
Loading

0 comments on commit 674424e

Please sign in to comment.