Skip to content

Commit

Permalink
latest run avg psnr 23.50
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 28, 2024
1 parent d874ef1 commit 8a02e74
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
9 changes: 4 additions & 5 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
SCENE_DIR="data/deblur_dataset/real_defocus_blur"
SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools"
SCENE_LIST="defocuscake defocustools defocussausage defocuscupcake defocuscups defocuscoral defocusdaisy defocusseal defocuscaps defocuscisco"

DATA_FACTOR=4
RENDER_TRAJ_PATH="spiral"

RESULT_DIR="results/benchmark_mcmc_deblur"
CAP_MAX=250000

RESULT_DIR="results/benchmark_mcmc_deblur/c0.2_a10"
for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"
Expand All @@ -15,12 +15,11 @@ do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--blur_opt \
--blur_opt_lr 1e-3 \
--blur_mean_reg 0.005 \
--blur_a 10 \
--blur_c 0.2 \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE
done

# Summarize the stats
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
27 changes: 15 additions & 12 deletions examples/blur_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
class BlurOptModule(nn.Module):
"""Blur optimization module."""

def __init__(self, n: int, embed_dim: int = 4):
def __init__(self, cfg, n: int, embed_dim: int = 4):
super().__init__()
self.a = cfg.blur_a
self.c = cfg.blur_c

self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(3, 3)
self.depths_encoder = get_encoder(3, 1)
Expand Down Expand Up @@ -75,17 +78,17 @@ def predict_mask(self, image_ids: Tensor, depths: Tensor):
def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
"""Mask mean loss."""
x = blur_mask.mean()
a = 0.9
meanloss = a * (1 / (1 - x + eps) - 1) + (1 - a) * (1 / (x + eps) - 1)
return meanloss

def mask_smoothness_loss(self, blur_mask: Tensor):
"""Mask smoothness loss."""
blurred_xy = median_blur(blur_mask.permute(0, 3, 1, 2), (5, 5)).permute(
0, 2, 3, 1
)
smoothloss = F.huber_loss(blur_mask, blurred_xy)
return smoothloss
if step <= 2000:
a = 20
b = 1
c = 0.2
else:
a = self.a
b = 1
c = self.c
print(x.item(), a, b, c)
meanloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1)
return c * meanloss


def get_encoder(num_freqs: int, input_dims: int):
Expand Down
19 changes: 11 additions & 8 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ class Config:
# Regularization for blur mask mean
blur_mean_reg: float = 0.001
# Regularization for blur mask smoothness
blur_smoothness_reg: float = 0.0
blur_a: float = 4
blur_c: float = 0.5

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -414,7 +415,7 @@ def __init__(

self.blur_optimizers = []
if cfg.blur_opt:
self.blur_module = BlurOptModule(len(self.trainset)).to(self.device)
self.blur_module = BlurOptModule(cfg, len(self.trainset)).to(self.device)
self.blur_module.zero_init()
self.blur_optimizers = [
torch.optim.Adam(
Expand Down Expand Up @@ -671,7 +672,7 @@ def train(self):
renders_blur[..., 0:3],
renders_blur[..., 3:4],
)
blur_mask = self.blur_module.predict_mask(image_ids, depths_blur)
blur_mask = self.blur_module.predict_mask(image_ids, depths)
colors = (1 - blur_mask) * colors + blur_mask * colors_blur

self.cfg.strategy.step_pre_backward(
Expand Down Expand Up @@ -714,9 +715,6 @@ def train(self):
loss += cfg.blur_mean_reg * self.blur_module.mask_mean_loss(
blur_mask, step
)
loss += cfg.blur_smoothness_reg * self.blur_module.mask_smoothness_loss(
blur_mask
)

# regularizations
if cfg.opacity_reg > 0.0:
Expand Down Expand Up @@ -875,6 +873,8 @@ def train(self):
self.eval(step, stage="train")
self.eval(step)
self.render_traj(step)
if (step + 1) % 1000 == 0 or step == 0:
self.eval(step, stage="train", vis_skip=True)

# run compression
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
Expand All @@ -892,7 +892,7 @@ def train(self):
self.viewer.update(step, num_train_rays_per_step)

@torch.no_grad()
def eval(self, step: int, stage: str = "val"):
def eval(self, step: int, stage: str = "val", vis_skip: bool = False):
"""Entry for evaluation."""
print("Running evaluation...")
cfg = self.cfg
Expand All @@ -904,10 +904,13 @@ def eval(self, step: int, stage: str = "val"):
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=1
)
train_vis_image_ids = np.linspace(0, len(dataloader) - 1, 7).astype(int)

ellipse_time = 0
metrics = defaultdict(list)
for i, data in enumerate(dataloader):
if vis_skip and stage == "train" and i not in train_vis_image_ids:
continue
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0
Expand Down Expand Up @@ -957,7 +960,7 @@ def eval(self, step: int, stage: str = "val"):
renders_blur[..., 0:3],
renders_blur[..., 3:4],
)
blur_mask = self.blur_module.predict_mask(image_ids, depths_blur)
blur_mask = self.blur_module.predict_mask(image_ids, depths)
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0))
colors = (1 - blur_mask) * colors + blur_mask * colors_blur
Expand Down

0 comments on commit 8a02e74

Please sign in to comment.