Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 31, 2024
1 parent d95b929 commit 8f7e92c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
6 changes: 4 additions & 2 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ DATA_FACTOR=4
RENDER_TRAJ_PATH="spiral"
CAP_MAX=250000

RESULT_DIR="results/benchmark_mcmc_deblur"
RESULT_DIR="results/benchmark_mcmc_deblur_wd/1e-3_0.8"
for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"
Expand All @@ -14,10 +14,12 @@ do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--blur_opt \
--blur_opt_lr 1e-3 \
--blur_a 0.8 \
--blur_mask_reg 0.002 \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE
done

# Summarize the stats
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
12 changes: 4 additions & 8 deletions examples/blur_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
class BlurOptModule(nn.Module):
"""Blur optimization module."""

def __init__(self, n: int, embed_dim: int = 4):
def __init__(self, cfg, n: int, embed_dim: int = 4):
super().__init__()
self.blur_a = cfg.blur_a
self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(3, 3)
self.depths_encoder = get_encoder(3, 1)
Expand Down Expand Up @@ -39,9 +40,7 @@ def forward(
quats: Tensor,
):
quats = F.normalize(quats, dim=-1)
means_log = log_transform(means)

means_emb = self.means_encoder.encode(means_log)
means_emb = self.means_encoder.encode(log_transform(means))
images_emb = self.embeds(image_ids).repeat(means.shape[0], 1)
mlp_out = self.blur_deltas_mlp(
torch.cat([images_emb, means_emb, scales, quats], dim=-1)
Expand All @@ -61,7 +60,6 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
)
grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)
grid_emb = self.grid_encoder.encode(grid_xy)

depths_emb = self.depths_encoder.encode(log_transform(depths))
images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1)
mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1)
Expand All @@ -78,9 +76,7 @@ def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2):
from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity.
"""
x = blur_mask.mean()
a = 2.0
b = 0.1
maskloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1)
maskloss = self.blur_a * (1 / (1 - x + eps) - 1) + 0.2 * (1 / (x + eps) - 1)
return maskloss


Expand Down
12 changes: 7 additions & 5 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ class Config:
blur_opt_lr: float = 1e-3
# Regularization for blur mask
blur_mask_reg: float = 0.001
# Blur start iteration
blur_start_iter: int = 2_000
# Regularization for blur optimization as weight decay
blur_opt_reg: float = 1e-6
blur_a: float = 0.8

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -414,12 +415,13 @@ def __init__(

self.blur_optimizers = []
if cfg.blur_opt:
self.blur_module = BlurOptModule(len(self.trainset)).to(self.device)
self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device)
self.blur_module.zero_init()
self.blur_optimizers = [
torch.optim.Adam(
self.blur_module.parameters(),
lr=cfg.blur_opt_lr * math.sqrt(cfg.batch_size),
weight_decay=cfg.blur_opt_reg,
),
]
if world_size > 1:
Expand Down Expand Up @@ -653,7 +655,7 @@ def train(self):
if cfg.random_bkgd:
bkgd = torch.rand(1, 3, device=device)
colors = colors + bkgd * (1.0 - alphas)
if cfg.blur_opt and step >= cfg.blur_start_iter:
if cfg.blur_opt:
blur_mask = self.blur_module.predict_mask(image_ids, depths)
renders_blur, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Expand Down Expand Up @@ -706,7 +708,7 @@ def train(self):
if cfg.use_bilateral_grid:
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
loss += tvloss
if cfg.blur_opt and step >= cfg.blur_start_iter:
if cfg.blur_opt:
loss += cfg.blur_mask_reg * self.blur_module.mask_loss(blur_mask)

# regularizations
Expand Down

0 comments on commit 8f7e92c

Please sign in to comment.