Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement defocus deblurring. #469

Draft
wants to merge 54 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
cf9a806
gtnet
jefequien Oct 5, 2024
5088b3b
reduce diff
jefequien Oct 5, 2024
214c8f0
refactor
jefequien Oct 5, 2024
d00cb1a
exact same gtnet
jefequien Oct 5, 2024
5aecb73
wip
jefequien Oct 12, 2024
15f8898
cleanup
jefequien Oct 13, 2024
9767f14
simplify
jefequien Oct 13, 2024
e168738
masks works
jefequien Oct 14, 2024
1e338e3
mask is working
jefequien Oct 14, 2024
abad1d1
mlp mask
jefequien Oct 15, 2024
052d272
mlp mask
jefequien Oct 15, 2024
3d4b5c0
no need for quantile
jefequien Oct 16, 2024
649acd0
single mlp
jefequien Oct 16, 2024
fad9b21
cleanup
jefequien Oct 16, 2024
2ae57b9
cleanup
jefequien Oct 16, 2024
904979e
focal embedding
jefequien Oct 16, 2024
c367c0a
init blur
jefequien Oct 16, 2024
ba514ed
reg to prevent collapse
jefequien Oct 16, 2024
6da4119
mlp
jefequien Oct 16, 2024
5714315
deltas mlp
jefequien Oct 16, 2024
291a74c
cleanup
jefequien Oct 16, 2024
ad9c5cd
tcnn works with log_transform
jefequien Oct 16, 2024
fa5957a
cleanup
jefequien Oct 17, 2024
2659a9d
send
jefequien Oct 17, 2024
464249d
log depth
jefequien Oct 17, 2024
af99af2
send
jefequien Oct 17, 2024
e58b2da
init focal
jefequien Oct 17, 2024
e70f474
init strat
jefequien Oct 18, 2024
beac915
log 10 center init around 0.2
jefequien Oct 18, 2024
56f7937
slower lr
jefequien Oct 18, 2024
a8a8fe8
successfully init tcnn
jefequien Oct 18, 2024
9eb1560
remove warmup and std
jefequien Oct 19, 2024
2a73703
new loss function
jefequien Oct 19, 2024
4f4c101
dialed
jefequien Oct 20, 2024
7676b79
embed init 0 clean lossfn high lr
jefequien Oct 21, 2024
6e00996
test features
jefequien Oct 24, 2024
6c611dc
cleanup
jefequien Oct 24, 2024
9e3bc19
minor
jefequien Oct 24, 2024
63a4bb6
summarize stats
jefequien Oct 24, 2024
ed830d7
less freqs
jefequien Oct 24, 2024
d874ef1
latest
jefequien Oct 26, 2024
8a02e74
latest run avg psnr 23.50
jefequien Oct 28, 2024
c8fd74d
cleanup
jefequien Oct 29, 2024
834e4e8
cleanup
jefequien Oct 29, 2024
cb826c6
docstring
jefequien Oct 29, 2024
ac0cef5
docstring
jefequien Oct 29, 2024
a3d7d15
rescale
jefequien Oct 29, 2024
10bac1d
minor
jefequien Oct 29, 2024
68289ae
warmup 3
jefequien Oct 29, 2024
d95b929
delayed start instead of warmup
jefequien Oct 29, 2024
8f7e92c
wip
jefequien Oct 31, 2024
f355877
bounded l1 losos
jefequien Nov 13, 2024
a14bcb7
mlp folder
jefequien Nov 13, 2024
d9db9e0
Merge branch 'main' into jeff/defocus
jefequien Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/benchmarks/compression/mcmc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ done
if command -v zip &> /dev/null
then
echo "Zipping results"
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST
else
echo "zip command not found, skipping zipping"
fi
9 changes: 6 additions & 3 deletions examples/benchmarks/compression/summarize_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import tyro


def main(results_dir: str, scenes: List[str]):
def main(results_dir: str, scenes: List[str], stage: str = "compress"):
print("scenes:", scenes)
stage = "compress"

summary = defaultdict(list)
for scene in scenes:
Expand All @@ -33,7 +32,11 @@ def main(results_dir: str, scenes: List[str]):
summary[k].append(v)

for k, v in summary.items():
print(k, np.mean(v))
summary[k] = np.mean(v)
summary["scenes"] = scenes

with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f:
json.dump(summary, f, indent=2)


if __name__ == "__main__":
Expand Down
23 changes: 23 additions & 0 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
SCENE_DIR="data/deblur_dataset/real_defocus_blur"
SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools"

DATA_FACTOR=4
RENDER_TRAJ_PATH="spiral"
CAP_MAX=250000
RESULT_DIR="results/benchmark_mcmc_deblur"

for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"

# train and eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--blur_opt \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE
done

# Summarize the stats
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
98 changes: 98 additions & 0 deletions examples/blur_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from examples.mlp import create_mlp, get_encoder
from gsplat.utils import log_transform


class BlurOptModule(nn.Module):
"""Blur optimization module."""

def __init__(self, n: int, embed_dim: int = 4):
super().__init__()
self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(num_freqs=3, input_dims=3)
self.depths_encoder = get_encoder(num_freqs=3, input_dims=1)
self.grid_encoder = get_encoder(num_freqs=1, input_dims=2)
self.blur_mask_mlp = create_mlp(
in_dim=embed_dim + self.depths_encoder.out_dim + self.grid_encoder.out_dim,
num_layers=5,
layer_width=64,
out_dim=1,
)
self.blur_deltas_mlp = create_mlp(
in_dim=embed_dim + self.means_encoder.out_dim + 7,
num_layers=5,
layer_width=64,
out_dim=7,
)
self.bounded_l1_loss = bounded_l1_loss(10.0, 0.5)

