Skip to content

Commit

Permalink
Add support for PnP.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jan 25, 2024
1 parent 8aa3a25 commit 798d1d8
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 17 deletions.
2 changes: 1 addition & 1 deletion configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ device: "cuda"
# numbers of iterations to benchmark
n_iter_range: [5, 10, 20, 50, 100, 200, 300]
# number of files to benchmark
n_files: 200 # null for all files
n_files: null # null for all files
#How much should the image be downsampled
downsample: 2
#algorithm to benchmark
Expand Down
2 changes: 2 additions & 0 deletions configs/defaults_recon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ admm:
mu2: 1e-5
mu3: 4e-5
tau: 0.0001
# PnP
denoiser: null # set to use PnP
#Loading unrolled model
unrolled: false
checkpoint_fp: null
Expand Down
98 changes: 87 additions & 11 deletions lensless/recon/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
psi_gram=None,
pad=False,
norm="backward",
# for PnP
denoiser=None,
**kwargs
):
"""
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
)

# call reset() to initialize matrices
self._proj = self._Psi
super(ADMM, self).__init__(psf, dtype, pad=pad, norm=norm, **kwargs)

# set prior
Expand All @@ -109,6 +112,10 @@ def __init__(
self._PsiT = psi_adj
self._PsiTPsi = psi_gram(self._padded_shape)

# - need to reset with new projector
self._proj = self._Psi
self.reset()

# precompute_R_divmat (self._H computed by constructor with reset())
if self.is_torch:
self._PsiTPsi = self._PsiTPsi.to(self._psf.device)
Expand All @@ -124,6 +131,43 @@ def __init__(
+ self._mu3
).astype(self._complex_dtype)

# check denoiser for PnP
self._denoiser = denoiser
if denoiser is not None:
assert self.is_torch

import lensless.recon.utils

denoiser_model, _ = lensless.recon.utils.create_process_network(
network=denoiser["network"], device=self._psf.device
)

def denoiser_func(x, normalize_image=True):
torch.clip(x, min=0.0, out=x)

x_max = torch.amax(x, dim=(-2, -3), keepdim=True) + 1e-6
denoised = lensless.recon.utils.apply_denoiser(
model=denoiser_model,
# image=x / x_max,
image=x / x_max if normalize_image else x,
noise_level=denoiser["noise_level"],
device=self._psf.device,
)
# denoised = torch.clip(denoised, min=0.0) * x_max.to(self._psf.device)
denoised = torch.clip(denoised, min=0.0)
if normalize_image:
denoised = denoised * x_max.to(self._psf.device)
return denoised

self._denoiser = denoiser_func
self._denoiser_use_dual = denoiser["use_dual"]

# - need to reset with new projector
self._proj = self._denoiser
# identify function
self._PsiT = lambda x: x
self.reset()

def _Psi(self, x):
"""
Operator to map image to space that the image is assumed to be sparse
Expand All @@ -150,7 +194,8 @@ def reset(self):

# self._image_est = torch.zeros_like(self._psf)
self._X = torch.zeros_like(self._image_est)
self._U = torch.zeros_like(self._Psi(self._image_est))
# self._U = torch.zeros_like(self._Psi(self._image_est))
self._U = torch.zeros_like(self._proj(self._image_est))
self._W = torch.zeros_like(self._X)
if self._image_est.max():
# if non-zero
Expand All @@ -177,7 +222,8 @@ def reset(self):

# self._U = np.zeros(np.r_[self._padded_shape, [2]], dtype=self._dtype)
self._X = np.zeros_like(self._image_est)
self._U = np.zeros_like(self._Psi(self._image_est))
# self._U = np.zeros_like(self._Psi(self._image_est))
self._U = np.zeros_like(self._proj(self._image_est))
self._W = np.zeros_like(self._X)
if self._image_est.max():
# if non-zero
Expand All @@ -200,7 +246,18 @@ def reset(self):
def _U_update(self):
"""Total variation update."""
# to avoid computing sparse operator twice
self._U = soft_thresh(self._Psi_out + self._eta / self._mu2, self._tau / self._mu2)
if self._denoiser is not None:
# PnP
if self._denoiser_use_dual:
self._U = self._denoiser(
self._U + self._eta / self._mu2,
)
else:
self._U = self._denoiser(self._image_est)
else:
self._U = soft_thresh(
self._Psi_out + self._eta / self._mu2, thresh=self._tau / self._mu2
)

