From 8e147800bc044822302299f0f43c28f18767fb48 Mon Sep 17 00:00:00 2001 From: niujinshuchong Date: Tue, 9 Jul 2024 15:03:06 +0000 Subject: [PATCH 1/5] init mip-splatting --- examples/datasets/colmap.py | 5 +- examples/simple_trainer_mip_splatting.py | 1069 ++++++++++++++++++++++ 2 files changed, 1073 insertions(+), 1 deletion(-) create mode 100644 examples/simple_trainer_mip_splatting.py diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 22eacc2ac..f3fd8eb46 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -182,6 +182,7 @@ def __init__( self.image_names = image_names # List[str], (num_images,) self.image_paths = image_paths # List[str], (num_images,) self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.worldtocams = np.linalg.inv(camtoworlds) # np.ndarray, (num_images, 4, 4) self.camera_ids = camera_ids # List[int], (num_images,) self.Ks_dict = Ks_dict # Dict of camera_id -> K self.params_dict = params_dict # Dict of camera_id -> params @@ -254,7 +255,8 @@ def __getitem__(self, item: int) -> Dict[str, Any]: K = self.parser.Ks_dict[camera_id].copy() # undistorted K params = self.parser.params_dict[camera_id] camtoworlds = self.parser.camtoworlds[index] - + worldtocams = self.parser.worldtocams[index] + if len(params) > 0: # Images are distorted. Undistort them. mapx, mapy = ( @@ -277,6 +279,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: data = { "K": torch.from_numpy(K).float(), "camtoworld": torch.from_numpy(camtoworlds).float(), + "worldtocam": torch.from_numpy(worldtocams).float(), "image": torch.from_numpy(image).float(), "image_id": item, # the index of the image in the dataset } diff --git a/examples/simple_trainer_mip_splatting.py b/examples/simple_trainer_mip_splatting.py new file mode 100644 index 000000000..f49670c18 --- /dev/null +++ b/examples/simple_trainer_mip_splatting.py @@ -0,0 +1,1069 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import nerfview +from datasets.colmap import Dataset, Parser +from datasets.traj import generate_interpolated_path +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from utils import ( + AppearanceOptModule, + CameraOptModule, + knn, + normalized_quat_to_rotmat, + rgb_to_sh, + set_random_seed, +) + +from gsplat.rendering import rasterization + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt file. If provide, it will skip training and render a video + ckpt: Optional[str] = None + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # GSs with opacity below this value will be pruned + prune_opa: float = 0.005 + # GSs with image plane gradient above this value will be split/duplicated + grow_grad2d: float = 0.0002 + # GSs with scale below this value will be duplicated. Above will be split + grow_scale3d: float = 0.01 + # GSs with scale above this value will be pruned. + prune_scale3d: float = 0.1 + + # Start refining GSs after this iteration + refine_start_iter: int = 500 + # Stop refining GSs after this iteration + refine_stop_iter: int = 15_000 + # Reset opacities every this steps + reset_every: int = 3000 + # Refine GSs every this steps + refine_every: int = 100 + + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 + absgrad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + # kernel size for the low-pass filter in rasterization. 0.1 should be a better value since it better approximates a 2D box filter of single pixel size. + kernel_size: float = 0.3 + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + self.refine_start_iter = int(self.refine_start_iter * factor) + self.refine_stop_iter = int(self.refine_stop_iter * factor) + self.reset_every = int(self.reset_every * factor) + self.refine_every = int(self.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", +) -> Tuple[torch.nn.ParameterDict, torch.optim.Optimizer]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + N = points.shape[0] + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means3d", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + # 3D smoothing filter, setting lr to 0.0 to disable optimization + ("filters", torch.nn.Parameter(torch.ones_like(opacities)), 0.0), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + optimizers = [ + (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(batch_size), "name": name}], + eps=1e-15 / math.sqrt(batch_size), + betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)), + ) + for name, _, lr in params + ] + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__(self, cfg: Config) -> None: + set_random_seed(42) + + self.cfg = cfg + self.device = "cuda" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=True, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + ) + print("Model initialized. Number of GS:", len(self.splats["means3d"])) + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + + self.app_optimizers = [] + if cfg.app_opt: + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to( + self.device + ) + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + # Running stats for prunning & growing. + n_gauss = len(self.splats["means3d"]) + self.running_stats = { + "grad2d": torch.zeros(n_gauss, device=self.device), # norm of the gradient + "count": torch.zeros(n_gauss, device=self.device, dtype=torch.int), + } + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means3d"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + filters = self.splats["filters"] # [N,] + + # apply 3D smoothing filter to scales and opacities + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters) [:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2) # [N,] + opacities = opacities * coef + + scales = torch.square(scales) + torch.square(filters)[:, None] # [N, 3] + scales = torch.sqrt(scales) + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + eps2d=self.cfg.kernel_size, + **kwargs, + ) + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + + # Dump cfg. + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means3d has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # determine the 3D smoothing filter before training + self.compute_3D_smoothing_filter() + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + info["means2d"].retain_grad() # used for running stats + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - self.ssim( + pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar( + "train/num_GS", len(self.splats["means3d"]), step + ) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # update running stats for prunning & growing + if step < cfg.refine_stop_iter: + self.update_running_stats(info) + + if step > cfg.refine_start_iter and step % cfg.refine_every == 0: + grads = self.running_stats["grad2d"] / self.running_stats[ + "count" + ].clamp_min(1) + + # grow GSs + is_grad_high = grads >= cfg.grow_grad2d + is_small = ( + torch.exp(self.splats["scales"]).max(dim=-1).values + <= cfg.grow_scale3d * self.scene_scale + ) + is_dupli = is_grad_high & is_small + n_dupli = is_dupli.sum().item() + self.refine_duplicate(is_dupli) + + is_split = is_grad_high & ~is_small + is_split = torch.cat( + [ + is_split, + # new GSs added by duplication will not be split + torch.zeros(n_dupli, device=device, dtype=torch.bool), + ] + ) + n_split = is_split.sum().item() + self.refine_split(is_split) + print( + f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " + f"Now having {len(self.splats['means3d'])} GSs." + ) + + # prune GSs + is_prune = torch.sigmoid(self.splats["opacities"]) < cfg.prune_opa + if step > cfg.reset_every: + # The official code also implements sreen-size pruning but + # it's actually not being used due to a bug: + # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 + is_too_big = ( + torch.exp(self.splats["scales"]).max(dim=-1).values + > cfg.prune_scale3d * self.scene_scale + ) + is_prune = is_prune | is_too_big + n_prune = is_prune.sum().item() + self.refine_keep(~is_prune) + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(self.splats['means3d'])} GSs." + ) + self.compute_3D_smoothing_filter() + + # reset running stats + self.running_stats["grad2d"].zero_() + self.running_stats["count"].zero_() + + if step % cfg.reset_every == 0: + self.reset_opa(cfg.prune_opa * 2.0) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # save checkpoint + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means3d"]), + } + print("Step: ", step, stats) + with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: + json.dump(stats, f) + torch.save( + { + "step": step, + "splats": self.splats.state_dict(), + }, + f"{self.ckpt_dir}/ckpt_{step}.pt", + ) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: + self.eval(step) + self.render_traj(step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def compute_3D_smoothing_filter(self): + print("Computing 3D filter") + cfg = self.cfg + device = self.device + xyz = self.splats["means3d"] + print("xyz", xyz.shape, xyz.device) + + distance = torch.ones((xyz.shape[0]), device=xyz.device) * 100000.0 + valid_points = torch.zeros((xyz.shape[0]), device=xyz.device, dtype=torch.bool) + focal_length = 0. + + for data in self.trainset: + worldtocam = data["worldtocam"].to(device) # [4, 4] + K = data["K"].to(device) # [3, 3] + height, width = data["image"].shape[:2] + R = worldtocam[:3, :3] + T = worldtocam[:3, 3] + + xyz_cam = xyz @ R.transpose(1, 0) + T[None, :] + + # project to screen space + valid_depth = xyz_cam[:, 2] > cfg.near_plane + + x, y, z = xyz_cam[:, 0], xyz_cam[:, 1], xyz_cam[:, 2] + z = torch.clamp(z, min=0.001) + + x = x / z * K[0, 0] + K[0, 2] + y = y / z * K[1, 1] + K[1, 2] + + # use similar tangent space filtering as in 3DGS, + # TODO check gsplat's implementation + in_screen = torch.logical_and(torch.logical_and(x >= -0.15 * width, x <= width * 1.15), torch.logical_and(y >= -0.15 * height, y <= 1.15 * height)) + valid = torch.logical_and(valid_depth, in_screen) + + distance[valid] = torch.min(distance[valid], z[valid]) + valid_points = torch.logical_or(valid_points, valid) + if focal_length < K[0, 0]: + focal_length = K[0, 0] + + distance[~valid_points] = distance[valid_points].max() + + filter_3D = distance / focal_length * (0.2 ** 0.5) + self.splats["filters"] = torch.nn.Parameter(filter_3D) + + @torch.no_grad() + def update_running_stats(self, info: Dict): + """Update running stats.""" + cfg = self.cfg + + # normalize grads to [-1, 1] screen space + if cfg.absgrad: + grads = info["means2d"].absgrad.clone() + else: + grads = info["means2d"].grad.clone() + grads[..., 0] *= info["width"] / 2.0 * cfg.batch_size + grads[..., 1] *= info["height"] / 2.0 * cfg.batch_size + if cfg.packed: + # grads is [nnz, 2] + gs_ids = info["gaussian_ids"] # [nnz] or None + self.running_stats["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) + self.running_stats["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids).int() + ) + else: + # grads is [C, N, 2] + sel = info["radii"] > 0.0 # [C, N] + gs_ids = torch.where(sel)[1] # [nnz] + self.running_stats["grad2d"].index_add_(0, gs_ids, grads[sel].norm(dim=-1)) + self.running_stats["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids).int() + ) + + @torch.no_grad() + def reset_opa(self, value: float = 0.01): + """Utility function to reset opacities.""" + # opacities = torch.clamp( + # self.splats["opacities"], max=torch.logit(torch.tensor(value)).item() + # ) + + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + filters = self.splats["filters"] # [N,] + + # apply 3D smoothing filter to scales and opacities + print("apply 3D smoothing filter in reset opacities") + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters) [:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2) # [N,] + opacities = opacities * coef + + opacities_reset = torch.min(opacities, torch.ones_like(opacities)*value) + opacities_reset = opacities_reset / coef + opacities = torch.logit(opacities_reset) + + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + if param_group["name"] != "opacities": + continue + p = param_group["params"][0] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + p_state[key] = torch.zeros_like(p_state[key]) + p_new = torch.nn.Parameter(opacities) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[param_group["name"]] = p_new + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_split(self, mask: Tensor): + """Utility function to grow GSs.""" + device = self.device + + sel = torch.where(mask)[0] + rest = torch.where(~mask)[0] + + scales = torch.exp(self.splats["scales"][sel]) # [N, 3] + quats = F.normalize(self.splats["quats"][sel], dim=-1) # [N, 4] + rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] + samples = torch.einsum( + "nij,nj,bnj->bni", + rotmats, + scales, + torch.randn(2, len(scales), 3, device=device), + ) # [2, N, 3] + + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + # create new params + if name == "means3d": + p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3] + elif name == "scales": + p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3] + else: + repeats = [2] + [1] * (p.dim() - 1) + p_split = p[sel].repeat(repeats) + p_new = torch.cat([p[rest], p_split]) + p_new = torch.nn.Parameter(p_new) + # update optimizer + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key == "step": + continue + v = p_state[key] + # new params are assigned with zero optimizer states + # (worth investigating it) + v_split = torch.zeros((2 * len(sel), *v.shape[1:]), device=device) + p_state[key] = torch.cat([v[rest], v_split]) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + if v is None: + continue + repeats = [2] + [1] * (v.dim() - 1) + v_new = v[sel].repeat(repeats) + self.running_stats[k] = torch.cat((v[rest], v_new)) + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_duplicate(self, mask: Tensor): + """Unility function to duplicate GSs.""" + sel = torch.where(mask)[0] + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + # new params are assigned with zero optimizer states + # (worth investigating it as it will lead to a lot more GS.) + v = p_state[key] + v_new = torch.zeros( + (len(sel), *v.shape[1:]), device=self.device + ) + # v_new = v[sel] + p_state[key] = torch.cat([v, v_new]) + p_new = torch.nn.Parameter(torch.cat([p, p[sel]])) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + self.running_stats[k] = torch.cat((v, v[sel])) + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_keep(self, mask: Tensor): + """Unility function to prune GSs.""" + sel = torch.where(mask)[0] + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + p_state[key] = p_state[key][sel] + p_new = torch.nn.Parameter(p[sel]) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + self.running_stats[k] = v[sel] + torch.cuda.empty_cache() + + @torch.no_grad() + def eval(self, step: int): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = {"psnr": [], "ssim": [], "lpips": []} + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + colors = torch.clamp(colors, 0.0, 1.0) + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) + ) + + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means3d'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means3d"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds = self.parser.camtoworlds[5:-5] + camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] + camtoworlds = np.concatenate( + [ + camtoworlds, + np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds[i : i + 1], + Ks=K[None], + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] + depths = renders[0, ..., 3:4] # [H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + + # write images + canvas = torch.cat( + [colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1 + ) + canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(cfg: Config): + runner = Runner(cfg) + + if cfg.ckpt is not None: + # run eval only + ckpt = torch.load(cfg.ckpt, map_location=runner.device) + for k in runner.splats.keys(): + runner.splats[k].data = ckpt["splats"][k] + runner.eval(step=ckpt["step"]) + runner.render_traj(step=ckpt["step"]) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + main(cfg) From f3599574a2277ae76a7a15672203aa8ba0e7c5b9 Mon Sep 17 00:00:00 2001 From: niujinshuchong Date: Wed, 10 Jul 2024 12:06:55 +0000 Subject: [PATCH 2/5] format --- examples/datasets/colmap.py | 2 +- examples/datasets/download_dataset.py | 4 +- examples/simple_trainer_mip_splatting.py | 77 ++++++++++++------------ 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index f3fd8eb46..9019c82bb 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -256,7 +256,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: params = self.parser.params_dict[camera_id] camtoworlds = self.parser.camtoworlds[index] worldtocams = self.parser.worldtocams[index] - + if len(params) > 0: # Images are distorted. Undistort them. mapx, mapy = ( diff --git a/examples/datasets/download_dataset.py b/examples/datasets/download_dataset.py index 8366ae979..5cdb99ffd 100755 --- a/examples/datasets/download_dataset.py +++ b/examples/datasets/download_dataset.py @@ -9,9 +9,7 @@ import tyro # dataset names -dataset_names = Literal[ - "mipnerf360", -] +dataset_names = Literal["mipnerf360",] # dataset urls urls = {"mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip"} diff --git a/examples/simple_trainer_mip_splatting.py b/examples/simple_trainer_mip_splatting.py index f49670c18..27eb87188 100644 --- a/examples/simple_trainer_mip_splatting.py +++ b/examples/simple_trainer_mip_splatting.py @@ -361,19 +361,19 @@ def rasterize_splats( scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] filters = self.splats["filters"] # [N,] - + # apply 3D smoothing filter to scales and opacities - scales_square = torch.square(scales) # [N, 3] - det1 = scales_square.prod(dim=1) # [N, ] - - scales_after_square = scales_square + torch.square(filters) [:, None] # [N, 1] - det2 = scales_after_square.prod(dim=1) # [N,] - coef = torch.sqrt(det1 / det2) # [N,] + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2) # [N,] opacities = opacities * coef - + scales = torch.square(scales) + torch.square(filters)[:, None] # [N, 3] scales = torch.sqrt(scales) - + image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( @@ -444,7 +444,7 @@ def train(self): # determine the 3D smoothing filter before training self.compute_3D_smoothing_filter() - + # Training loop. global_tic = time.time() pbar = tqdm.tqdm(range(init_step, max_steps)) @@ -690,44 +690,47 @@ def compute_3D_smoothing_filter(self): device = self.device xyz = self.splats["means3d"] print("xyz", xyz.shape, xyz.device) - + distance = torch.ones((xyz.shape[0]), device=xyz.device) * 100000.0 valid_points = torch.zeros((xyz.shape[0]), device=xyz.device, dtype=torch.bool) - focal_length = 0. - + focal_length = 0.0 + for data in self.trainset: - worldtocam = data["worldtocam"].to(device) # [4, 4] + worldtocam = data["worldtocam"].to(device) # [4, 4] K = data["K"].to(device) # [3, 3] height, width = data["image"].shape[:2] R = worldtocam[:3, :3] T = worldtocam[:3, 3] - + xyz_cam = xyz @ R.transpose(1, 0) + T[None, :] - + # project to screen space valid_depth = xyz_cam[:, 2] > cfg.near_plane - + x, y, z = xyz_cam[:, 0], xyz_cam[:, 1], xyz_cam[:, 2] z = torch.clamp(z, min=0.001) - + x = x / z * K[0, 0] + K[0, 2] y = y / z * K[1, 1] + K[1, 2] - - # use similar tangent space filtering as in 3DGS, + + # use similar tangent space filtering as in 3DGS, # TODO check gsplat's implementation - in_screen = torch.logical_and(torch.logical_and(x >= -0.15 * width, x <= width * 1.15), torch.logical_and(y >= -0.15 * height, y <= 1.15 * height)) + in_screen = torch.logical_and( + torch.logical_and(x >= -0.15 * width, x <= width * 1.15), + torch.logical_and(y >= -0.15 * height, y <= 1.15 * height), + ) valid = torch.logical_and(valid_depth, in_screen) - + distance[valid] = torch.min(distance[valid], z[valid]) valid_points = torch.logical_or(valid_points, valid) if focal_length < K[0, 0]: focal_length = K[0, 0] - + distance[~valid_points] = distance[valid_points].max() - - filter_3D = distance / focal_length * (0.2 ** 0.5) + + filter_3D = distance / focal_length * (0.2**0.5) self.splats["filters"] = torch.nn.Parameter(filter_3D) - + @torch.no_grad() def update_running_stats(self, info: Dict): """Update running stats.""" @@ -762,25 +765,25 @@ def reset_opa(self, value: float = 0.01): # opacities = torch.clamp( # self.splats["opacities"], max=torch.logit(torch.tensor(value)).item() # ) - + scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] filters = self.splats["filters"] # [N,] - + # apply 3D smoothing filter to scales and opacities print("apply 3D smoothing filter in reset opacities") - scales_square = torch.square(scales) # [N, 3] - det1 = scales_square.prod(dim=1) # [N, ] - - scales_after_square = scales_square + torch.square(filters) [:, None] # [N, 1] - det2 = scales_after_square.prod(dim=1) # [N,] - coef = torch.sqrt(det1 / det2) # [N,] + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2) # [N,] opacities = opacities * coef - - opacities_reset = torch.min(opacities, torch.ones_like(opacities)*value) + + opacities_reset = torch.min(opacities, torch.ones_like(opacities) * value) opacities_reset = opacities_reset / coef opacities = torch.logit(opacities_reset) - + for optimizer in self.optimizers: for i, param_group in enumerate(optimizer.param_groups): if param_group["name"] != "opacities": From 11990317087253a0d4c211a46f8ee1fa004b7190 Mon Sep 17 00:00:00 2001 From: niujinshuchong Date: Wed, 10 Jul 2024 12:09:12 +0000 Subject: [PATCH 3/5] add eps in sqrt --- examples/simple_trainer_mip_splatting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/simple_trainer_mip_splatting.py b/examples/simple_trainer_mip_splatting.py index 27eb87188..bbe5e0f65 100644 --- a/examples/simple_trainer_mip_splatting.py +++ b/examples/simple_trainer_mip_splatting.py @@ -368,7 +368,7 @@ def rasterize_splats( scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] det2 = scales_after_square.prod(dim=1) # [N,] - coef = torch.sqrt(det1 / det2) # [N,] + coef = torch.sqrt(det1 / det2 + 1e-7) # [N,] opacities = opacities * coef scales = torch.square(scales) + torch.square(filters)[:, None] # [N, 3] @@ -777,11 +777,11 @@ def reset_opa(self, value: float = 0.01): scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] det2 = scales_after_square.prod(dim=1) # [N,] - coef = torch.sqrt(det1 / det2) # [N,] + coef = torch.sqrt(det1 / det2 + 1e-7) # [N,] opacities = opacities * coef opacities_reset = torch.min(opacities, torch.ones_like(opacities) * value) - opacities_reset = opacities_reset / coef + opacities_reset = opacities_reset / (coef + 1e-7) opacities = torch.logit(opacities_reset) for optimizer in self.optimizers: From 5630e127d94412f39543d24c8f151a71621e8a10 Mon Sep 17 00:00:00 2001 From: niujinshuchong Date: Wed, 10 Jul 2024 18:31:44 +0000 Subject: [PATCH 4/5] add cuda implementation of compute_3D_smoothing_filter --- examples/simple_trainer_mip_splatting.py | 24 +++- gsplat/__init__.py | 2 + gsplat/cuda/_wrapper.py | 53 +++++++++ gsplat/cuda/csrc/bindings.h | 7 ++ .../csrc/compute_3D_smoothing_filter_fwd.cu | 110 ++++++++++++++++++ gsplat/cuda/csrc/ext.cpp | 2 + 6 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu diff --git a/examples/simple_trainer_mip_splatting.py b/examples/simple_trainer_mip_splatting.py index bbe5e0f65..57a75ce9a 100644 --- a/examples/simple_trainer_mip_splatting.py +++ b/examples/simple_trainer_mip_splatting.py @@ -29,6 +29,7 @@ ) from gsplat.rendering import rasterization +from gsplat import compute_3D_smoothing_filter @dataclass @@ -685,6 +686,27 @@ def train(self): @torch.no_grad() def compute_3D_smoothing_filter(self): + cfg = self.cfg + device = self.device + xyz = self.splats["means3d"] + print("xyz", xyz.shape, xyz.device) + + worldtocams = ( + torch.from_numpy(self.trainset.parser.worldtocams).float().to(device) + ) + # TODO, currently use K, H, W of the first image + data = self.trainset[0] + K = data["K"].to(device) # [3, 3] + height, width = data["image"].shape[:2] + Ks = torch.stack([K] * len(worldtocams), dim=0) # [C, 3, 3] + filter_3D = compute_3D_smoothing_filter( + xyz, worldtocams, Ks, width, height, cfg.near_plane + ) + filter_3D = filter_3D * (0.2**0.5) + self.splats["filters"] = torch.nn.Parameter(filter_3D) + + @torch.no_grad() + def compute_3D_smoothing_filter_torch(self): print("Computing 3D filter") cfg = self.cfg device = self.device @@ -1057,7 +1079,7 @@ def main(cfg: Config): for k in runner.splats.keys(): runner.splats[k].data = ckpt["splats"][k] runner.eval(step=ckpt["step"]) - runner.render_traj(step=ckpt["step"]) + # runner.render_traj(step=ckpt["step"]) else: runner.train() diff --git a/gsplat/__init__.py b/gsplat/__init__.py index cf238ec3b..ecce1692c 100644 --- a/gsplat/__init__.py +++ b/gsplat/__init__.py @@ -11,6 +11,7 @@ rasterize_to_pixels, spherical_harmonics, world_to_cam, + compute_3D_smoothing_filter, ) from .rendering import ( rasterization, @@ -120,4 +121,5 @@ def get_tile_bin_edges(*args, **kwargs): "compute_cumulative_intersects", "compute_cov2d_bounds", "get_tile_bin_edges", + "compute_3D_smoothing_filter", ] diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 9043181ce..111920788 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -572,6 +572,34 @@ def rasterize_to_indices_in_range( return out_gauss_ids, out_pixel_ids, out_camera_ids +def compute_3D_smoothing_filter( + means: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, +) -> Tensor: + """Compute 3D smoothing filter.""" + C = viewmats.size(0) + N = means.size(0) + assert means.size() == (N, 3), means.size() + assert viewmats.size() == (C, 4, 4), viewmats.size() + assert Ks.size() == (C, 3, 3), Ks.size() + means = means.contiguous() + viewmats = viewmats.contiguous() + Ks = Ks.contiguous() + + return _Compute3DSmoothingFilter.apply( + means, + viewmats, + Ks, + width, + height, + near_plane, + ) + + class _QuatScaleToCovarPreci(torch.autograd.Function): """Converts quaternions and scales to covariance and precision matrices.""" @@ -1143,3 +1171,28 @@ def backward(ctx, v_colors: Tensor): if not compute_v_dirs: v_dirs = None return None, v_dirs, v_coeffs, None + + +class _Compute3DSmoothingFilter(torch.autograd.Function): + """Compute 3D Smoothing filter.""" + + @staticmethod + def forward( + ctx, + means: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float, + ) -> Tensor: + filter_3D = _make_lazy_cuda_func("compute_3D_smoothing_filter_fwd")( + means, + viewmats, + Ks, + width, + height, + near_plane, + ) + + return filter_3D diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index c983f461e..d38544e2a 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -175,6 +175,13 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use, torch::Tensor &v_colors, // [..., 3] bool compute_v_dirs); +torch::Tensor compute_3D_smoothing_filter_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, const uint32_t image_height, + const float near_plane); + /**************************************************************************************** * Packed Version ****************************************************************************************/ diff --git a/gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu b/gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu new file mode 100644 index 000000000..a5f337aa0 --- /dev/null +++ b/gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu @@ -0,0 +1,110 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + + +/**************************************************************************** + * Compute the 3D smoothing filter size of 3D Gaussians Forward Pass + ****************************************************************************/ + +template +__global__ void +compute_3D_smoothing_filter_fwd_kernel(const uint32_t C, const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, const int32_t image_height, + const T near_plane, + // outputs + T *__restrict__ filter // [N, ] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + viewmats += cid * 16; + Ks += cid * 9; + + // glm is column-major but input is row-major + mat3 R = mat3(viewmats[0], viewmats[4], viewmats[8], // 1st column + viewmats[1], viewmats[5], viewmats[9], // 2nd column + viewmats[2], viewmats[6], viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // transform Gaussian center to camera space + vec3 mean_c; + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + if (mean_c.z < near_plane ) { + return; + } + + // project the point to image plane + vec2 mean2d; + + const T fx = Ks[0]; + const T fy = Ks[4]; + const T cx = Ks[2]; + const T cy = Ks[5]; + + T x = mean_c[0], y = mean_c[1], z = mean_c[2]; + T rz = 1.f / z; + mean2d = vec2({fx * x * rz + cx, fy * y * rz + cy}); + + // mask out gaussians outside the image region + if (mean2d.x <= 0 || mean2d.x >= image_width || + mean2d.y <= 0 || mean2d.y >= image_height) { + return; + } + + T filter_size = z / fx; + + // write to outputs + // atomicMin(&filter[gid], filter_size); + + // atomicMin is not supported for float, so we use __float_as_int + // refer to https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda/51549250#51549250 + atomicMin((int *)&filter[gid], __float_as_int(filter_size)); +} + + +torch::Tensor compute_3D_smoothing_filter_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, const uint32_t image_height, + const float near_plane) { + DEVICE_GUARD(means); + CHECK_INPUT(means); + CHECK_INPUT(viewmats); + CHECK_INPUT(Ks); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor filter = torch::full({N}, 1000000, means.options()); + + if (C && N) { + compute_3D_smoothing_filter_fwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + C, N, means.data_ptr(), + viewmats.data_ptr(), Ks.data_ptr(), image_width, image_height, + near_plane, + filter.data_ptr()); + } + return filter; +} diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index 6ef34e7e1..f6b0c28a4 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -25,6 +25,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rasterize_to_indices_in_range", &rasterize_to_indices_in_range_tensor); + m.def("compute_3D_smoothing_filter_fwd", &compute_3D_smoothing_filter_fwd_tensor); + // packed version m.def("fully_fused_projection_packed_fwd", &fully_fused_projection_packed_fwd_tensor); m.def("fully_fused_projection_packed_bwd", &fully_fused_projection_packed_bwd_tensor); From 49160072d16165d447a22a7e0e7d67c5e130f5a8 Mon Sep 17 00:00:00 2001 From: niujinshuchong Date: Fri, 12 Jul 2024 13:02:14 +0000 Subject: [PATCH 5/5] add training scripts --- examples/benchmark_mipnerf360.py | 93 +++++++++++++++++++++ examples/benchmark_mipnerf360_stmt.py | 115 ++++++++++++++++++++++++++ examples/show_mipnerf360.py | 51 ++++++++++++ examples/show_mipnerf360_allscales.py | 45 ++++++++++ 4 files changed, 304 insertions(+) create mode 100644 examples/benchmark_mipnerf360.py create mode 100644 examples/benchmark_mipnerf360_stmt.py create mode 100644 examples/show_mipnerf360.py create mode 100644 examples/show_mipnerf360_allscales.py diff --git a/examples/benchmark_mipnerf360.py b/examples/benchmark_mipnerf360.py new file mode 100644 index 000000000..62d39c5ea --- /dev/null +++ b/examples/benchmark_mipnerf360.py @@ -0,0 +1,93 @@ +# Training script for the Mip-NeRF 360 dataset + +import os +import GPUtil +from concurrent.futures import ThreadPoolExecutor +import time +import glob + +# 9 scenes +# scenes = ["bicycle", "bonsai", "counter", "flowers", "garden", "stump", "treehill", "kitchen", "room"] +# factors = [4, 2, 2, 4, 4, 4, 4, 2, 2] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] +factors = [4, 2, 2, 4, 4, 2, 2] + +excluded_gpus = set([]) + +result_dir = "results/benchmark_mipsplatting_cuda3D" + +dry_run = False + +jobs = list(zip(scenes, factors)) + + +def train_scene(gpu, scene, factor): + # train without eval + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + # eval and render for all the ckpts + ckpts = glob.glob(f"{result_dir}/{scene}/ckpts/*.pt") + for ckpt in ckpts: + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --ckpt {ckpt} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + return True + + +def worker(gpu, scene, factor): + print(f"Starting job on GPU {gpu} with scene {scene}\n") + train_scene(gpu, scene, factor) + print(f"Finished job on GPU {gpu} with scene {scene}\n") + # This worker function starts a job and returns when it's done. + + +def dispatch_jobs(jobs, executor): + future_to_job = {} + reserved_gpus = set() # GPUs that are slated for work but may not be active yet + + while jobs or future_to_job: + # Get the list of available GPUs, not including those that are reserved. + all_available_gpus = set( + GPUtil.getAvailable(order="first", limit=10, maxMemory=0.1, maxLoad=0.1) + ) + # all_available_gpus = set([0,1,2,3]) + available_gpus = list(all_available_gpus - reserved_gpus - excluded_gpus) + + # Launch new jobs on available GPUs + while available_gpus and jobs: + gpu = available_gpus.pop(0) + job = jobs.pop(0) + future = executor.submit( + worker, gpu, *job + ) # Unpacking job as arguments to worker + future_to_job[future] = (gpu, job) + + reserved_gpus.add(gpu) # Reserve this GPU until the job starts processing + + # Check for completed jobs and remove them from the list of running jobs. + # Also, release the GPUs they were using. + done_futures = [future for future in future_to_job if future.done()] + for future in done_futures: + job = future_to_job.pop( + future + ) # Remove the job associated with the completed future + gpu = job[0] # The GPU is the first element in each job tuple + reserved_gpus.discard(gpu) # Release this GPU + print(f"Job {job} has finished., rellasing GPU {gpu}") + # (Optional) You might want to introduce a small delay here to prevent this loop from spinning very fast + # when there are no GPUs available. + time.sleep(5) + + print("All jobs have been processed.") + + +# Using ThreadPoolExecutor to manage the thread pool +with ThreadPoolExecutor(max_workers=8) as executor: + dispatch_jobs(jobs, executor) diff --git a/examples/benchmark_mipnerf360_stmt.py b/examples/benchmark_mipnerf360_stmt.py new file mode 100644 index 000000000..28952adad --- /dev/null +++ b/examples/benchmark_mipnerf360_stmt.py @@ -0,0 +1,115 @@ +# Training script for the Mip-NeRF 360 dataset +# The model is trained with downsampling factor 8 and rendered with downsampling factor 1, 2, 4, 8 + +import os +import GPUtil +from concurrent.futures import ThreadPoolExecutor +import time +import glob + +# 9 scenes +# scenes = ["bicycle", "bonsai", "counter", "flowers", "garden", "stump", "treehill", "kitchen", "room"] +# factors = [4, 2, 2, 4, 4, 4, 4, 2, 2] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] +factors = [8] * len(scenes) + +excluded_gpus = set([]) + +# classic +result_dir = "results/benchmark_stmt" +# antialiased +result_dir = "results/benchmark_antialiased_stmt" +# mip-splatting +# result_dir = "results/benchmark_mipsplatting_stmt" + +dry_run = False + +jobs = list(zip(scenes, factors)) + + +def train_scene(gpu, scene, factor): + # train without eval + # classic + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}" + + # anti-aliased + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --antialiased" + + # mip-splatting + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + # eval and render for all the ckpts + ckpts = glob.glob(f"{result_dir}/{scene}/ckpts/*.pt") + for ckpt in ckpts: + for test_factor in [1, 2, 4, 8]: + # classic + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --disable_viewer --data_factor {test_factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}_{test_factor} --ckpt {ckpt}" + + # anti-aliased + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --disable_viewer --data_factor {test_factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}_{test_factor} --ckpt {ckpt} --antialiased" + + # mip-splatting + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --disable_viewer --data_factor {test_factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}_{test_factor} --ckpt {ckpt} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + return True + + +def worker(gpu, scene, factor): + print(f"Starting job on GPU {gpu} with scene {scene}\n") + train_scene(gpu, scene, factor) + print(f"Finished job on GPU {gpu} with scene {scene}\n") + # This worker function starts a job and returns when it's done. + + +def dispatch_jobs(jobs, executor): + future_to_job = {} + reserved_gpus = set() # GPUs that are slated for work but may not be active yet + + while jobs or future_to_job: + # Get the list of available GPUs, not including those that are reserved. + all_available_gpus = set( + GPUtil.getAvailable(order="first", limit=10, maxMemory=0.1, maxLoad=0.1) + ) + # all_available_gpus = set([0,1,2,3]) + available_gpus = list(all_available_gpus - reserved_gpus - excluded_gpus) + + # Launch new jobs on available GPUs + while available_gpus and jobs: + gpu = available_gpus.pop(0) + job = jobs.pop(0) + future = executor.submit( + worker, gpu, *job + ) # Unpacking job as arguments to worker + future_to_job[future] = (gpu, job) + + reserved_gpus.add(gpu) # Reserve this GPU until the job starts processing + time.sleep(2) + + # Check for completed jobs and remove them from the list of running jobs. + # Also, release the GPUs they were using. + done_futures = [future for future in future_to_job if future.done()] + for future in done_futures: + job = future_to_job.pop( + future + ) # Remove the job associated with the completed future + gpu = job[0] # The GPU is the first element in each job tuple + reserved_gpus.discard(gpu) # Release this GPU + print(f"Job {job} has finished., rellasing GPU {gpu}") + # (Optional) You might want to introduce a small delay here to prevent this loop from spinning very fast + # when there are no GPUs available. + time.sleep(5) + + print("All jobs have been processed.") + + +# Using ThreadPoolExecutor to manage the thread pool +with ThreadPoolExecutor(max_workers=8) as executor: + dispatch_jobs(jobs, executor) diff --git a/examples/show_mipnerf360.py b/examples/show_mipnerf360.py new file mode 100644 index 000000000..c89fa124d --- /dev/null +++ b/examples/show_mipnerf360.py @@ -0,0 +1,51 @@ +import json +import numpy as np +import glob + +# 9 scenes +# scenes = ['bicycle', 'flowers', 'garden', 'stump', 'treehill', 'room', 'counter', 'kitchen', 'bonsai'] + +# outdoor scenes +# scenes = scenes[:5] +# indoor scenes +# scenes = scenes[5:] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] + +result_dirs = ["results/benchmark"] +result_dirs = ["results/benchmark_antialiased"] +result_dirs = ["results/benchmark_mipsplatting"] +result_dirs = ["results/benchmark_mipsplatting_cuda3D"] + +all_metrics = {"psnr": [], "ssim": [], "lpips": [], "num_GS": []} +print(result_dirs) + +for scene in scenes: + print(scene, end=" ") + for result_dir in result_dirs: + json_files = glob.glob(f"{result_dir}/{scene}/stats/val_step29999.json") + for json_file in json_files: + # print(json_file) + data = json.load(open(json_file)) + # print(data) + + for k in ["psnr", "ssim", "lpips", "num_GS"]: + all_metrics[k].append(data[k]) + print(f"{data[k]:.3f}", end=" ") + print() + +latex = [] +for k in ["psnr", "ssim", "lpips", "num_GS"]: + numbers = np.asarray(all_metrics[k]).mean(axis=0).tolist() + print(numbers) + numbers = [numbers] + if k == "PSNR": + numbers = [f"{x:.2f}" for x in numbers] + elif k == "num_GS": + num = numbers[0] / 1e6 + numbers = [f"{num:.2f}"] + else: + numbers = [f"{x:.3f}" for x in numbers] + latex.extend(numbers) +print(" | ".join(latex)) diff --git a/examples/show_mipnerf360_allscales.py b/examples/show_mipnerf360_allscales.py new file mode 100644 index 000000000..dcbf166ad --- /dev/null +++ b/examples/show_mipnerf360_allscales.py @@ -0,0 +1,45 @@ +import json +import numpy as np +import glob + +# 9 scenes +# scenes = ['bicycle', 'flowers', 'garden', 'stump', 'treehill', 'room', 'counter', 'kitchen', 'bonsai'] + +# outdoor scenes +# scenes = scenes[:5] +# indoor scenes +# scenes = scenes[5:] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] + +result_dirs = ["results/benchmark_stmt"] +# result_dirs = ["results/benchmark_antialiased_stmt"] +# result_dirs = ["results/benchmark_mipsplatting_stmt"] + +all_metrics = {"psnr": [], "ssim": [], "lpips": [], "num_GS": []} +print(result_dirs) + +for scene in scenes: + print(scene) + for result_dir in result_dirs: + for scale in ["8", "4", "2", "1"]: + json_files = glob.glob(f"{result_dir}/{scene}_{scale}/stats/val_step29999.json") + for json_file in json_files: + data = json.load(open(json_file)) + for k in ["psnr", "ssim", "lpips", "num_GS"]: + all_metrics[k].append(data[k]) + print(f"{data[k]:.3f}", end=" ") + print() + +latex = [] +for k in ["psnr", "ssim", "lpips"]: + numbers = np.asarray(all_metrics[k]).reshape(-1, 4).mean(axis=0).tolist() + numbers = numbers + [np.mean(numbers)] + print(numbers) + if k == "psnr": + numbers = [f"{x:.2f}" for x in numbers] + else: + numbers = [f"{x:.3f}" for x in numbers] + latex.extend(numbers) +print(" | ".join(latex))