def zero_init(self):
torch.nn.init.zeros_(self.embeds.weight)

def forward(
self,
image_ids: Tensor,
means: Tensor,
scales: Tensor,
quats: Tensor,
):
quats = F.normalize(quats, dim=-1)
means_emb = self.means_encoder.encode(log_transform(means))
images_emb = self.embeds(image_ids).repeat(means.shape[0], 1)
mlp_out = self.blur_deltas_mlp(
torch.cat([images_emb, means_emb, scales, quats], dim=-1)
).float()
scales_delta = torch.clamp(mlp_out[:, :3], min=0.0, max=0.1)
quats_delta = torch.clamp(mlp_out[:, 3:], min=0.0, max=0.1)
scales = torch.exp(scales + scales_delta)
quats = quats + quats_delta
return scales, quats

def predict_mask(self, image_ids: Tensor, depths: Tensor):
height, width = depths.shape[1:3]
grid_y, grid_x = torch.meshgrid(
(torch.arange(height, device=depths.device) + 0.5) / height,
(torch.arange(width, device=depths.device) + 0.5) / width,
indexing="ij",
)
grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)
grid_emb = self.grid_encoder.encode(grid_xy)
depths_emb = self.depths_encoder.encode(log_transform(depths))
images_emb = self.embeds(image_ids).repeat(*depths_emb.shape[:-1], 1)
mlp_in = torch.cat([images_emb, grid_emb, depths_emb], dim=-1)
mlp_out = self.blur_mask_mlp(mlp_in.reshape(-1, mlp_in.shape[-1])).reshape(
depths.shape
)
blur_mask = torch.sigmoid(mlp_out)
return blur_mask

def mask_loss(self, blur_mask: Tensor):
"""Loss function for regularizing the blur mask by controlling its mean.

Uses bounded l1 loss which diverges to +infinity at 0 and 1 to prevents the mask
from collapsing all 0s or 1s.
"""
x = blur_mask.mean()
return self.bounded_l1_loss(x)


def bounded_l1_loss(lambda_a: float, lambda_b: float, eps: float = 1e-2):
"""L1 loss function with discontinuities at 0 and 1.

Args:
lambda_a (float): Coefficient of L1 loss.
lambda_b (float): Coefficient of bounded loss.
eps (float, optional): Epsilon to prevent divide by zero. Defaults to 1e-2.
"""

def loss_fn(x: Tensor):
return lambda_a * x + lambda_b * (1 / (1 - x + eps) + 1 / (x + eps))

# Compute constant that sets min to zero
xs = torch.linspace(0, 1, 1000)
ys = loss_fn(xs)
c = ys.min()
return lambda x: loss_fn(x) - c
7 changes: 6 additions & 1 deletion examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def __init__(
self.factor = factor
self.normalize = normalize
self.test_every = test_every
li = os.listdir(data_dir)
for l in li:
if l.startswith("hold"):
self.test_every = int(l.split("=")[-1])
break

colmap_dir = os.path.join(data_dir, "sparse/0/")
if not os.path.exists(colmap_dir):
Expand Down Expand Up @@ -134,7 +139,7 @@ def __init__(

# Load extended metadata. Used by Bilarf dataset.
self.extconf = {
"spiral_radius_scale": 1.0,
"spiral_radius_scale": 0.1,
"no_factor_suffix": False,
}
extconf_file = os.path.join(data_dir, "ext_metadata.json")
Expand Down
2 changes: 2 additions & 0 deletions examples/mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .encoder import get_encoder
from .mlp import create_mlp
47 changes: 47 additions & 0 deletions examples/mlp/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch


def get_encoder(num_freqs: int, input_dims: int):
kwargs = {
"include_input": True,
"input_dims": input_dims,
"max_freq_log2": num_freqs - 1,
"num_freqs": num_freqs,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
encoder = Encoder(**kwargs)
return encoder


class Encoder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()

def create_embedding_fn(self):
embed_fns = []
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d

max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]

if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)

for freq in freq_bands:
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d

self.embed_fns = embed_fns
self.out_dim = out_dim

def encode(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
58 changes: 58 additions & 0 deletions examples/mlp/external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys


class _LazyError:
def __init__(self, data):
self.__data = data # pylint: disable=unused-private-member

class LazyErrorObj:
def __init__(self, data):
self.__data = data # pylint: disable=unused-private-member

def __call__(self, *args, **kwds):
name, exc = object.__getattribute__(self, "__data")
raise RuntimeError(f"Could not load package {name}.") from exc

def __getattr__(self, __name: str):
name, exc = object.__getattribute__(self, "__data")
raise RuntimeError(f"Could not load package {name}") from exc

def __getattr__(self, __name: str):
return _LazyError.LazyErrorObj(object.__getattribute__(self, "__data"))


TCNN_EXISTS = False
tcnn_import_exception = None
tcnn = None
try:
import tinycudann

tcnn = tinycudann
del tinycudann
TCNN_EXISTS = True
except ModuleNotFoundError as _exp:
tcnn_import_exception = _exp
except ImportError as _exp:
tcnn_import_exception = _exp
except EnvironmentError as _exp:
if "Unknown compute capability" not in _exp.args[0]:
raise _exp
print("Could not load tinycudann: " + str(_exp), file=sys.stderr)
tcnn_import_exception = _exp

if tcnn_import_exception is not None:
tcnn = _LazyError(tcnn_import_exception)
Loading
Loading