def _X_update(self):
# to avoid computing forward model twice
Expand All @@ -219,11 +276,22 @@ def _W_update(self):
self._W = np.maximum(self._rho / self._mu3 + self._image_est, 0)

def _image_update(self):
rk = (
(self._mu3 * self._W - self._rho)
+ self._PsiT(self._mu2 * self._U - self._eta)
+ self._convolver.deconvolve(self._mu1 * self._X - self._xi)
)
if self._denoiser is not None:
# PnP
rk = (
(self._mu3 * self._W - self._rho)
# + self._mu2 * self._U
+ self._mu2 * self._U - self._eta
if self._denoiser_use_dual
else self._mu2 * self._U
+ self._convolver.deconvolve(self._mu1 * self._X - self._xi)
)
else:
rk = (
(self._mu3 * self._W - self._rho)
+ self._PsiT(self._mu2 * self._U - self._eta)
+ self._convolver.deconvolve(self._mu1 * self._X - self._xi)
)

# rk = self._convolver._pad(rk)

Expand All @@ -242,7 +310,11 @@ def _xi_update(self):

def _eta_update(self):
# to avoid finite difference operataion again?
self._eta += self._mu2 * (self._Psi_out - self._U)
if self._denoiser is not None:
# PnP
self._eta += self._mu2 * (self._image_est - self._U)
else:
self._eta += self._mu2 * (self._Psi_out - self._U)

def _rho_update(self):
self._rho += self._mu3 * (self._image_est - self._W)
Expand All @@ -255,10 +327,14 @@ def _update(self, iter):

# update forward and sparse operators
self._forward_out = self._convolver.convolve(self._image_est)
self._Psi_out = self._Psi(self._image_est)
if self._denoiser is None:
self._Psi_out = self._Psi(self._image_est)

self._xi_update()
self._eta_update()
if self._denoiser is None:
self._eta_update()
elif self._denoiser_use_dual:
self._eta_update()
self._rho_update()

def _form_image(self):
Expand Down
10 changes: 8 additions & 2 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference")
image : :py:class:`torch.Tensor`
Reconstructed image.
"""
assert noise_level > 0
assert noise_level <= 255

# convert from NDHWC to NCHW
depth = image.shape[-4]
image = image.movedim(-1, -3)
Expand All @@ -118,6 +121,7 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference")
noise_level = noise_level / 255.0
else:
noise_level = torch.tensor([noise_level / 255.0]).to(device)

image = torch.cat(
(
image,
Expand Down Expand Up @@ -194,7 +198,7 @@ def measure_gradient(model):
return total_norm


def create_process_network(network, depth, device="cpu", nc=None):
def create_process_network(network, depth=4, device="cpu", nc=None):
"""
Helper function to create a process network.
Expand Down Expand Up @@ -847,7 +851,8 @@ def save(self, epoch, path="recon", include_optimizer=False):
self.mask._mask.cpu().detach().numpy(),
)

