diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index fd2ea9b6..b9f87c12 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -5,7 +5,7 @@ DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" CAP_MAX=250000 -RESULT_DIR="results/benchmark_mcmc_deblur" +RESULT_DIR="results/benchmark_mcmc_deblur_wd/1e-3_0.8" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -14,10 +14,12 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ + --blur_opt_lr 1e-3 \ + --blur_a 0.8 \ + --blur_mask_reg 0.002 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE done - # Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_opt.py b/examples/blur_opt.py index fed26000..e02a5f5f 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -9,8 +9,9 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, n: int, embed_dim: int = 4): + def __init__(self, cfg, n: int, embed_dim: int = 4): super().__init__() + self.blur_a = cfg.blur_a self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -39,9 +40,7 @@ def forward( quats: Tensor, ): quats = F.normalize(quats, dim=-1) - means_log = log_transform(means) - - means_emb = self.means_encoder.encode(means_log) + means_emb = self.means_encoder.encode(log_transform(means)) images_emb = self.embeds(image_ids).repeat(means.shape[0], 1) mlp_out = self.blur_deltas_mlp( torch.cat([images_emb, means_emb, scales, quats], dim=-1) @@ -61,7 +60,6 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) grid_emb = self.grid_encoder.encode(grid_xy) - depths_emb = self.depths_encoder.encode(log_transform(depths)) images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1) mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1) @@ -78,9 +76,7 @@ def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2): from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity. """ x = blur_mask.mean() - a = 2.0 - b = 0.1 - maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) + maskloss = self.blur_a * (1 / (1 - x + eps) - 1) + 0.2 * (1 / (x + eps) - 1) return maskloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index e6b054ce..3f0c93ce 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -154,8 +154,9 @@ class Config: blur_opt_lr: float = 1e-3 # Regularization for blur mask blur_mask_reg: float = 0.001 - # Blur start iteration - blur_start_iter: int = 2_000 + # Regularization for blur optimization as weight decay + blur_opt_reg: float = 1e-6 + blur_a: float = 0.8 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -414,12 +415,13 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( self.blur_module.parameters(), lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.blur_opt_reg, ), ] if world_size > 1: @@ -653,7 +655,7 @@ def train(self): if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - if cfg.blur_opt and step >= cfg.blur_start_iter: + if cfg.blur_opt: blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, @@ -706,7 +708,7 @@ def train(self): if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss - if cfg.blur_opt and step >= cfg.blur_start_iter: + if cfg.blur_opt: loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask) # regularizations