Skip to content

Commit

Permalink
bounded l1 losos
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Nov 13, 2024
1 parent 8f7e92c commit f355877
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 37 deletions.
6 changes: 2 additions & 4 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake def
DATA_FACTOR=4
RENDER_TRAJ_PATH="spiral"
CAP_MAX=250000
RESULT_DIR="results/benchmark_mcmc_deblur"

RESULT_DIR="results/benchmark_mcmc_deblur_wd/1e-3_0.8"
for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"
Expand All @@ -14,12 +14,10 @@ do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--blur_opt \
--blur_opt_lr 1e-3 \
--blur_a 0.8 \
--blur_mask_reg 0.002 \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE
done

# Summarize the stats
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
32 changes: 25 additions & 7 deletions examples/blur_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
class BlurOptModule(nn.Module):
"""Blur optimization module."""

def __init__(self, cfg, n: int, embed_dim: int = 4):
def __init__(self, n: int, embed_dim: int = 4):
super().__init__()
self.blur_a = cfg.blur_a
self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(3, 3)
self.depths_encoder = get_encoder(3, 1)
Expand All @@ -28,6 +27,7 @@ def __init__(self, cfg, n: int, embed_dim: int = 4):
layer_width=64,
out_dim=7,
)
self.bounded_l1_loss = bounded_l1_loss(10.0, 0.5)

def zero_init(self):
torch.nn.init.zeros_(self.embeds.weight)
Expand Down Expand Up @@ -69,15 +69,33 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
blur_mask = torch.sigmoid(mlp_out)
return blur_mask

def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2):
def mask_loss(self, blur_mask: Tensor):
"""Loss function for regularizing the blur mask by controlling its mean.
The loss function diverges to +infinity at 0 and 1. This prevents the mask
from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity.
Uses bounded l1 loss which diverges to +infinity at 0 and 1 to prevents the mask
from collapsing all 0s or 1s.
"""
x = blur_mask.mean()
maskloss = self.blur_a * (1 / (1 - x + eps) - 1) + 0.2 * (1 / (x + eps) - 1)
return maskloss
return self.bounded_l1_loss(x)


def bounded_l1_loss(lambda_a: float, lambda_b: float, eps: float = 1e-2):
"""L1 loss function with discontinuities at 0 and 1.
Args:
lambda_a (float): Coefficient of L1 loss.
lambda_b (float): Coefficient of bounded loss.
eps (float, optional): Epsilon to prevent divide by zero. Defaults to 1e-2.
"""

def loss_fn(x: Tensor):
return lambda_a * x + lambda_b * (1 / (1 - x + eps) + 1 / (x + eps))

# Compute constant that sets min to zero
xs = torch.linspace(0, 1, 1000)
ys = loss_fn(xs)
c = ys.min()
return lambda x: loss_fn(x) - c


def get_encoder(num_freqs: int, input_dims: int):
Expand Down
27 changes: 5 additions & 22 deletions examples/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from typing import Union

import torch
from torch import nn

from examples.external import TCNN_EXISTS, tcnn
Expand All @@ -37,7 +36,7 @@ def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str:
if isinstance(activation, nn.ReLU):
return "ReLU"
if isinstance(activation, nn.LeakyReLU):
return "Leaky ReLU"
return "LeakyReLU"
if isinstance(activation, nn.Sigmoid):
return "Sigmoid"
if isinstance(activation, nn.Softplus):
Expand Down Expand Up @@ -74,28 +73,22 @@ def create_mlp(
num_layers: int,
layer_width: int,
out_dim: int,
initialize_last_layer_zeros: bool = False,
):
if TCNN_EXISTS:
return _create_mlp_tcnn(
in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros
)
return _create_mlp_tcnn(in_dim, num_layers, layer_width, out_dim)
else:
return _create_mlp_torch(
in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros
)
return _create_mlp_torch(in_dim, num_layers, layer_width, out_dim)


def _create_mlp_tcnn(
in_dim: int,
num_layers: int,
layer_width: int,
out_dim: int,
initialize_last_layer_zeros: bool = False,
):
"""Create a fully-connected neural network with tiny-cuda-nn."""
network_config = get_tcnn_network_config(
activation=nn.ReLU(),
activation=nn.LeakyReLU(),
out_activation=None,
layer_width=layer_width,
num_layers=num_layers,
Expand All @@ -105,12 +98,6 @@ def _create_mlp_tcnn(
n_output_dims=out_dim,
network_config=network_config,
)

if initialize_last_layer_zeros:
# tcnn always pads the output layer's width to a multiple of 16
params = tcnn_encoding.state_dict()["params"]
params[-1 * (layer_width * 16 * (out_dim // 16 + 1)) :] = 0
tcnn_encoding.load_state_dict({"params": params})
return tcnn_encoding


Expand All @@ -119,7 +106,6 @@ def _create_mlp_torch(
num_layers: int,
layer_width: int,
out_dim: int,
initialize_last_layer_zeros: bool = False,
):
"""Create a fully-connected neural network with PyTorch."""
layers = []
Expand All @@ -128,9 +114,6 @@ def _create_mlp_torch(
layer_out = layer_width if i != num_layers - 1 else out_dim
layers.append(nn.Linear(layer_in, layer_out, bias=False))
if i != num_layers - 1:
layers.append(nn.ReLU())
layers.append(nn.LeakyReLU())
layer_in = layer_width

if initialize_last_layer_zeros:
nn.init.zeros_(layers[-1].weight)
return nn.Sequential(*layers)
7 changes: 3 additions & 4 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class Config:
blur_mask_reg: float = 0.001
# Regularization for blur optimization as weight decay
blur_opt_reg: float = 1e-6
blur_a: float = 0.8

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -415,7 +414,7 @@ def __init__(

self.blur_optimizers = []
if cfg.blur_opt:
self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device)
self.blur_module = BlurOptModule(len(self.trainset)).to(self.device)
self.blur_module.zero_init()
self.blur_optimizers = [
torch.optim.Adam(
Expand Down Expand Up @@ -869,8 +868,8 @@ def train(self):
self.eval(step, stage="train")
self.eval(step, stage="val")
self.render_traj(step)
if step % 1000 == 0:
self.eval(step, stage="vis")
# if step % 1000 == 0:
# self.eval(step, stage="vis")

# run compression
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
Expand Down

0 comments on commit f355877

Please sign in to comment.