Skip to content

Commit

Permalink
delayed start instead of warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 29, 2024
1 parent 68289ae commit d95b929
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
16 changes: 5 additions & 11 deletions examples/blur_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ class BlurOptModule(nn.Module):

def __init__(self, n: int, embed_dim: int = 4):
super().__init__()
self.num_warmup_steps = 2000

self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(3, 3)
self.depths_encoder = get_encoder(3, 1)
Expand Down Expand Up @@ -73,19 +71,15 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
blur_mask = torch.sigmoid(mlp_out)
return blur_mask

def mask_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2):
"""Loss function for regularizing the blur mask by controlling its mean.
The loss function diverges to +infinity at 0 and 1. This prevents the mask
from collapsing all 0s or 1s. It is also biased towards 0 to encourage
sparsity. During warmup, the bias is even higher to start with a sparse mask."""
from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity.
"""
x = blur_mask.mean()
if step <= self.num_warmup_steps:
a = 3
b = 0.1
else:
a = 1
b = 0.1
a = 2.0
b = 0.1
maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1)
return maskloss

Expand Down
19 changes: 13 additions & 6 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ class Config:
# Learning rate for blur optimization
blur_opt_lr: float = 1e-3
# Regularization for blur mask
blur_mask_reg: float = 0.002
blur_mask_reg: float = 0.001
# Blur start iteration
blur_start_iter: int = 2_000

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -651,7 +653,7 @@ def train(self):
if cfg.random_bkgd:
bkgd = torch.rand(1, 3, device=device)
colors = colors + bkgd * (1.0 - alphas)
if cfg.blur_opt:
if cfg.blur_opt and step >= cfg.blur_start_iter:
blur_mask = self.blur_module.predict_mask(image_ids, depths)
renders_blur, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Expand Down Expand Up @@ -704,8 +706,8 @@ def train(self):
if cfg.use_bilateral_grid:
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
loss += tvloss
if cfg.blur_opt:
loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask, step)
if cfg.blur_opt and step >= cfg.blur_start_iter:
loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask)

# regularizations
if cfg.opacity_reg > 0.0:
Expand Down Expand Up @@ -865,6 +867,8 @@ def train(self):
self.eval(step, stage="train")
self.eval(step, stage="val")
self.render_traj(step)
if step % 1000 == 0:
self.eval(step, stage="vis")

# run compression
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
Expand All @@ -890,14 +894,17 @@ def eval(self, step: int, stage: str = "val"):
world_rank = self.world_rank
world_size = self.world_size

dataset = self.trainset if stage == "train" else self.valset
dataset = self.valset if stage == "val" else self.trainset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=1
)

ellipse_time = 0
metrics = defaultdict(list)
for i, data in enumerate(dataloader):
if stage == "vis":
if i % 5 != 0:
continue
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0
Expand Down Expand Up @@ -929,7 +936,7 @@ def eval(self, step: int, stage: str = "val"):

colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]
if self.cfg.blur_opt and stage == "train":
if self.cfg.blur_opt and stage != "val":
blur_mask = self.blur_module.predict_mask(image_ids, depths)
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
renders_blur, _, _ = self.rasterize_splats(
Expand Down

0 comments on commit d95b929

Please sign in to comment.