From d95b9293fcbf75f283632041b77ccb64025e87f1 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 29 Oct 2024 13:35:42 -0700 Subject: [PATCH] delayed start instead of warmup --- examples/blur_opt.py | 16 +++++----------- examples/simple_trainer.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/examples/blur_opt.py b/examples/blur_opt.py index d01a552a..fed26000 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -11,8 +11,6 @@ class BlurOptModule(nn.Module): def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.num_warmup_steps = 2000 - self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -73,19 +71,15 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): blur_mask = torch.sigmoid(mlp_out) return blur_mask - def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): + def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2): """Loss function for regularizing the blur mask by controlling its mean. The loss function diverges to +infinity at 0 and 1. This prevents the mask - from collapsing all 0s or 1s. It is also biased towards 0 to encourage - sparsity. During warmup, the bias is even higher to start with a sparse mask.""" + from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity. + """ x = blur_mask.mean() - if step <= self.num_warmup_steps: - a = 3 - b = 0.1 - else: - a = 1 - b = 0.1 + a = 2.0 + b = 0.1 maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) return maskloss diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 0bed4d6a..e6b054ce 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -153,7 +153,9 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur mask - blur_mask_reg: float = 0.002 + blur_mask_reg: float = 0.001 + # Blur start iteration + blur_start_iter: int = 2_000 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -651,7 +653,7 @@ def train(self): if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - if cfg.blur_opt: + if cfg.blur_opt and step >= cfg.blur_start_iter: blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, @@ -704,8 +706,8 @@ def train(self): if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss - if cfg.blur_opt: - loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step) + if cfg.blur_opt and step >= cfg.blur_start_iter: + loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask) # regularizations if cfg.opacity_reg > 0.0: @@ -865,6 +867,8 @@ def train(self): self.eval(step, stage="train") self.eval(step, stage="val") self.render_traj(step) + if step % 1000 == 0: + self.eval(step, stage="vis") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -890,7 +894,7 @@ def eval(self, step: int, stage: str = "val"): world_rank = self.world_rank world_size = self.world_size - dataset = self.trainset if stage == "train" else self.valset + dataset = self.valset if stage == "val" else self.trainset dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) @@ -898,6 +902,9 @@ def eval(self, step: int, stage: str = "val"): ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(dataloader): + if stage == "vis": + if i % 5 != 0: + continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -929,7 +936,7 @@ def eval(self, step: int, stage: str = "val"): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] - if self.cfg.blur_opt and stage == "train": + if self.cfg.blur_opt and stage != "val": blur_mask = self.blur_module.predict_mask(image_ids, depths) canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats(