From f3558773d21fb8b57bddc70f5be5ed005e85aa05 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 13 Nov 2024 11:57:48 -0800 Subject: [PATCH] bounded l1 losos --- examples/benchmarks/mcmc_deblur.sh | 6 ++---- examples/blur_opt.py | 32 +++++++++++++++++++++++------- examples/mlp.py | 27 +++++-------------------- examples/simple_trainer.py | 7 +++---- 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index b9f87c12..da89187e 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -4,8 +4,8 @@ SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake def DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" CAP_MAX=250000 +RESULT_DIR="results/benchmark_mcmc_deblur" -RESULT_DIR="results/benchmark_mcmc_deblur_wd/1e-3_0.8" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -14,12 +14,10 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ - --blur_opt_lr 1e-3 \ - --blur_a 0.8 \ - --blur_mask_reg 0.002 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE done + # Summarize the stats python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/blur_opt.py b/examples/blur_opt.py index e02a5f5f..e2fcc506 100644 --- a/examples/blur_opt.py +++ b/examples/blur_opt.py @@ -9,9 +9,8 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, cfg, n: int, embed_dim: int = 4): + def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.blur_a = cfg.blur_a self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -28,6 +27,7 @@ def __init__(self, cfg, n: int, embed_dim: int = 4): layer_width=64, out_dim=7, ) + self.bounded_l1_loss = bounded_l1_loss(10.0, 0.5) def zero_init(self): torch.nn.init.zeros_(self.embeds.weight) @@ -69,15 +69,33 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor): blur_mask = torch.sigmoid(mlp_out) return blur_mask - def mask_loss(self, blur_mask: Tensor, eps: float = 1e-2): + def mask_loss(self, blur_mask: Tensor): """Loss function for regularizing the blur mask by controlling its mean. - The loss function diverges to +infinity at 0 and 1. This prevents the mask - from collapsing all 0s or 1s. It is biased towards 0 to encourage sparsity. + Uses bounded l1 loss which diverges to +infinity at 0 and 1 to prevents the mask + from collapsing all 0s or 1s. """ x = blur_mask.mean() - maskloss = self.blur_a * (1 / (1 - x + eps) - 1) + 0.2 * (1 / (x + eps) - 1) - return maskloss + return self.bounded_l1_loss(x) + + +def bounded_l1_loss(lambda_a: float, lambda_b: float, eps: float = 1e-2): + """L1 loss function with discontinuities at 0 and 1. + + Args: + lambda_a (float): Coefficient of L1 loss. + lambda_b (float): Coefficient of bounded loss. + eps (float, optional): Epsilon to prevent divide by zero. Defaults to 1e-2. + """ + + def loss_fn(x: Tensor): + return lambda_a * x + lambda_b * (1 / (1 - x + eps) + 1 / (x + eps)) + + # Compute constant that sets min to zero + xs = torch.linspace(0, 1, 1000) + ys = loss_fn(xs) + c = ys.min() + return lambda x: loss_fn(x) - c def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/mlp.py b/examples/mlp.py index 4d5987cf..f5bc48ac 100644 --- a/examples/mlp.py +++ b/examples/mlp.py @@ -18,7 +18,6 @@ from typing import Union -import torch from torch import nn from examples.external import TCNN_EXISTS, tcnn @@ -37,7 +36,7 @@ def activation_to_tcnn_string(activation: Union[nn.Module, None]) -> str: if isinstance(activation, nn.ReLU): return "ReLU" if isinstance(activation, nn.LeakyReLU): - return "Leaky ReLU" + return "LeakyReLU" if isinstance(activation, nn.Sigmoid): return "Sigmoid" if isinstance(activation, nn.Softplus): @@ -74,16 +73,11 @@ def create_mlp( num_layers: int, layer_width: int, out_dim: int, - initialize_last_layer_zeros: bool = False, ): if TCNN_EXISTS: - return _create_mlp_tcnn( - in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros - ) + return _create_mlp_tcnn(in_dim, num_layers, layer_width, out_dim) else: - return _create_mlp_torch( - in_dim, num_layers, layer_width, out_dim, initialize_last_layer_zeros - ) + return _create_mlp_torch(in_dim, num_layers, layer_width, out_dim) def _create_mlp_tcnn( @@ -91,11 +85,10 @@ def _create_mlp_tcnn( num_layers: int, layer_width: int, out_dim: int, - initialize_last_layer_zeros: bool = False, ): """Create a fully-connected neural network with tiny-cuda-nn.""" network_config = get_tcnn_network_config( - activation=nn.ReLU(), + activation=nn.LeakyReLU(), out_activation=None, layer_width=layer_width, num_layers=num_layers, @@ -105,12 +98,6 @@ def _create_mlp_tcnn( n_output_dims=out_dim, network_config=network_config, ) - - if initialize_last_layer_zeros: - # tcnn always pads the output layer's width to a multiple of 16 - params = tcnn_encoding.state_dict()["params"] - params[-1 * (layer_width * 16 * (out_dim // 16 + 1)) :] = 0 - tcnn_encoding.load_state_dict({"params": params}) return tcnn_encoding @@ -119,7 +106,6 @@ def _create_mlp_torch( num_layers: int, layer_width: int, out_dim: int, - initialize_last_layer_zeros: bool = False, ): """Create a fully-connected neural network with PyTorch.""" layers = [] @@ -128,9 +114,6 @@ def _create_mlp_torch( layer_out = layer_width if i != num_layers - 1 else out_dim layers.append(nn.Linear(layer_in, layer_out, bias=False)) if i != num_layers - 1: - layers.append(nn.ReLU()) + layers.append(nn.LeakyReLU()) layer_in = layer_width - - if initialize_last_layer_zeros: - nn.init.zeros_(layers[-1].weight) return nn.Sequential(*layers) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3f0c93ce..49aee986 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -156,7 +156,6 @@ class Config: blur_mask_reg: float = 0.001 # Regularization for blur optimization as weight decay blur_opt_reg: float = 1e-6 - blur_a: float = 0.8 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -415,7 +414,7 @@ def __init__( self.blur_optimizers = [] if cfg.blur_opt: - self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device) + self.blur_module = BlurOptModule(len(self.trainset)).to(self.device) self.blur_module.zero_init() self.blur_optimizers = [ torch.optim.Adam( @@ -869,8 +868,8 @@ def train(self): self.eval(step, stage="train") self.eval(step, stage="val") self.render_traj(step) - if step % 1000 == 0: - self.eval(step, stage="vis") + # if step % 1000 == 0: + # self.eval(step, stage="vis") # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: