diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index 0ef02ad6..d497515f 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,12 +1,12 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" +SCENE_LIST="defocuscake defocustools defocussausage defocuscupcake defocuscups defocuscoral defocusdaisy defocusseal defocuscaps defocuscisco" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" - -RESULT_DIR="results/benchmark_mcmc_deblur" CAP_MAX=250000 +RESULT_DIR="results/benchmark_mcmc_deblur/c0.2_a10" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -15,12 +15,11 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ - --blur_opt_lr 1e-3 \ - --blur_mean_reg 0.005 \ + --blur_a 10 \ + --blur_c 0.2 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE done -# Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index b6daf971..a28f3244 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -10,8 +10,11 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, n: int, embed_dim: int = 4): + def __init__(self, cfg, n: int, embed_dim: int = 4): super().__init__() + self.a = cfg.blur_a + self.c = cfg.blur_c + self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -75,17 +78,17 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): """Mask mean loss.""" x = blur_mask.mean() - a = 0.9 - meanloss = a * (1 / (1 - x + eps) - 1) + (1 - a) * (1 / (x + eps) - 1) - return meanloss - - def mask_smoothness_loss(self, blur_mask: Tensor): - """Mask smoothness loss.""" - blurred_xy = median_blur(blur_mask.permute(0, 3, 1, 2), (5, 5)).permute( - 0, 2, 3, 1 - ) - smoothloss = F.huber_loss(blur_mask, blurred_xy) - return smoothloss + if step <= 2000: + a = 20 + b = 1 + c = 0.2 + else: + a = self.a + b = 1 + c = self.c + print(x.item(), a, b, c) + meanloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) + return c * meanloss def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 4e73510d..1cd99769 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -155,7 +155,8 @@ class Config: # Regularization for blur mask mean blur_mean_reg: float = 0.001 # Regularization for blur mask smoothness - blur_smoothness_reg: float = 0.0 + blur_a: float = 4 + blur_c: float = 0.5 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -414,7 +415,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( @@ -671,7 +672,7 @@ def train(self): renders_blur[..., 0:3], renders_blur[..., 3:4], ) - blur_mask = self.blur_module.predict_mask(image_ids, depths_blur) + blur_mask = self.blur_module.predict_mask(image_ids, depths) colors = (1 - blur_mask) * colors + blur_mask * colors_blur self.cfg.strategy.step_pre_backward( @@ -714,9 +715,6 @@ def train(self): loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss( blur_mask, step ) - loss += cfg.blur_smoothness_reg * self.blur_module.mask_smoothness_loss( - blur_mask - ) # regularizations if cfg.opacity_reg > 0.0: @@ -875,6 +873,8 @@ def train(self): self.eval(step, stage="train") self.eval(step) self.render_traj(step) + if (step + 1) % 1000 == 0 or step == 0: + self.eval(step, stage="train", vis_skip=True) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -892,7 +892,7 @@ def train(self): self.viewer.update(step, num_train_rays_per_step) @torch.no_grad() - def eval(self, step: int, stage: str = "val"): + def eval(self, step: int, stage: str = "val", vis_skip: bool = False): """Entry for evaluation.""" print("Running evaluation...") cfg = self.cfg @@ -904,10 +904,13 @@ def eval(self, step: int, stage: str = "val"): dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) + train_vis_image_ids = np.linspace(0, len(dataloader) - 1, 7).astype(int) ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(dataloader): + if vis_skip and stage == "train" and i not in train_vis_image_ids: + continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -957,7 +960,7 @@ def eval(self, step: int, stage: str = "val"): renders_blur[..., 0:3], renders_blur[..., 3:4], ) - blur_mask = self.blur_module.predict_mask(image_ids, depths_blur) + blur_mask = self.blur_module.predict_mask(image_ids, depths) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * colors_blur