diff --git a/examples/benchmarks/compression/mcmc.sh b/examples/benchmarks/compression/mcmc.sh index 4c7165f3d..a28bfb6f5 100644 --- a/examples/benchmarks/compression/mcmc.sh +++ b/examples/benchmarks/compression/mcmc.sh @@ -49,7 +49,7 @@ done if command -v zip &> /dev/null then echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR + python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST else echo "zip command not found, skipping zipping" fi \ No newline at end of file diff --git a/examples/benchmarks/compression/summarize_stats.py b/examples/benchmarks/compression/summarize_stats.py index d11dbed6f..5efa729fe 100644 --- a/examples/benchmarks/compression/summarize_stats.py +++ b/examples/benchmarks/compression/summarize_stats.py @@ -8,9 +8,8 @@ import tyro -def main(results_dir: str, scenes: List[str]): +def main(results_dir: str, scenes: List[str], stage: str = "compress"): print("scenes:", scenes) - stage = "compress" summary = defaultdict(list) for scene in scenes: @@ -33,7 +32,11 @@ def main(results_dir: str, scenes: List[str]): summary[k].append(v) for k, v in summary.items(): - print(k, np.mean(v)) + summary[k] = np.mean(v) + summary["scenes"] = scenes + + with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f: + json.dump(summary, f, indent=2) if __name__ == "__main__": diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh new file mode 100644 index 000000000..da89187ef --- /dev/null +++ b/examples/benchmarks/mcmc_deblur.sh @@ -0,0 +1,23 @@ +SCENE_DIR="data/deblur_dataset/real_defocus_blur" +SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" + +DATA_FACTOR=4 +RENDER_TRAJ_PATH="spiral" +CAP_MAX=250000 +RESULT_DIR="results/benchmark_mcmc_deblur" + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train and eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + --strategy.cap-max $CAP_MAX \ + --blur_opt \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE +done + +# Summarize the stats +python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_opt.py b/examples/blur_opt.py new file mode 100644 index 000000000..ee529d6c3 --- /dev/null +++ b/examples/blur_opt.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +from torch import Tensor +import torch.nn.functional as F +from examples.mlp import create_mlp, get_encoder +from gsplat.utils import log_transform + + +class BlurOptModule(nn.Module): + """Blur optimization module.""" + + def __init__(self, n: int, embed_dim: int = 4): + super().__init__() + self.embeds = torch.nn.Embedding(n, embed_dim) + self.means_encoder = get_encoder(num_freqs=3, input_dims=3) + self.depths_encoder = get_encoder(num_freqs=3, input_dims=1) + self.grid_encoder = get_encoder(num_freqs=1, input_dims=2) + self.blur_mask_mlp = create_mlp( + in_dim=embed_dim + self.depths_encoder.out_dim + self.grid_encoder.out_dim, + num_layers=5, + layer_width=64, + out_dim=1, + ) + self.blur_deltas_mlp = create_mlp( + in_dim=embed_dim + self.means_encoder.out_dim + 7, + num_layers=5, + layer_width=64, + out_dim=7, + ) + self.bounded_l1_loss = bounded_l1_loss(10.0, 0.5) + + def zero_init(self): + torch.nn.init.zeros_(self.embeds.weight) + + def forward( + self, + image_ids: Tensor, + means: Tensor, + scales: Tensor, + quats: Tensor, + ): + quats = F.normalize(quats, dim=-1) + means_emb = self.means_encoder.encode(log_transform(means)) + images_emb = self.embeds(image_ids).repeat(means.shape[0], 1) + mlp_out = self.blur_deltas_mlp( + torch.cat([images_emb, means_emb, scales, quats], dim=-1) + ).float() + scales_delta = torch.clamp(mlp_out[:, :3], min=0.0, max=0.1) + quats_delta = torch.clamp(mlp_out[:, 3:], min=0.0, max=0.1) + scales = torch.exp(scales + scales_delta) + quats = quats + quats_delta + return scales, quats + + def predict_mask(self, image_ids: Tensor, depths: Tensor): + height, width = depths.shape[1:3] + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=depths.device) + 0.5) / height, + (torch.arange(width, device=depths.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + grid_emb = self.grid_encoder.encode(grid_xy) + depths_emb = self.depths_encoder.encode(log_transform(depths)) + images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) + mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) + mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])).reshape( + depths.shape + ) + blur_mask = torch.sigmoid(mlp_out) + return blur_mask + + def mask_loss(self, blur_mask: Tensor): + """Loss function for regularizing the blur mask by controlling its mean. + + Uses bounded l1 loss which diverges to +infinity at 0 and 1 to prevents the mask + from collapsing all 0s or 1s. + """ + x = blur_mask.mean() + return self.bounded_l1_loss(x) + + +def bounded_l1_loss(lambda_a: float, lambda_b: float, eps: float = 1e-2): + """L1 loss function with discontinuities at 0 and 1. + + Args: + lambda_a (float): Coefficient of L1 loss. + lambda_b (float): Coefficient of bounded loss. + eps (float, optional): Epsilon to prevent divide by zero. Defaults to 1e-2. + """ + + def loss_fn(x: Tensor): + return lambda_a * x + lambda_b * (1 / (1 - x + eps) + 1 / (x + eps)) + + # Compute constant that sets min to zero + xs = torch.linspace(0, 1, 1000) + ys = loss_fn(xs) + c = ys.min() + return lambda x: loss_fn(x) - c diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 938bad265..ba12c2258 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -40,6 +40,11 @@ def __init__( self.factor = factor self.normalize = normalize self.test_every = test_every + li = os.listdir(data_dir) + for l in li: + if l.startswith("hold"): + self.test_every = int(l.split("=")[-1]) + break colmap_dir = os.path.join(data_dir, "sparse/0/") if not os.path.exists(colmap_dir): @@ -134,7 +139,7 @@ def __init__( # Load extended metadata. Used by Bilarf dataset. self.extconf = { - "spiral_radius_scale": 1.0, + "spiral_radius_scale": 0.1, "no_factor_suffix": False, } extconf_file = os.path.join(data_dir, "ext_metadata.json") diff --git a/examples/mlp/__init__.py b/examples/mlp/__init__.py new file mode 100644 index 000000000..58f4382b4 --- /dev/null +++ b/examples/mlp/__init__.py @@ -0,0 +1,2 @@ +from .encoder import get_encoder +from .mlp import create_mlp diff --git a/examples/mlp/encoder.py b/examples/mlp/encoder.py new file mode 100644 index 000000000..188a25bf8 --- /dev/null +++ b/examples/mlp/encoder.py @@ -0,0 +1,47 @@ +import torch + + +def get_encoder(num_freqs: int, input_dims: int): + kwargs = { + "include_input": True, + "input_dims": input_dims, + "max_freq_log2": num_freqs - 1, + "num_freqs": num_freqs, + "log_sampling": True, + "periodic_fns": [torch.sin, torch.cos], + } + encoder = Encoder(**kwargs) + return encoder + + +class Encoder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs["input_dims"] + out_dim = 0 + if self.kwargs["include_input"]: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs["max_freq_log2"] + N_freqs = self.kwargs["num_freqs"] + + if self.kwargs["log_sampling"]: + freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs["periodic_fns"]: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def encode(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) diff --git a/examples/mlp/external.py b/examples/mlp/external.py new file mode 100644 index 000000000..f14219bd5 --- /dev/null +++ b/examples/mlp/external.py @@ -0,0 +1,58 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + + +class _LazyError: + def __init__(self, data): + self.__data = data # pylint: disable=unused-private-member + + class LazyErrorObj: + def __init__(self, data): + self.__data = data # pylint: disable=unused-private-member + + def __call__(self, *args, **kwds): + name, exc = object.__getattribute__(self, "__data") + raise RuntimeError(f"Could not load package {name}.") from exc + + def __getattr__(self, __name: str): + name, exc = object.__getattribute__(self, "__data") + raise RuntimeError(f"Could not load package {name}") from exc + + def __getattr__(self, __name: str): + return _LazyError.LazyErrorObj(object.__getattribute__(self, "__data")) + + +TCNN_EXISTS = False +tcnn_import_exception = None +tcnn = None +try: + import tinycudann + + tcnn = tinycudann + del tinycudann + TCNN_EXISTS = True +except ModuleNotFoundError as _exp: + tcnn_import_exception = _exp +except ImportError as _exp: + tcnn_import_exception = _exp +except EnvironmentError as _exp: + if "Unknown compute capability" not in _exp.args[0]: + raise _exp + print("Could not load tinycudann: " + str(_exp), file=sys.stderr) + tcnn_import_exception = _exp + +if tcnn_import_exception is not None: + tcnn = _LazyError(tcnn_import_exception) diff --git a/examples/mlp/mlp.py b/examples/mlp/mlp.py new file mode 100644 index 000000000..f68330bee --- /dev/null +++ b/examples/mlp/mlp.py @@ -0,0 +1,119 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi Layer Perceptron +""" + +from typing import Union + +from torch import nn + +from examples.mlp.external import TCNN_EXISTS, tcnn + + +def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: + """Converts a torch.nn activation function to a string that can be used to + initialize a TCNN activation function. + + Args: + activation: torch.nn activation function + Returns: + str: TCNN activation function string + """ + + if isinstance(activation, nn.ReLU): + return "ReLU" + if isinstance(activation, nn.LeakyReLU): + return "LeakyReLU" + if isinstance(activation, nn.Sigmoid): + return "Sigmoid" + if isinstance(activation, nn.Softplus): + return "Softplus" + if isinstance(activation, nn.Tanh): + return "Tanh" + if isinstance(activation, type(None)): + return "None" + tcnn_documentation_url = "https://github.com/NVlabs/tiny-cuda-nn/blob/master/DOCUMENTATION.md#activation-functions" + raise ValueError( + f"TCNN activation {activation} not supported for now.\nSee {tcnn_documentation_url} for TCNN documentation." + ) + + +def get_tcnn_network_config( + activation, out_activation, layer_width, num_layers +) -> dict: + """Get the network configuration for tcnn if implemented""" + activation_str = activation_to_tcnn_string(activation) + output_activation_str = activation_to_tcnn_string(out_activation) + assert layer_width in [16, 32, 64, 128] + network_config = { + "otype": "FullyFusedMLP", + "activation": activation_str, + "output_activation": output_activation_str, + "n_neurons": layer_width, + "n_hidden_layers": num_layers - 1, + } + return network_config + + +def create_mlp( + in_dim: int, + num_layers: int, + layer_width: int, + out_dim: int, +): + if TCNN_EXISTS: + return _create_mlp_tcnn(in_dim, num_layers, layer_width, out_dim) + else: + return _create_mlp_torch(in_dim, num_layers, layer_width, out_dim) + + +def _create_mlp_tcnn( + in_dim: int, + num_layers: int, + layer_width: int, + out_dim: int, +): + """Create a fully-connected neural network with tiny-cuda-nn.""" + network_config = get_tcnn_network_config( + activation=nn.LeakyReLU(), + out_activation=None, + layer_width=layer_width, + num_layers=num_layers, + ) + tcnn_encoding = tcnn.Network( + n_input_dims=in_dim, + n_output_dims=out_dim, + network_config=network_config, + ) + return tcnn_encoding + + +def _create_mlp_torch( + in_dim: int, + num_layers: int, + layer_width: int, + out_dim: int, +): + """Create a fully-connected neural network with PyTorch.""" + layers = [] + layer_in = in_dim + for i in range(num_layers): + layer_out = layer_width if i != num_layers - 1 else out_dim + layers.append(nn.Linear(layer_in, layer_out, bias=False)) + if i != num_layers - 1: + layers.append(nn.LeakyReLU()) + layer_in = layer_width + return nn.Sequential(*layers) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 93e70002f..49aee9860 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -28,6 +28,7 @@ from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never +from blur_opt import BlurOptModule from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed from lib_bilagrid import ( BilateralGrid, @@ -82,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -147,6 +148,15 @@ class Config: # Regularization for appearance optimization as weight decay app_opt_reg: float = 1e-6 + # Enable blur optimization. (experimental) + blur_opt: bool = False + # Learning rate for blur optimization + blur_opt_lr: float = 1e-3 + # Regularization for blur mask + blur_mask_reg: float = 0.001 + # Regularization for blur optimization as weight decay + blur_opt_reg: float = 1e-6 + # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False # Shape of the bilateral grid (X, Y, W) @@ -283,6 +293,9 @@ def __init__( self.local_rank = local_rank self.world_size = world_size self.device = f"cuda:{local_rank}" + self.render_mode = "RGB" + if cfg.depth_loss or cfg.blur_opt: + self.render_mode = "RGB+ED" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -399,6 +412,20 @@ def __init__( if world_size > 1: self.app_module = DDP(self.app_module) + self.blur_optimizers = [] + if cfg.blur_opt: + self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) + self.blur_module.zero_init() + self.blur_optimizers = [ + torch.optim.Adam( + self.blur_module.parameters(), + lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.blur_opt_reg, + ), + ] + if world_size > 1: + self.blur_module = DDP(self.blur_module) + self.bil_grid_optimizers = [] if cfg.use_bilateral_grid: self.bil_grids = BilateralGrid( @@ -447,13 +474,10 @@ def rasterize_splats( width: int, height: int, masks: Optional[Tensor] = None, + blur: bool = False, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means"] # [N, 3] - # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] - # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] image_ids = kwargs.pop("image_ids", None) @@ -469,6 +493,17 @@ def rasterize_splats( else: colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + if self.cfg.blur_opt and blur: + scales, quats = self.blur_module( + image_ids=image_ids, + means=self.splats["means"], + scales=self.splats["scales"], + quats=self.splats["quats"], + ) + else: + scales = torch.exp(self.splats["scales"]) # [N, 3] + quats = self.splats["quats"] # [N, 4] + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( means=means, @@ -599,7 +634,7 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode=self.render_mode, masks=masks, ) if renders.shape[-1] == 4: @@ -619,6 +654,22 @@ def train(self): if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) + if cfg.blur_opt: + blur_mask = self.blur_module.predict_mask(image_ids, depths) + renders_blur, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB", + masks=masks, + blur=True, + ) + colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( params=self.splats, @@ -656,6 +707,8 @@ def train(self): if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss + if cfg.blur_opt: + loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask) # regularizations if cfg.opacity_reg > 0.0: @@ -732,6 +785,11 @@ def train(self): data["app_module"] = self.app_module.module.state_dict() else: data["app_module"] = self.app_module.state_dict() + if cfg.blur_opt: + if world_size > 1: + data["blur_module"] = self.blur_module.module.state_dict() + else: + data["blur_module"] = self.blur_module.state_dict() torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) @@ -774,6 +832,9 @@ def train(self): for optimizer in self.app_optimizers: optimizer.step() optimizer.zero_grad(set_to_none=True) + for optimizer in self.blur_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) for optimizer in self.bil_grid_optimizers: optimizer.step() optimizer.zero_grad(set_to_none=True) @@ -804,8 +865,11 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: - self.eval(step) + self.eval(step, stage="train") + self.eval(step, stage="val") self.render_traj(step) + # if step % 1000 == 0: + # self.eval(step, stage="vis") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -831,21 +895,27 @@ def eval(self, step: int, stage: str = "val"): world_rank = self.world_rank world_size = self.world_size - valloader = torch.utils.data.DataLoader( - self.valset, batch_size=1, shuffle=False, num_workers=1 + dataset = self.valset if stage == "val" else self.trainset + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, shuffle=False, num_workers=1 ) + ellipse_time = 0 metrics = defaultdict(list) - for i, data in enumerate(valloader): + for i, data in enumerate(dataloader): + if stage == "vis": + if i % 5 != 0: + continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 + image_ids = data["image_id"].to(device) masks = data["mask"].to(device) if "mask" in data else None height, width = pixels.shape[1:3] torch.cuda.synchronize() tic = time.time() - colors, _, _ = self.rasterize_splats( + renders, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -853,13 +923,41 @@ def eval(self, step: int, stage: str = "val"): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode=self.render_mode, masks=masks, ) # [1, H, W, 3] + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + torch.cuda.synchronize() ellipse_time += time.time() - tic colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] + if self.cfg.blur_opt and stage != "val": + blur_mask = self.blur_module.predict_mask(image_ids, depths) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) + renders_blur, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB", + masks=masks, + blur=True, + ) + colors_blur = renders_blur[..., 0:3] + canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) + colors = (1 - blur_mask) * colors + blur_mask * colors_blur + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list.append(colors) if world_rank == 0: # write images @@ -881,7 +979,7 @@ def eval(self, step: int, stage: str = "val"): metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) if world_rank == 0: - ellipse_time /= len(valloader) + ellipse_time /= len(dataloader) stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} stats.update( @@ -904,7 +1002,7 @@ def eval(self, step: int, stage: str = "val"): self.writer.flush() @torch.no_grad() - def render_traj(self, step: int): + def render_traj(self, step: int, stage: str = "val"): """Entry for trajectory rendering.""" print("Running trajectory rendering...") cfg = self.cfg @@ -948,7 +1046,7 @@ def render_traj(self, step: int): # save to video video_dir = f"{cfg.result_dir}/videos" os.makedirs(video_dir, exist_ok=True) - writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + writer = imageio.get_writer(f"{video_dir}/traj_{stage}_{step}.mp4", fps=30) for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): camtoworlds = camtoworlds_all[i : i + 1] Ks = K[None] @@ -973,7 +1071,7 @@ def render_traj(self, step: int): canvas = (canvas * 255).astype(np.uint8) writer.append_data(canvas) writer.close() - print(f"Video saved to {video_dir}/traj_{step}.mp4") + print(f"Video saved to {video_dir}/traj_{stage}_{step}.mp4") @torch.no_grad() def run_compression(self, step: int):