diff --git a/EXPLORATION.md b/EXPLORATION.md index e597e9479..e8698ab75 100644 --- a/EXPLORATION.md +++ b/EXPLORATION.md @@ -26,7 +26,7 @@ | `--absgrad --grow_grad2d 2e-4` | 8m30s | 0.018s/im | 2.21 GB | 0.6251 | 20.68 | 0.587 | 0.89M | | `--absgrad --grow_grad2d 2e-4` (30k) | -- | 0.030s/im | 5.25 GB | 0.7442 | 24.12 | 0.291 | 2.62M | -Note: default args means running `python simple_trainer.py --data_dir ` with: +Note: default args means running `CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --data_dir ` with: - Garden ([Source](https://jonbarron.info/mipnerf360/)): `--result_dir results/garden` - U1 (a.k.a University 1 from [Source](https://localrf.github.io/)): `--result_dir results/u1 --data_factor 1 --grow_scale3d 0.001` diff --git a/README.md b/README.md index 5e517f829..40cfdbf7f 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ pip install -r requirements.txt # download mipnerf_360 benchmark data python datasets/download_dataset.py # run batch evaluation -bash benchmark.sh +bash benchmarks/basic.sh ``` ## Examples diff --git a/docs/source/examples/colmap.rst b/docs/source/examples/colmap.rst index ef6d171e3..13eb344f1 100644 --- a/docs/source/examples/colmap.rst +++ b/docs/source/examples/colmap.rst @@ -15,7 +15,7 @@ Simply run the script under `examples/`: .. code-block:: bash - python simple_trainer.py \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py \ --data_dir data/360_v2/garden/ --data_factor 4 \ --result_dir ./results/garden diff --git a/docs/source/examples/large_scale.rst b/docs/source/examples/large_scale.rst index 95f2b1140..46db0bc49 100644 --- a/docs/source/examples/large_scale.rst +++ b/docs/source/examples/large_scale.rst @@ -35,7 +35,7 @@ The code for this example can be found under `examples/`: .. code-block:: bash # First train a 3DGS model - python simple_trainer.py \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py \ --data_dir data/360_v2/garden/ --data_factor 4 \ --result_dir ./results/garden diff --git a/docs/source/index.rst b/docs/source/index.rst index f0199c08c..c123ef1e1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,13 +13,21 @@ Overview Real-Time Rendering of Radiance Fields" :cite:p:`kerbl3Dgaussians`, but we've made *gsplat* even faster, more memory efficient, and with a growing list of new features! -* *gsplat* is developed with efficiency in mind. Comparing to the `official implementation `_, *gsplat* enables up to **4x less training memory footprint**, and up to **15% less training time** on Mip-NeRF 360 captures, and potential more on larger scenes. See :doc:`tests/eval` for details. +* *gsplat* is developed with efficiency in mind. Comparing to the `official implementation `_, + *gsplat* enables up to **4x less training memory footprint**, and up to **15% less training time** on Mip-NeRF 360 captures, and potential more on larger scenes. See :doc:`tests/eval` for details. -* *gsplat* is designed to **support extremely large scene rendering**, which is magnitudes faster than the official CUDA backend `diff-gaussian-rasterization `_. See :doc:`examples/large_scale` for an example. +* *gsplat* is designed to **support extremely large scene rendering**, which is magnitudes + faster than the official CUDA backend `diff-gaussian-rasterization `_. See :doc:`examples/large_scale` for an example. -* *gsplat* offers many extra features, including **batch rasterization**, **N-D feature rendering**, **depth rendering**, **sparse gradient** etc. See :doc:`apis/rasterization` for details. +* *gsplat* offers many extra features, including **batch rasterization**, + **N-D feature rendering**, **depth rendering**, **sparse gradient**, + **multi-GPU distributed rasterization** + etc. See :doc:`apis/rasterization` for details. -* *gsplat* is equipped with the **latest and greatest** 3D Gaussian Splatting techniques, including `absgrad `_, `anti-aliasing `_ etc. And more to come! +* *gsplat* is equipped with the **latest and greatest** 3D Gaussian Splatting techniques, + including `absgrad `_, + `anti-aliasing `_, + `3DGS-MCMC `_ etc. And more to come! .. raw:: html diff --git a/docs/source/tests/eval.rst b/docs/source/tests/eval.rst index fd5dda081..486cb39b3 100644 --- a/docs/source/tests/eval.rst +++ b/docs/source/tests/eval.rst @@ -3,17 +3,19 @@ Evaluation .. table:: Performance on `Mip-NeRF 360 Captures `_ (Averaged Over 7 Scenes) - +------------+-------+-------+-------+------------------+------------+ - | | PSNR | SSIM | LPIPS | Train Mem | Train Time | - +============+=======+=======+=======+==================+============+ - | inria-7k | 27.23 | 0.829 | 0.204 | 7.7 GB | 6m05s | - +------------+-------+-------+-------+------------------+------------+ - | gsplat-7k | 27.21 | 0.831 | 0.202 | **4.3GB** | **5m35s** | - +------------+-------+-------+-------+------------------+------------+ - | inria-30k | 28.95 | 0.870 | 0.138 | 9.0 GB | 37m13s | - +------------+-------+-------+-------+------------------+------------+ - | gsplat-30k | 28.95 | 0.870 | 0.135 | **5.7 GB** | **35m49s** | - +------------+-------+-------+-------+------------------+------------+ + +---------------------+-------+-------+-------+------------------+------------+ + | | PSNR | SSIM | LPIPS | Train Mem | Train Time | + +=====================+=======+=======+=======+==================+============+ + | inria-7k | 27.23 | 0.829 | 0.204 | 7.7 GB | 6m05s | + +---------------------+-------+-------+-------+------------------+------------+ + | gsplat-7k | 27.21 | 0.831 | 0.202 | **4.3GB** | **5m35s** | + +---------------------+-------+-------+-------+------------------+------------+ + | inria-30k | 28.95 | 0.870 | 0.138 | 9.0 GB | 37m13s | + +---------------------+-------+-------+-------+------------------+------------+ + | gsplat-30k (1 GPU) | 28.95 | 0.870 | 0.135 | **5.7 GB** | **35m49s** | + +---------------------+-------+-------+-------+------------------+------------+ + | gsplat-30k (4 GPUs) | 28.91 | 0.871 | 0.135 | **2.0 GB** | **11m28s** | + +---------------------+-------+-------+-------+------------------+------------+ This repo comes with a standalone script (:code:`examples/simple_trainer.py`) that reproduces the `Gaussian Splatting `_ with @@ -131,7 +133,7 @@ is different from what's reported in the original paper that uses :code:`from lpipsPyTorch import lpips`. The evaluation of `gsplat-X` can be reproduced with the command -:code:`cd examples; bash benchmark.sh` +:code:`cd examples; bash benchmarks/basic.sh` within the gsplat repo (commit 6acdce4). The evaluation of `inria-X` can be diff --git a/examples/benchmark.sh b/examples/benchmarks/basic.sh similarity index 72% rename from examples/benchmark.sh rename to examples/benchmarks/basic.sh index 889e9f0c5..e804285dc 100644 --- a/examples/benchmark.sh +++ b/examples/benchmarks/basic.sh @@ -11,14 +11,14 @@ do echo "Running $SCENE" # train without eval - python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # run eval and render for CKPT in $RESULT_DIR/$SCENE/ckpts/*; do - python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ --ckpt $CKPT @@ -30,7 +30,7 @@ for SCENE in bicycle bonsai counter garden kitchen room stump; do echo "=== Eval Stats ===" - for STATS in $RESULT_DIR/$SCENE/stats/val*; + for STATS in $RESULT_DIR/$SCENE/stats/val*.json; do echo $STATS cat $STATS; @@ -39,7 +39,7 @@ do echo "=== Train Stats ===" - for STATS in $RESULT_DIR/$SCENE/stats/train*; + for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; do echo $STATS cat $STATS; diff --git a/examples/benchmarks/basic_4gpus.sh b/examples/benchmarks/basic_4gpus.sh new file mode 100644 index 000000000..3c3ad334e --- /dev/null +++ b/examples/benchmarks/basic_4gpus.sh @@ -0,0 +1,43 @@ +RESULT_DIR=results/benchmark_4gpus + +for SCENE in bicycle bonsai counter garden kitchen room stump; +do + if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ]; then + DATA_FACTOR=4 + else + DATA_FACTOR=2 + fi + + echo "Running $SCENE" + + # train and eval at the last step + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + # 4 GPUs is effectively 4x batch size so we scale down the steps by 4x as well. + # "--packed" reduces the data transfer between GPUs, which leads to faster training. + --steps_scaler 0.25 --packed \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + +done + + +for SCENE in bicycle bonsai counter garden kitchen room stump; +do + echo "=== Eval Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/val_step7499.json; + do + echo $STATS + cat $STATS; + echo + done + + echo "=== Train Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/train_step7499_rank0.json; + do + echo $STATS + cat $STATS; + echo + done +done \ No newline at end of file diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 9d7bb22d0..a96db1214 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -16,11 +16,13 @@ from datasets.colmap import Dataset, Parser from datasets.traj import generate_interpolated_path from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed +from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy @@ -166,6 +168,8 @@ def create_splats_with_optimizers( batch_size: int = 1, feature_dim: Optional[int] = None, device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: if init_type == "sfm": points = torch.from_numpy(parser.points).float() @@ -176,11 +180,17 @@ def create_splats_with_optimizers( else: raise ValueError("Please specify a correct init_type: sfm or random") - N = points.shape[0] # Initialize the GS size to be the average dist of the 3 nearest neighbors dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + rgbs = rgbs[world_rank::world_size] + scales = scales[world_rank::world_size] + + N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] @@ -210,11 +220,13 @@ def create_splats_with_optimizers( # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ # Note that this would not make the training exactly equivalent, see # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size optimizers = { name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( - [{"params": splats[name], "lr": lr * math.sqrt(batch_size)}], - eps=1e-15 / math.sqrt(batch_size), - betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)), + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), ) for name, _, lr in params } @@ -224,11 +236,16 @@ def create_splats_with_optimizers( class Runner: """Engine for training and testing.""" - def __init__(self, cfg: Config) -> None: - set_random_seed(42) + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) self.cfg = cfg - self.device = "cuda" + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -276,6 +293,8 @@ def __init__(self, cfg: Config) -> None: batch_size=cfg.batch_size, feature_dim=feature_dim, device=self.device, + world_rank=world_rank, + world_size=world_size, ) print("Model initialized. Number of GS:", len(self.splats["means"])) @@ -309,10 +328,14 @@ def __init__(self, cfg: Config) -> None: weight_decay=cfg.pose_opt_reg, ) ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) if cfg.pose_noise > 0.0: self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) self.app_optimizers = [] if cfg.app_opt: @@ -333,6 +356,8 @@ def __init__(self, cfg: Config) -> None: lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), ), ] + if world_size > 1: + self.app_module = DDP(self.app_module) # Losses & Metrics. self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) @@ -393,6 +418,7 @@ def rasterize_splats( absgrad=self.cfg.absgrad, sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, **kwargs, ) return render_colors, render_alphas, info @@ -400,10 +426,13 @@ def rasterize_splats( def train(self): cfg = self.cfg device = self.device + world_rank = self.world_rank + world_size = self.world_size # Dump cfg. - with open(f"{cfg.result_dir}/cfg.json", "w") as f: - json.dump(vars(cfg), f) + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) max_steps = cfg.max_steps init_step = 0 @@ -536,7 +565,16 @@ def train(self): desc += f"pose err={pose_err.item():.6f}| " pbar.set_description(desc) - if cfg.tb_every > 0 and step % cfg.tb_every == 0: + # write images (gt and render) + # if world_rank == 0 and step % 800 == 0: + # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + # canvas = canvas.reshape(-1, *canvas.shape[2:]) + # imageio.imwrite( + # f"{self.render_dir}/train_rank{self.world_rank}.png", + # (canvas * 255).astype(np.uint8), + # ) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: mem = torch.cuda.max_memory_allocated() / 1024**3 self.writer.add_scalar("train/loss", loss.item(), step) self.writer.add_scalar("train/l1loss", l1loss.item(), step) @@ -551,12 +589,42 @@ def train(self): self.writer.add_image("train/render", canvas, step) self.writer.flush() + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + self.strategy.step_post_backward( params=self.splats, optimizers=self.optimizers, state=self.strategy_state, step=step, info=info, + packed=cfg.packed, ) # Turn Gradients into Sparse Tensor before running optimizer @@ -587,25 +655,6 @@ def train(self): for scheduler in schedulers: scheduler.step() - # save checkpoint - if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: - mem = torch.cuda.max_memory_allocated() / 1024**3 - stats = { - "mem": mem, - "ellipse_time": time.time() - global_tic, - "num_GS": len(self.splats["means"]), - } - print("Step: ", step, stats) - with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: - json.dump(stats, f) - torch.save( - { - "step": step, - "splats": self.splats.state_dict(), - }, - f"{self.ckpt_dir}/ckpt_{step}.pt", - ) - # eval the full set if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: self.eval(step) @@ -628,6 +677,8 @@ def eval(self, step: int): print("Running evaluation...") cfg = self.cfg device = self.device + world_rank = self.world_rank + world_size = self.world_size valloader = torch.utils.data.DataLoader( self.valset, batch_size=1, shuffle=False, num_workers=1 @@ -655,42 +706,45 @@ def eval(self, step: int): torch.cuda.synchronize() ellipse_time += time.time() - tic - # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() - imageio.imwrite( - f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) - ) + if world_rank == 0: + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", + (canvas * 255).astype(np.uint8), + ) - pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] - colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] - metrics["psnr"].append(self.psnr(colors, pixels)) - metrics["ssim"].append(self.ssim(colors, pixels)) - metrics["lpips"].append(self.lpips(colors, pixels)) - - ellipse_time /= len(valloader) - - psnr = torch.stack(metrics["psnr"]).mean() - ssim = torch.stack(metrics["ssim"]).mean() - lpips = torch.stack(metrics["lpips"]).mean() - print( - f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " - f"Time: {ellipse_time:.3f}s/image " - f"Number of GS: {len(self.splats['means'])}" - ) - # save stats as json - stats = { - "psnr": psnr.item(), - "ssim": ssim.item(), - "lpips": lpips.item(), - "ellipse_time": ellipse_time, - "num_GS": len(self.splats["means"]), - } - with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: - json.dump(stats, f) - # save stats to tensorboard - for k, v in stats.items(): - self.writer.add_scalar(f"val/{k}", v, step) - self.writer.flush() + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() @torch.no_grad() def render_traj(self, step: int): @@ -767,8 +821,13 @@ def _viewer_render_fn( return render_colors[0].cpu().numpy() -def main(cfg: Config): - runner = Runner(cfg) +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) if cfg.ckpt is not None: # run eval only @@ -786,6 +845,18 @@ def main(cfg: Config): if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py --steps_scaler 0.25 + + """ + cfg = tyro.cli(Config) cfg.adjust_steps(cfg.steps_scaler) - main(cfg) + cli(main, cfg, verbose=True) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 059ec70d5..483cbd66b 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -17,11 +17,13 @@ from datasets.traj import generate_interpolated_path from simple_trainer import create_splats_with_optimizers from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from utils import AppearanceOptModule, CameraOptModule, set_random_seed +from gsplat.distributed import cli from gsplat.rendering import rasterization from gsplat.strategy import MCMCStrategy @@ -152,11 +154,16 @@ def adjust_steps(self, factor: float): class Runner: """Engine for training and testing.""" - def __init__(self, cfg: Config) -> None: - set_random_seed(42) + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) self.cfg = cfg - self.device = "cuda" + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -204,6 +211,8 @@ def __init__(self, cfg: Config) -> None: batch_size=cfg.batch_size, feature_dim=feature_dim, device=self.device, + world_rank=world_rank, + world_size=world_size, ) print("Model initialized. Number of GS:", len(self.splats["means"])) @@ -230,10 +239,14 @@ def __init__(self, cfg: Config) -> None: weight_decay=cfg.pose_opt_reg, ) ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) if cfg.pose_noise > 0.0: self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) self.app_optimizers = [] if cfg.app_opt: @@ -254,6 +267,8 @@ def __init__(self, cfg: Config) -> None: lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), ), ] + if world_size > 1: + self.app_module = DDP(self.app_module) # Losses & Metrics. self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) @@ -314,6 +329,7 @@ def rasterize_splats( absgrad=self.cfg.absgrad, sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, **kwargs, ) return render_colors, render_alphas, info @@ -321,10 +337,13 @@ def rasterize_splats( def train(self): cfg = self.cfg device = self.device + world_rank = self.world_rank + world_size = self.world_size # Dump cfg. - with open(f"{cfg.result_dir}/cfg.json", "w") as f: - json.dump(vars(cfg), f) + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) max_steps = cfg.max_steps init_step = 0 @@ -459,7 +478,16 @@ def train(self): desc += f"pose err={pose_err.item():.6f}| " pbar.set_description(desc) - if cfg.tb_every > 0 and step % cfg.tb_every == 0: + # write images (gt and render) + # if world_rank == 0 and step % 800 == 0: + # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + # canvas = canvas.reshape(-1, *canvas.shape[2:]) + # imageio.imwrite( + # f"{self.render_dir}/train_rank{self.world_rank}.png", + # (canvas * 255).astype(np.uint8), + # ) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: mem = torch.cuda.max_memory_allocated() / 1024**3 self.writer.add_scalar("train/loss", loss.item(), step) self.writer.add_scalar("train/l1loss", l1loss.item(), step) @@ -474,6 +502,35 @@ def train(self): self.writer.add_image("train/render", canvas, step) self.writer.flush() + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + self.strategy.step_post_backward( params=self.splats, optimizers=self.optimizers, @@ -511,25 +568,6 @@ def train(self): for scheduler in schedulers: scheduler.step() - # save checkpoint - if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: - mem = torch.cuda.max_memory_allocated() / 1024**3 - stats = { - "mem": mem, - "ellipse_time": time.time() - global_tic, - "num_GS": len(self.splats["means"]), - } - print("Step: ", step, stats) - with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: - json.dump(stats, f) - torch.save( - { - "step": step, - "splats": self.splats.state_dict(), - }, - f"{self.ckpt_dir}/ckpt_{step}.pt", - ) - # eval the full set if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: self.eval(step) @@ -552,6 +590,8 @@ def eval(self, step: int): print("Running evaluation...") cfg = self.cfg device = self.device + world_rank = self.world_rank + world_size = self.world_size valloader = torch.utils.data.DataLoader( self.valset, batch_size=1, shuffle=False, num_workers=1 @@ -579,42 +619,45 @@ def eval(self, step: int): torch.cuda.synchronize() ellipse_time += time.time() - tic - # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() - imageio.imwrite( - f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) - ) + if world_rank == 0: + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", + (canvas * 255).astype(np.uint8), + ) - pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] - colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] - metrics["psnr"].append(self.psnr(colors, pixels)) - metrics["ssim"].append(self.ssim(colors, pixels)) - metrics["lpips"].append(self.lpips(colors, pixels)) - - ellipse_time /= len(valloader) - - psnr = torch.stack(metrics["psnr"]).mean() - ssim = torch.stack(metrics["ssim"]).mean() - lpips = torch.stack(metrics["lpips"]).mean() - print( - f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " - f"Time: {ellipse_time:.3f}s/image " - f"Number of GS: {len(self.splats['means'])}" - ) - # save stats as json - stats = { - "psnr": psnr.item(), - "ssim": ssim.item(), - "lpips": lpips.item(), - "ellipse_time": ellipse_time, - "num_GS": len(self.splats["means"]), - } - with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: - json.dump(stats, f) - # save stats to tensorboard - for k, v in stats.items(): - self.writer.add_scalar(f"val/{k}", v, step) - self.writer.flush() + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() @torch.no_grad() def render_traj(self, step: int): @@ -691,8 +734,13 @@ def _viewer_render_fn( return render_colors[0].cpu().numpy() -def main(cfg: Config): - runner = Runner(cfg) +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) if cfg.ckpt is not None: # run eval only @@ -712,4 +760,4 @@ def main(cfg: Config): if __name__ == "__main__": cfg = tyro.cli(Config) cfg.adjust_steps(cfg.steps_scaler) - main(cfg) + cli(main, cfg, verbose=True) diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 3662433a8..7a8f5b90f 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -16,171 +16,222 @@ import numpy as np import torch import torch.nn.functional as F +import tqdm import viser from gsplat._helper import load_test_data +from gsplat.distributed import cli from gsplat.rendering import rasterization -parser = argparse.ArgumentParser() -parser.add_argument( - "--output_dir", type=str, default="results/", help="where to dump outputs" -) -parser.add_argument( - "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN" -) -parser.add_argument("--ckpt", type=str, default=None, help="path to the .pt file") -parser.add_argument("--port", type=int, default=8080, help="port for the viewer server") -parser.add_argument( - "--backend", type=str, default="gsplat", help="gsplat, gsplat_legacy, inria" -) -args = parser.parse_args() -assert args.scene_grid % 2 == 1, "scene_grid must be odd" - -torch.manual_seed(42) -device = "cuda" - -if args.ckpt is None: - ( - means, - quats, - scales, - opacities, - colors, - viewmats, - Ks, - width, - height, - ) = load_test_data(device=device, scene_grid=args.scene_grid) - sh_degree = None - C = len(viewmats) - N = len(means) - print("Number of Gaussians:", N) - - # batched render - render_colors, render_alphas, meta = rasterization( - means, # [N, 3] - quats, # [N, 4] - scales, # [N, 3] - opacities, # [N] - colors, # [N, 3] - viewmats, # [C, 4, 4] - Ks, # [C, 3, 3] - width, - height, - render_mode="RGB+D", - ) - assert render_colors.shape == (C, height, width, 4) - assert render_alphas.shape == (C, height, width, 1) - - render_rgbs = render_colors[..., 0:3] - render_depths = render_colors[..., 3:4] - render_depths = render_depths / render_depths.max() - - # dump batch images - os.makedirs(args.output_dir, exist_ok=True) - canvas = ( - torch.cat( - [ - render_rgbs.reshape(C * height, width, 3), - render_depths.reshape(C * height, width, 1).expand(-1, -1, 3), - render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3), - ], - dim=1, + +def main(local_rank: int, world_rank, world_size: int, args): + torch.manual_seed(42) + device = torch.device("cuda", local_rank) + + if args.ckpt is None: + ( + means, + quats, + scales, + opacities, + colors, + viewmats, + Ks, + width, + height, + ) = load_test_data(device=device, scene_grid=args.scene_grid) + + assert world_size <= 2 + means = means[world_rank::world_size].contiguous() + means.requires_grad = True + quats = quats[world_rank::world_size].contiguous() + quats.requires_grad = True + scales = scales[world_rank::world_size].contiguous() + scales.requires_grad = True + opacities = opacities[world_rank::world_size].contiguous() + opacities.requires_grad = True + colors = colors[world_rank::world_size].contiguous() + colors.requires_grad = True + + viewmats = viewmats[world_rank::world_size][:1].contiguous() + Ks = Ks[world_rank::world_size][:1].contiguous() + + sh_degree = None + C = len(viewmats) + N = len(means) + print("rank", world_rank, "Number of Gaussians:", N, "Number of Cameras:", C) + + # batched render + for _ in tqdm.trange(1): + render_colors, render_alphas, meta = rasterization( + means, # [N, 3] + quats, # [N, 4] + scales, # [N, 3] + opacities, # [N] + colors, # [N, 3] + viewmats, # [C, 4, 4] + Ks, # [C, 3, 3] + width, + height, + render_mode="RGB+D", + packed=False, + distributed=world_size > 1, + ) + C = render_colors.shape[0] + assert render_colors.shape == (C, height, width, 4) + assert render_alphas.shape == (C, height, width, 1) + render_colors.sum().backward() + + render_rgbs = render_colors[..., 0:3] + render_depths = render_colors[..., 3:4] + render_depths = render_depths / render_depths.max() + + # dump batch images + os.makedirs(args.output_dir, exist_ok=True) + canvas = ( + torch.cat( + [ + render_rgbs.reshape(C * height, width, 3), + render_depths.reshape(C * height, width, 1).expand(-1, -1, 3), + render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3), + ], + dim=1, + ) + .detach() + .cpu() + .numpy() + ) + imageio.imsave( + f"{args.output_dir}/render_rank{world_rank}.png", + (canvas * 255).astype(np.uint8), + ) + else: + means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] + for ckpt_path in args.ckpt: + ckpt = torch.load(ckpt_path, map_location=device)["splats"] + means.append(ckpt["means3d"]) + quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) + scales.append(torch.exp(ckpt["scales"])) + opacities.append(torch.sigmoid(ckpt["opacities"])) + sh0.append(ckpt["sh0"]) + shN.append(ckpt["shN"]) + means = torch.cat(means, dim=0) + quats = torch.cat(quats, dim=0) + scales = torch.cat(scales, dim=0) + opacities = torch.cat(opacities, dim=0) + sh0 = torch.cat(sh0, dim=0) + shN = torch.cat(shN, dim=0) + colors = torch.cat([sh0, shN], dim=-2) + sh_degree = int(math.sqrt(colors.shape[-2]) - 1) + + # # crop + # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device) + # edges = aabb[3:] - aabb[:3] + # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1) + # sel = torch.where(sel)[0] + # means, quats, scales, colors, opacities = ( + # means[sel], + # quats[sel], + # scales[sel], + # colors[sel], + # opacities[sel], + # ) + + # # repeat the scene into a grid (to mimic a large-scale setting) + # repeats = args.scene_grid + # gridx, gridy = torch.meshgrid( + # [ + # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), + # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), + # ], + # indexing="ij", + # ) + # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape( + # -1, 3 + # ) + # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :] + # means = means.reshape(-1, 3) + # quats = quats.repeat(repeats**2, 1) + # scales = scales.repeat(repeats**2, 1) + # colors = colors.repeat(repeats**2, 1, 1) + # opacities = opacities.repeat(repeats**2) + print("Number of Gaussians:", len(means)) + + # register and open viewer + @torch.no_grad() + def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]): + width, height = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(device) + K = torch.from_numpy(K).float().to(device) + viewmat = c2w.inverse() + + if args.backend == "gsplat": + rasterization_fn = rasterization + elif args.backend == "gsplat_legacy": + from gsplat import rasterization_legacy_wrapper + + rasterization_fn = rasterization_legacy_wrapper + elif args.backend == "inria": + from gsplat import rasterization_inria_wrapper + + rasterization_fn = rasterization_inria_wrapper + else: + raise ValueError + + render_colors, render_alphas, meta = rasterization_fn( + means, # [N, 3] + quats, # [N, 4] + scales, # [N, 3] + opacities, # [N] + colors, # [N, 3] + viewmat[None], # [1, 4, 4] + K[None], # [1, 3, 3] + width, + height, + sh_degree=sh_degree, + render_mode="RGB", + # this is to speedup large-scale rendering by skipping far-away Gaussians. + radius_clip=3, ) - .cpu() - .numpy() + render_rgbs = render_colors[0, ..., 0:3].cpu().numpy() + return render_rgbs + + server = viser.ViserServer(port=args.port, verbose=False) + _ = nerfview.Viewer( + server=server, + render_fn=viewer_render_fn, + mode="rendering", ) - imageio.imsave(f"{args.output_dir}/render.png", (canvas * 255).astype(np.uint8)) -else: - ckpt = torch.load(args.ckpt, map_location=device)["splats"] - means = ckpt["means"] - quats = F.normalize(ckpt["quats"], p=2, dim=-1) - scales = torch.exp(ckpt["scales"]) - opacities = torch.sigmoid(ckpt["opacities"]) - sh0 = ckpt["sh0"] - shN = ckpt["shN"] - colors = torch.cat([sh0, shN], dim=-2) - sh_degree = int(math.sqrt(colors.shape[-2]) - 1) - - # crop - aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device) - edges = aabb[3:] - aabb[:3] - sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1) - sel = torch.where(sel)[0] - means, quats, scales, colors, opacities = ( - means[sel], - quats[sel], - scales[sel], - colors[sel], - opacities[sel], + print("Viewer running... Ctrl+C to exit.") + time.sleep(100000) + + +if __name__ == "__main__": + """ + # Use single GPU to view the scene + CUDA_VISIBLE_DEVICES=0 python simple_viewer.py \ + --ckpt results/garden/ckpts/ckpt_3499_rank0.pt results/garden/ckpts/ckpt_3499_rank1.pt \ + --port 8081 + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", type=str, default="results/", help="where to dump outputs" ) - - # repeat the scene into a grid (to mimic a large-scale setting) - repeats = args.scene_grid - gridx, gridy = torch.meshgrid( - [ - torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), - torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), - ], - indexing="ij", + parser.add_argument( + "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN" ) - grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(-1, 3) - means = means[None, :, :] + grid[:, None, :] * edges[None, None, :] - means = means.reshape(-1, 3) - quats = quats.repeat(repeats**2, 1) - scales = scales.repeat(repeats**2, 1) - colors = colors.repeat(repeats**2, 1, 1) - opacities = opacities.repeat(repeats**2) - print("Number of Gaussians:", len(means)) - - -# register and open viewer -@torch.no_grad() -def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]): - width, height = img_wh - c2w = camera_state.c2w - K = camera_state.get_K(img_wh) - c2w = torch.from_numpy(c2w).float().to(device) - K = torch.from_numpy(K).float().to(device) - viewmat = c2w.inverse() - - if args.backend == "gsplat": - rasterization_fn = rasterization - elif args.backend == "gsplat_legacy": - from gsplat import rasterization_legacy_wrapper - - rasterization_fn = rasterization_legacy_wrapper - elif args.backend == "inria": - from gsplat import rasterization_inria_wrapper - - rasterization_fn = rasterization_inria_wrapper - else: - raise ValueError - - render_colors, render_alphas, meta = rasterization_fn( - means, # [N, 3] - quats, # [N, 4] - scales, # [N, 3] - opacities, # [N] - colors, # [N, 3] - viewmat[None], # [1, 4, 4] - K[None], # [1, 3, 3] - width, - height, - sh_degree=sh_degree, - render_mode="RGB", - # this is to speedup large-scale rendering by skipping far-away Gaussians. - radius_clip=3, + parser.add_argument( + "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file" + ) + parser.add_argument( + "--port", type=int, default=8080, help="port for the viewer server" ) - render_rgbs = render_colors[0, ..., 0:3].cpu().numpy() - return render_rgbs - - -server = viser.ViserServer(port=args.port, verbose=False) -_ = nerfview.Viewer( - server=server, - render_fn=viewer_render_fn, - mode="rendering", -) -print("Viewer running... Ctrl+C to exit.") -time.sleep(100000) + parser.add_argument( + "--backend", type=str, default="gsplat", help="gsplat, gsplat_legacy, inria" + ) + args = parser.parse_args() + assert args.scene_grid % 2 == 1, "scene_grid must be odd" + + cli(main, args, verbose=True) diff --git a/gsplat/distributed.py b/gsplat/distributed.py new file mode 100644 index 000000000..cab559df5 --- /dev/null +++ b/gsplat/distributed.py @@ -0,0 +1,360 @@ +import os +from typing import Any, Callable, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as distF +from torch import Tensor + + +def all_gather_int32( + world_size: int, value: Union[int, Tensor], device: Optional[torch.device] = None +) -> List[int]: + """Gather an 32-bit integer from all ranks. + + .. note:: + This implementation is faster than using `torch.distributed.all_gather_object`. + + .. note:: + This function is not differentiable to the input tensor. + + Args: + world_size: The total number of ranks. + value: The integer to gather. Could be a scalar or a tensor. + device: Only required if `value` is a scalar. The device to put the tensor on. + + Returns: + A list of integers, where the i-th element is the value from the i-th rank. + Could be a list of scalars or tensors based on the input `value`. + """ + if world_size == 1: + return [value] + + # move to CUDA + if isinstance(value, int): + assert device is not None, "device is required for scalar input" + value_tensor = torch.tensor(value, dtype=torch.int, device=device) + else: + value_tensor = value + assert value_tensor.is_cuda, "value should be on CUDA" + + # gather + collected = torch.empty( + world_size, dtype=value_tensor.dtype, device=value_tensor.device + ) + dist.all_gather_into_tensor(collected, value_tensor) + + if isinstance(value, int): + # return as list of integers on CPU + return collected.tolist() + else: + # return as list of single-element tensors + return collected.unbind() + + +def all_to_all_int32( + world_size: int, + values: List[Union[int, Tensor]], + device: Optional[torch.device] = None, +) -> List[int]: + """Exchange 32-bit integers between all ranks in a many-to-many fashion. + + .. note:: + This function is not differentiable to the input tensors. + + Args: + world_size: The total number of ranks. + values: A list of integers to exchange. Could be a list of scalars or tensors. + Should have the same length as `world_size`. + device: Only required if `values` contains scalars. The device to put the tensors on. + + Returns: + A list of integers. Could be a list of scalars or tensors based on the input `values`. + Have the same length as `world_size`. + """ + if world_size == 1: + return values + + assert ( + len(values) == world_size + ), "The length of values should be equal to world_size" + + if any(isinstance(v, int) for v in values): + assert device is not None, "device is required for scalar input" + + # move to CUDA + values_tensor = [ + (torch.tensor(v, dtype=torch.int, device=device) if isinstance(v, int) else v) + for v in values + ] + + # all_to_all + collected = [torch.empty_like(v) for v in values_tensor] + dist.all_to_all(collected, values_tensor) + + # return as a list of integers or tensors, based on the input + return [ + v.item() if isinstance(tensor, int) else v + for v, tensor in zip(collected, values) + ] + + +def all_gather_tensor_list(world_size: int, tensor_list: List[Tensor]) -> List[Tensor]: + """Gather a list of tensors from all ranks. + + .. note:: + This function expects the tensors in the `tensor_list` to have the same shape + and data type across all ranks. + + .. note:: + This function is differentiable to the tensors in `tensor_list`. + + .. note:: + For efficiency, this function internally concatenates the tensors in `tensor_list` + and performs a single gather operation. Thus it requires all tensors in the list + to have the same first-dimension size. + + Args: + world_size: The total number of ranks. + tensor_list: A list of tensors to gather. The size of the first dimension of all + the tensors in this list should be the same. The rest dimensions can be + arbitrary. Shape: [(N, *), (N, *), ...] + + Returns: + A list of tensors gathered from all ranks, where the i-th element is corresponding + to the i-th tensor in `tensor_list`. The returned tensors have the shape + [(N * world_size, *), (N * world_size, *), ...] + + Examples: + + .. code-block:: python + + >>> # on rank 0 + >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] + >>> # on rank 1 + >>> # tensor_list = [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])] + >>> collected = all_gather_tensor_list(world_rank, world_size, tensor_list) + >>> # on both ranks + >>> # [torch.tensor([1, 2, 3, 7, 8, 9]), torch.tensor([4, 5, 6, 10, 11, 12])] + + """ + if world_size == 1: + return tensor_list + + N = len(tensor_list[0]) + for tensor in tensor_list: + assert len(tensor) == N, "All tensors should have the same first dimension size" + + # concatenate tensors and record their sizes + data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1) + sizes = [t.numel() // N for t in tensor_list] + + if data.requires_grad: + # differentiable gather + collected = distF.all_gather(data) + else: + # non-differentiable gather + collected = [torch.empty_like(data) for _ in range(world_size)] + torch.distributed.all_gather(collected, data) + collected = torch.cat(collected, dim=0) + + # split the collected tensor and reshape to the original shape + out_tensor_tuple = torch.split(collected, sizes, dim=-1) + out_tensor_list = [] + for out_tensor, tensor in zip(out_tensor_tuple, tensor_list): + out_tensor = out_tensor.view(-1, *tensor.shape[1:]) # [N * world_size, *] + out_tensor_list.append(out_tensor) + return out_tensor_list + + +def all_to_all_tensor_list( + world_size: int, + tensor_list: List[Tensor], + splits: List[Union[int, Tensor]], + output_splits: Optional[List[Union[int, Tensor]]] = None, +) -> List[Tensor]: + """Split and exchange tensors between all ranks in a many-to-many fashion. + + Args: + world_size: The total number of ranks. + tensor_list: A list of tensors to split and exchange. The size of the first + dimension of all the tensors in this list should be the same. The rest + dimensions can be arbitrary. Shape: [(N, *), (N, *), ...] + splits: A list of integers representing the number of elements to send to each + rank. It will be used to split the tensor in the `tensor_list`. + The sum of the elements in this list should be equal to N. The size of this + list should be equal to the `world_size`. + output_splits: Splits of the output tensors. Could be pre-calculated via + `all_to_all_int32(world_size, splits)`. If not provided, it will + be calculated internally. + + Returns: + A list of tensors exchanged between all ranks, where the i-th element is + corresponding to the i-th tensor in `tensor_list`. Note the shape of the + returned tensors might be different from the input tensors, depending on the + splits. + + Examples: + + .. code-block:: python + + >>> # on rank 0 + >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] + >>> # splits = [2, 1] + + >>> # on rank 1 + >>> # tensor_list = [torch.tensor([7, 8]), torch.tensor([9, 10])] + >>> # splits = [1, 1] + + >>> collected = all_to_all_tensor_list(world_rank, world_size, tensor_list, splits) + + >>> # on rank 0 + >>> # [torch.tensor([1, 2, 7]), torch.tensor([4, 5, 9])] + >>> # on rank 1 + >>> # [torch.tensor([3, 8]), torch.tensor([6, 10])] + + """ + if world_size == 1: + return tensor_list + + N = len(tensor_list[0]) + for tensor in tensor_list: + assert len(tensor) == N, "All tensors should have the same first dimension size" + + assert ( + len(splits) == world_size + ), "The length of splits should be equal to world_size" + + # concatenate tensors and record their sizes + data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1) + sizes = [t.numel() // N for t in tensor_list] + + # all_to_all + if output_splits is not None: + collected_splits = output_splits + else: + collected_splits = all_to_all_int32(world_size, splits, device=data.device) + collected = [ + torch.empty((l, *data.shape[1:]), dtype=data.dtype, device=data.device) + for l in collected_splits + ] + # torch.split requires tuple of integers + splits = [s.item() if isinstance(s, Tensor) else s for s in splits] + if data.requires_grad: + # differentiable all_to_all + distF.all_to_all(collected, data.split(splits, dim=0)) + else: + # non-differentiable all_to_all + torch.distributed.all_to_all(collected, list(data.split(splits, dim=0))) + collected = torch.cat(collected, dim=0) + + # split the collected tensor and reshape to the original shape + out_tensor_tuple = torch.split(collected, sizes, dim=-1) + out_tensor_list = [] + for out_tensor, tensor in zip(out_tensor_tuple, tensor_list): + out_tensor = out_tensor.view(-1, *tensor.shape[1:]) + out_tensor_list.append(out_tensor) + return out_tensor_list + + +def _find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def _distributed_worker( + world_rank: int, + world_size: int, + fn: Callable, + args: Any, + local_rank: Optional[int] = None, + verbose: bool = False, +) -> bool: + if local_rank is None: # single Node + local_rank = world_rank + if verbose: + print("Distributed worker: %d / %d" % (world_rank + 1, world_size)) + distributed = world_size > 1 + if distributed: + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend="nccl", world_size=world_size, rank=world_rank + ) + # Dump collection that participates all ranks. + # This initializes the communicator required by `batch_isend_irecv`. + # See: https://github.com/pytorch/pytorch/pull/74701 + _ = [None for _ in range(world_size)] + torch.distributed.all_gather_object(_, 0) + fn(local_rank, world_rank, world_size, args) + if distributed: + torch.distributed.barrier() + torch.distributed.destroy_process_group() + if verbose: + print("Job Done for worker: %d / %d" % (world_rank + 1, world_size)) + return True + + +def cli(fn: Callable, args: Any, verbose: bool = False) -> bool: + """Wrapper to run a function in a distributed environment. + + The function `fn` should have the following signature: + + ```python + def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None: + pass + ``` + + Usage: + + ```python + # Launch with "CUDA_VISIBLE_DEVICES=0,1,2,3 python my_script.py" + if __name__ == "__main__": + cli(fn, None, verbose=True) + ``` + """ + assert torch.cuda.is_available(), "CUDA device is required!" + if "OMPI_COMM_WORLD_SIZE" in os.environ: # multi-node + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) # dist.get_world_size() + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) # dist.get_rank() + return _distributed_worker( + world_rank, world_size, fn, args, local_rank, verbose + ) + + world_size = torch.cuda.device_count() + distributed = world_size > 1 + + if distributed: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(_find_free_port()) + process_context = torch.multiprocessing.spawn( + _distributed_worker, + args=(world_size, fn, args, None, verbose), + nprocs=world_size, + join=False, + ) + try: + process_context.join() + except KeyboardInterrupt: + # this is important. + # if we do not explicitly terminate all launched subprocesses, + # they would continue living even after this main process ends, + # eventually making the OD machine unusable! + for i, process in enumerate(process_context.processes): + if process.is_alive(): + if verbose: + print("terminating process " + str(i) + "...") + process.terminate() + process.join() + if verbose: + print("process " + str(i) + " finished") + return True + else: + return _distributed_worker(0, 1, fn=fn, args=args) diff --git a/gsplat/profile.py b/gsplat/profile.py new file mode 100644 index 000000000..669d363fb --- /dev/null +++ b/gsplat/profile.py @@ -0,0 +1,59 @@ +import os +import time +from functools import wraps +from typing import Callable, Optional + +import torch + +profiler = {} + + +class timeit(object): + """Profiler that is controled by the TIMEIT environment variable. + + If TIMEIT is set to 1, the profiler will measure the time taken by the decorated function. + + Usage: + + ```python + @timeit() + def my_function(): + pass + + # Or + + with timeit(name="stage1"): + my_function() + + print(profiler) + ``` + """ + + def __init__(self, name: str = "unnamed"): + self.name = name + self.start_time: Optional[float] = None + self.enabled = os.environ.get("TIMEIT", "0") == "1" + + def __enter__(self): + if self.enabled: + torch.cuda.synchronize() + self.start_time = time.perf_counter() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enabled: + torch.cuda.synchronize() + end_time = time.perf_counter() + total_time = end_time - self.start_time + if self.name not in profiler: + profiler[self.name] = total_time + else: + profiler[self.name] += total_time + + def __call__(self, f: Callable) -> Callable: + @wraps(f) + def decorated(*args, **kwargs): + with self: + self.name = f.__name__ + return f(*args, **kwargs) + + return decorated diff --git a/gsplat/relocation.py b/gsplat/relocation.py index d593d28bc..921e4afd9 100644 --- a/gsplat/relocation.py +++ b/gsplat/relocation.py @@ -1,5 +1,7 @@ +import math from typing import Tuple +import torch from torch import Tensor from .cuda._wrapper import _make_lazy_cuda_func @@ -31,6 +33,13 @@ def compute_relocation( **new_opacities**: The opacities of the new Gaussians. [N] **new_scales**: The scales of the Gaussians. [N, 3] """ + + N_MAX = 51 + BINOMS = torch.zeros((N_MAX, N_MAX), device=opacities.device) + for n in range(N_MAX): + for k in range(n + 1): + BINOMS[n, k] = math.comb(n, k) + N = opacities.shape[0] n_max, _ = binoms.shape assert scales.shape == (N, 3), scales.shape diff --git a/gsplat/rendering.py b/gsplat/rendering.py index da0d0ed74..a18363beb 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1,7 +1,8 @@ import math -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch +import torch.distributed from torch import Tensor from typing_extensions import Literal @@ -12,6 +13,12 @@ rasterize_to_pixels, spherical_harmonics, ) +from .distributed import ( + all_gather_int32, + all_gather_tensor_list, + all_to_all_int32, + all_to_all_tensor_list, +) def rasterization( @@ -37,6 +44,7 @@ def rasterization( absgrad: bool = False, rasterize_mode: Literal["classic", "antialiased"] = "classic", channel_chunk: int = 32, + distributed: bool = False, ) -> Tuple[Tensor, Tensor, Dict]: """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). @@ -44,6 +52,18 @@ def rasterization( we detail in the following notes. A complete profiling of the these features can be found in the :ref:`profiling` page. + .. note:: + **Multi-GPU Distributed Rasterization**: This function can be used in a multi-GPU + distributed scenario by setting `distributed` to True. When `distributed` is True, + a subset of total Gaussians could be passed into this function in each rank, and + the function will collaboratively render a set of images using Gaussians from all ranks. Note + to achieve balanced computation, it is recommended (not enforced) to have similar number of + Gaussians in each rank. But we do enforce that the number of cameras to be rendered + in each rank is the same. The function will return the rendered images + corresponds to the input cameras in each rank, and allows for gradients to flow back to the + Gaussians living in other ranks. For the details, please refer to the paper + `On Scaling Up 3D Gaussian Splatting Training `_. + .. note:: **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. @@ -148,6 +168,9 @@ def rasterization( channel_chunk: The number of channels to render in one go. Default is 32. If the required rendering channels are larger than this value, the rendering will be done looply in chunks. + distributed: Whether to use distributed rendering. Default is False. If True, + The input Gaussians are expected to be a subset of scene in each rank, and + the function will collaboratively render the images for all ranks. Returns: A tuple: @@ -188,9 +211,11 @@ def rasterization( 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size']) """ + meta = {} N = means.shape[0] C = viewmats.shape[0] + device = means.device assert means.shape == (N, 3), means.shape assert quats.shape == (N, 4), quats.shape assert scales.shape == (N, 3), scales.shape @@ -204,6 +229,10 @@ def rasterization( assert (colors.dim() == 2 and colors.shape[0] == N) or ( colors.dim() == 3 and colors.shape[:2] == (C, N) ), colors.shape + if distributed: + assert ( + colors.dim() == 2 + ), "Distributed mode only supports per-Gaussian colors." else: # treat colors as SH coefficients, should be in shape [N, K, 3] or [C, N, K, 3] # Allowing for activating partial SH bands @@ -213,6 +242,30 @@ def rasterization( colors.dim() == 4 and colors.shape[:2] == (C, N) and colors.shape[3] == 3 ), colors.shape assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape + if distributed: + assert ( + colors.dim() == 3 + ), "Distributed mode only supports per-Gaussian colors." + + if absgrad: + assert not distributed, "AbsGrad is not supported in distributed mode." + + # If in distributed mode, we distribute the projection computation over Gaussians + # and the rasterize computation over cameras. So first we gather the cameras + # from all ranks for projection. + if distributed: + world_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Gather the number of Gaussians in each rank. + N_world = all_gather_int32(world_size, N, device=device) + + # Enforce that the number of cameras is the same across all ranks. + C_world = [C] * world_size + viewmats, Ks = all_gather_tensor_list(world_size, [viewmats, Ks]) + + # Silently change C from local #Cameras to global #Cameras. + C = len(viewmats) # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. proj_results = fully_fused_projection( @@ -254,22 +307,19 @@ def rasterization( if compensations is not None: opacities = opacities * compensations - # Identify intersecting tiles - tile_width = math.ceil(width / float(tile_size)) - tile_height = math.ceil(height / float(tile_size)) - tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( - means2d, - radii, - depths, - tile_size, - tile_width, - tile_height, - packed=packed, - n_cameras=C, - camera_ids=camera_ids, - gaussian_ids=gaussian_ids, + meta.update( + { + # global camera_ids + "camera_ids": camera_ids, + # local gaussian_ids + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "conics": conics, + "opacities": opacities, + } ) - isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) # Turn colors into [C, N, D] or [nnz, D] to pass into rasterize_to_pixels() if sh_degree is None: @@ -305,7 +355,7 @@ def rasterization( dirs = means[None, :, :] - camtoworlds[:, None, :3, 3] # [C, N, 3] masks = radii > 0 # [C, N] if colors.dim() == 3: - # Turn [N, K, 3] into [C, N, 3] + # Turn [N, K, 3] into [C, N, K, 3] shs = colors.expand(C, -1, -1, -1) # [C, N, K, 3] else: # colors is already [C, N, K, 3] @@ -314,6 +364,92 @@ def rasterization( # make it apple-to-apple with Inria's CUDA Backend. colors = torch.clamp_min(colors + 0.5, 0.0) + # If in distributed mode, we need to scatter the GSs to the destination ranks, based + # on which cameras they are visible to, which we already figured out in the projection + # stage. + if distributed: + if packed: + # count how many elements need to be sent to each rank + cnts = torch.bincount(camera_ids, minlength=C) # all cameras + cnts = cnts.split(C_world, dim=0) + cnts = [cuts.sum() for cuts in cnts] + + # all to all communication across all ranks. After this step, each rank + # would have all the necessary GSs to render its own images. + collected_splits = all_to_all_int32(world_size, cnts, device=device) + (radii,) = all_to_all_tensor_list( + world_size, [radii], cnts, output_splits=collected_splits + ) + (means2d, depths, conics, opacities, colors) = all_to_all_tensor_list( + world_size, + [means2d, depths, conics, opacities, colors], + cnts, + output_splits=collected_splits, + ) + + # before sending the data, we should turn the camera_ids from global to local. + # i.e. the camera_ids produced by the projection stage are over all cameras world-wide, + # so we need to turn them into camera_ids that are local to each rank. + offsets = torch.tensor( + [0] + C_world[:-1], device=camera_ids.device, dtype=camera_ids.dtype + ) + offsets = torch.cumsum(offsets, dim=0) + offsets = offsets.repeat_interleave(torch.stack(cnts)) + camera_ids = camera_ids - offsets + + # and turn gaussian ids from local to global. + offsets = torch.tensor( + [0] + N_world[:-1], + device=gaussian_ids.device, + dtype=gaussian_ids.dtype, + ) + offsets = torch.cumsum(offsets, dim=0) + offsets = offsets.repeat_interleave(torch.stack(cnts)) + gaussian_ids = gaussian_ids + offsets + + # all to all communication across all ranks. + (camera_ids, gaussian_ids) = all_to_all_tensor_list( + world_size, + [camera_ids, gaussian_ids], + cnts, + output_splits=collected_splits, + ) + + # Silently change C from global #Cameras to local #Cameras. + C = C_world[world_rank] + + else: + # Silently change C from global #Cameras to local #Cameras. + C = C_world[world_rank] + + # all to all communication across all ranks. After this step, each rank + # would have all the necessary GSs to render its own images. + (radii,) = all_to_all_tensor_list( + world_size, + [radii.flatten(0, 1)], + splits=[C_i * N for C_i in C_world], + output_splits=[C * N_i for N_i in N_world], + ) + radii = radii.reshape(C, -1) + + (means2d, depths, conics, opacities, colors) = all_to_all_tensor_list( + world_size, + [ + means2d.flatten(0, 1), + depths.flatten(0, 1), + conics.flatten(0, 1), + opacities.flatten(0, 1), + colors.flatten(0, 1), + ], + splits=[C_i * N for C_i in C_world], + output_splits=[C * N_i for N_i in N_world], + ) + means2d = means2d.reshape(C, -1, 2) + depths = depths.reshape(C, -1) + conics = conics.reshape(C, -1, 3) + opacities = opacities.reshape(C, -1) + colors = colors.reshape(C, -1, colors.shape[-1]) + # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: colors = torch.cat((colors, depths[..., None]), dim=-1) @@ -327,6 +463,41 @@ def rasterization( backgrounds = torch.zeros(C, 1, device=backgrounds.device) else: # RGB pass + + # Identify intersecting tiles + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, + radii, + depths, + tile_size, + tile_width, + tile_height, + packed=packed, + n_cameras=C, + camera_ids=camera_ids, + gaussian_ids=gaussian_ids, + ) + # print("rank", world_rank, "Before isect_offset_encode") + isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + + meta.update( + { + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + "n_cameras": C, + } + ) + + # print("rank", world_rank, "Before rasterize_to_pixels") if colors.shape[-1] > channel_chunk: # slice into chunks n_chunks = (colors.shape[-1] + channel_chunk - 1) // channel_chunk @@ -381,25 +552,6 @@ def rasterization( dim=-1, ) - meta = { - "camera_ids": camera_ids, - "gaussian_ids": gaussian_ids, - "radii": radii, - "means2d": means2d, - "depths": depths, - "conics": conics, - "opacities": opacities, - "tile_width": tile_width, - "tile_height": tile_height, - "tiles_per_gauss": tiles_per_gauss, - "isect_ids": isect_ids, - "flatten_ids": flatten_ids, - "isect_offsets": isect_offsets, - "width": width, - "height": height, - "tile_size": tile_size, - "n_cameras": C, - } return render_colors, render_alphas, meta diff --git a/profiling/main.py b/profiling/main.py index 238f3ae66..bc47df3e4 100644 --- a/profiling/main.py +++ b/profiling/main.py @@ -12,6 +12,7 @@ from typing_extensions import Callable, Literal from gsplat._helper import load_test_data +from gsplat.distributed import cli from gsplat.rendering import rasterization RESOLUTIONS = { @@ -46,6 +47,8 @@ def main( backend: Literal["gsplat2", "gsplat", "inria"] = "gsplat2", repeats: int = 100, memory_history: bool = False, + world_rank: int = 0, + world_size: int = 1, ): ( means, @@ -66,6 +69,13 @@ def main( # more channels colors = colors[:, :1].repeat(1, channels) + # distribute the gaussians + means = means[world_rank::world_size].contiguous() + quats = quats[world_rank::world_size].contiguous() + scales = scales[world_rank::world_size].contiguous() + opacities = opacities[world_rank::world_size].contiguous() + colors = colors[world_rank::world_size].contiguous() + means.requires_grad = True quats.requires_grad = True scales.requires_grad = True @@ -112,6 +122,7 @@ def main( far_plane=100.0, radius_clip=3.0, sparse_grad=sparse_grad, + distributed=world_size > 1, ) mem_toc_fwd = torch.cuda.max_memory_allocated() / 1024**3 - mem_tic @@ -144,55 +155,9 @@ def backward(): } -if __name__ == "__main__": - import argparse - +def worker(local_rank: int, world_rank: int, world_size: int, args): from tabulate import tabulate - parser = argparse.ArgumentParser() - parser.add_argument( - "--backends", - nargs="+", - type=str, - default=["gsplat"], - help="gsplat, gsplat-legacy, inria", - ) - parser.add_argument( - "--repeats", - type=int, - default=10, - help="Number of repeats for profiling", - ) - parser.add_argument( - "--batch_size", - nargs="+", - type=int, - default=[1], - help="Batch size for profiling", - ) - parser.add_argument( - "--scene_grid", - nargs="+", - type=int, - default=[1, 11, 21], - help="Scene grid size for profiling", - ) - parser.add_argument( - "--channels", - nargs="+", - type=int, - default=[3], - help="Number of color channels for profiling", - ) - parser.add_argument( - "--memory_history", - action="store_true", - help="Record memory history and dump a snapshot. Use https://pytorch.org/memory_viz to visualize.", - ) - args = parser.parse_args() - if args.memory_history: - args.repeats = 1 # only run once for memory history - # Tested on a NVIDIA TITAN RTX with (24 GB). collection = [] @@ -214,6 +179,8 @@ def backward(): repeats=args.repeats, # only care about memory for the packed version implementation memory_history=args.memory_history, + world_rank=world_rank, + world_size=world_size, ) collection.append( [ @@ -243,6 +210,8 @@ def backward(): packed=True, sparse_grad=False, repeats=args.repeats, + world_rank=world_rank, + world_size=world_size, ) collection.append( [ @@ -272,6 +241,8 @@ def backward(): packed=False, sparse_grad=False, repeats=args.repeats, + world_rank=world_rank, + world_size=world_size, ) collection.append( [ @@ -349,33 +320,84 @@ def backward(): ) torch.cuda.empty_cache() - headers = [ - "Backend", - "Packed", - "Sparse Grad", - # configs - "Batch Size", - "Channels", - "Scene Size", - # stats - # "Mem[fwd] (GB)", - "Mem (GB)", - "FPS[fwd]", - "FPS[bwd]", - ] - - # pop config columns that has only one option - if len(args.scene_grid) == 1: - headers.pop(5) - for row in collection: - row.pop(5) - if len(args.channels) == 1: - headers.pop(4) - for row in collection: - row.pop(4) - if len(args.batch_size) == 1: - headers.pop(3) - for row in collection: - row.pop(3) - - print(tabulate(collection, headers, tablefmt="rst")) + if world_rank == 0: + headers = [ + "Backend", + "Packed", + "Sparse Grad", + # configs + "Batch Size", + "Channels", + "Scene Size", + # stats + # "Mem[fwd] (GB)", + "Mem (GB)", + "FPS[fwd]", + "FPS[bwd]", + ] + + # pop config columns that has only one option + if len(args.scene_grid) == 1: + headers.pop(5) + for row in collection: + row.pop(5) + if len(args.channels) == 1: + headers.pop(4) + for row in collection: + row.pop(4) + if len(args.batch_size) == 1: + headers.pop(3) + for row in collection: + row.pop(3) + + print(tabulate(collection, headers, tablefmt="rst")) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--backends", + nargs="+", + type=str, + default=["gsplat"], + help="gsplat, gsplat-legacy, inria", + ) + parser.add_argument( + "--repeats", + type=int, + default=10, + help="Number of repeats for profiling", + ) + parser.add_argument( + "--batch_size", + nargs="+", + type=int, + default=[1], + help="Batch size for profiling", + ) + parser.add_argument( + "--scene_grid", + nargs="+", + type=int, + default=[1, 11, 21], + help="Scene grid size for profiling", + ) + parser.add_argument( + "--channels", + nargs="+", + type=int, + default=[3], + help="Number of color channels for profiling", + ) + parser.add_argument( + "--memory_history", + action="store_true", + help="Record memory history and dump a snapshot. Use https://pytorch.org/memory_viz to visualize.", + ) + args = parser.parse_args() + if args.memory_history: + args.repeats = 1 # only run once for memory history + + cli(worker, args, verbose=True) diff --git a/tests/_test_distributed.py b/tests/_test_distributed.py new file mode 100644 index 000000000..ae03f9c98 --- /dev/null +++ b/tests/_test_distributed.py @@ -0,0 +1,114 @@ +import pytest +import torch + +from gsplat.distributed import ( + all_gather_int32, + all_gather_tensor_list, + all_to_all_int32, + all_to_all_tensor_list, + cli, +) + + +def _main_all_gather_int32(local_rank: int, world_rank: int, world_size: int, _): + device = torch.device("cuda", local_rank) + + value = world_rank + collected = all_gather_int32(world_size, value, device=device) + for i in range(world_size): + assert collected[i] == i + + value = torch.tensor(world_rank, device=device, dtype=torch.int) + collected = all_gather_int32(world_size, value, device=device) + for i in range(world_size): + assert collected[i] == torch.tensor(i, device=device, dtype=torch.int) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_all_gather_int32(): + cli(_main_all_gather_int32, None, verbose=True) + + +def _main_all_to_all_int32(local_rank: int, world_rank: int, world_size: int, _): + device = torch.device("cuda", local_rank) + + values = list(range(world_size)) + collected = all_to_all_int32(world_size, values, device=device) + for i in range(world_size): + assert collected[i] == world_rank + + values = torch.arange(world_size, device=device, dtype=torch.int) + collected = all_to_all_int32(world_size, values, device=device) + for i in range(world_size): + assert collected[i] == torch.tensor(world_rank, device=device, dtype=torch.int) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_all_to_all_int32(): + cli(_main_all_to_all_int32, None, verbose=True) + + +def _main_all_gather_tensor_list(local_rank: int, world_rank: int, world_size: int, _): + device = torch.device("cuda", local_rank) + N = 10 + + tensor_list = [ + torch.full((N, 2), world_rank, device=device), + torch.full((N, 3, 3), world_rank, device=device), + ] + + target_list = [ + torch.cat([torch.full((N, 2), i, device=device) for i in range(world_size)]), + torch.cat([torch.full((N, 3, 3), i, device=device) for i in range(world_size)]), + ] + + collected = all_gather_tensor_list(world_size, tensor_list) + for tensor, target in zip(collected, target_list): + assert torch.equal(tensor, target) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_all_gather_tensor_list(): + cli(_main_all_gather_tensor_list, None, verbose=True) + + +def _main_all_to_all_tensor_list(local_rank: int, world_rank: int, world_size: int, _): + device = torch.device("cuda", local_rank) + splits = torch.arange(0, world_size, device=device) + N = splits.sum().item() + + tensor_list = [ + torch.full((N, 2), world_rank, device=device), + torch.full((N, 3, 3), world_rank, device=device), + ] + + target_list = [ + torch.cat( + [ + torch.full((splits[world_rank], 2), i, device=device) + for i in range(world_size) + ] + ), + torch.cat( + [ + torch.full((splits[world_rank], 3, 3), i, device=device) + for i in range(world_size) + ] + ), + ] + + collected = all_to_all_tensor_list(world_size, tensor_list, splits) + for tensor, target in zip(collected, target_list): + assert torch.equal(tensor, target) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_all_to_all_tensor_list(): + cli(_main_all_to_all_tensor_list, None, verbose=True) + + +if __name__ == "__main__": + test_all_gather_int32() + test_all_to_all_int32() + test_all_gather_tensor_list() + test_all_to_all_tensor_list()