if self.mask.color_filter is not None:
# if color_filter is an attribute
if hasattr(self.mask, "color_filter") and self.mask.color_filter is not None:
# save save numpy array
np.save(
os.path.join(path, f"mask_color_filter_epoch{epoch}.npy"),
Expand All @@ -860,6 +865,7 @@ def save(self, epoch, path="recon", include_optimizer=False):

psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...]
psf_np = psf_np.squeeze() # remove (potential) singleton color channel
np.save(os.path.join(path, f"psf_epoch{epoch}.npy"), psf_np)
save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png"))
plot_image(psf_np, gamma=self.gamma)
plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png"))
Expand Down
14 changes: 14 additions & 0 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(
# background_pix=(0, 15),
downsample=1,
flip=False,
flip_ud=False,
flip_lr=False,
transform_lensless=None,
transform_lensed=None,
input_snr=None,
Expand Down Expand Up @@ -83,6 +85,8 @@ def __init__(
self.input_snr = input_snr
self.downsample = downsample
self.flip = flip
self.flip_ud = flip_ud
self.flip_lr = flip_lr
self.transform_lensless = transform_lensless
self.transform_lensed = transform_lensed

Expand Down Expand Up @@ -161,6 +165,12 @@ def __getitem__(self, idx):
if self.flip:
lensless = torch.rot90(lensless, dims=(-3, -2), k=2)
lensed = torch.rot90(lensed, dims=(-3, -2), k=2)
if self.flip_ud:
lensless = torch.flip(lensless, dims=(-4, -3))
lensed = torch.flip(lensed, dims=(-4, -3))
if self.flip_lr:
lensless = torch.flip(lensless, dims=(-4, -2))
lensed = torch.flip(lensed, dims=(-4, -2))
if self.transform_lensless:
lensless = self.transform_lensless(lensless)
if self.transform_lensed:
Expand Down Expand Up @@ -769,6 +779,8 @@ def __init__(
return_float=True,
return_bg=True,
bg_pix=(0, 15),
flip_ud=True,
flip_lr=False,
)

# transform from BGR to RGB
Expand All @@ -787,6 +799,8 @@ def __init__(
background=background,
downsample=downsample,
flip=False,
flip_ud=True,
flip_lr=False,
transform_lensless=transform_BRG2RGB,
transform_lensed=transform_BRG2RGB,
lensless_fn="diffuser",
Expand Down
10 changes: 10 additions & 0 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def load_image(
fp,
verbose=False,
flip=False,
flip_ud=False,
flip_lr=False,
bayer=False,
black_level=RPI_HQ_CAMERA_BLACK_LEVEL,
blue_gain=None,
Expand Down Expand Up @@ -157,6 +159,10 @@ def load_image(
if flip:
img = np.flipud(img)
img = np.fliplr(img)
if flip_ud:
img = np.flipud(img)
if flip_lr:
img = np.fliplr(img)

if verbose:
print_image_info(img)
Expand Down Expand Up @@ -206,6 +212,8 @@ def load_psf(
bg_pix=(5, 25),
return_bg=False,
flip=False,
flip_ud=False,
flip_lr=False,
verbose=False,
bayer=False,
blue_gain=None,
Expand Down Expand Up @@ -282,6 +290,8 @@ def load_psf(
fp,
verbose=verbose,
flip=flip,
flip_ud=flip_ud,
flip_lr=flip_lr,
bayer=bayer,
blue_gain=blue_gain,
red_gain=red_gain,
Expand Down
47 changes: 44 additions & 3 deletions scripts/eval/benchmark_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def benchmark_recon(config):
raise ValueError(f"Dataset {dataset} not supported")

print(f"Number of files : {len(benchmark_dataset)}")
print(f"Data shape : {dataset[0][0].shape}")
print(f"Data shape : {benchmark_dataset[0][0].shape}")

model_list = [] # list of algoritms to benchmark
if "ADMM" in config.algorithms:
Expand All @@ -104,6 +104,48 @@ def benchmark_recon(config):
)
if "ADMM_Monakhova2019" in config.algorithms:
model_list.append(("ADMM_Monakhova2019", ADMM(psf, mu1=1e-4, mu2=1e-4, mu3=1e-4, tau=2e-3)))
if "ADMM_PnP_10" in config.algorithms:
model_list.append(
(
"ADMM_PnP_10",
ADMM(
psf,
mu1=config.admm.mu1,
mu2=config.admm.mu2,
mu3=config.admm.mu3,
tau=config.admm.tau,
denoiser={"network": "DruNet", "noise_level": 10, "use_dual": False},
),
)
)
if "ADMM_PnP_25" in config.algorithms:
model_list.append(
(
"ADMM_PnP_25",
ADMM(
psf,
mu1=config.admm.mu1,
mu2=config.admm.mu2,
mu3=config.admm.mu3,
tau=config.admm.tau,
denoiser={"network": "DruNet", "noise_level": 25, "use_dual": False},
),
)
)
if "ADMM_PnP_50" in config.algorithms:
model_list.append(
(
"ADMM_PnP_50",
ADMM(
psf,
mu1=config.admm.mu1,
mu2=config.admm.mu2,
mu3=config.admm.mu3,
tau=config.admm.tau,
denoiser={"network": "DruNet", "noise_level": 50, "use_dual": False},
),
)
)
if "FISTA" in config.algorithms:
model_list.append(("FISTA", FISTA(psf, tk=config.fista.tk)))
if "GradientDescent" in config.algorithms:
Expand Down Expand Up @@ -310,8 +352,7 @@ def benchmark_recon(config):
)
plt.xlabel("Number of iterations", fontsize="12")
plt.ylabel(metric, fontsize="12")
if metric == "ReconstructionError":
plt.legend(fontsize="12")
plt.legend(fontsize="12")
plt.grid()
plt.savefig(f"{metric}.png")

Expand Down
1 change: 1 addition & 0 deletions scripts/recon/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def admm(config):
fig = plt.gcf()
plt.close(fig)

# load model
start_time = time.time()
if not config.admm.unrolled:
recon = ADMM(psf, **config.admm)
Expand Down

0 comments on commit 798d1d8

Please sign in to comment.