diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index 47915514..384ad2cb 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -14,7 +14,7 @@ device: "cuda" # numbers of iterations to benchmark n_iter_range: [5, 10, 20, 50, 100, 200, 300] # number of files to benchmark -n_files: 200 # null for all files +n_files: null # null for all files #How much should the image be downsampled downsample: 2 #algorithm to benchmark diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 1771ff8a..d2cadc5f 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -64,6 +64,8 @@ admm: mu2: 1e-5 mu3: 4e-5 tau: 0.0001 + # PnP + denoiser: null # set to use PnP #Loading unrolled model unrolled: false checkpoint_fp: null diff --git a/lensless/recon/admm.py b/lensless/recon/admm.py index b3810563..6ee2792d 100644 --- a/lensless/recon/admm.py +++ b/lensless/recon/admm.py @@ -43,6 +43,8 @@ def __init__( psi_gram=None, pad=False, norm="backward", + # for PnP + denoiser=None, **kwargs ): """ @@ -92,6 +94,7 @@ def __init__( ) # call reset() to initialize matrices + self._proj = self._Psi super(ADMM, self).__init__(psf, dtype, pad=pad, norm=norm, **kwargs) # set prior @@ -109,6 +112,10 @@ def __init__( self._PsiT = psi_adj self._PsiTPsi = psi_gram(self._padded_shape) + # - need to reset with new projector + self._proj = self._Psi + self.reset() + # precompute_R_divmat (self._H computed by constructor with reset()) if self.is_torch: self._PsiTPsi = self._PsiTPsi.to(self._psf.device) @@ -124,6 +131,43 @@ def __init__( + self._mu3 ).astype(self._complex_dtype) + # check denoiser for PnP + self._denoiser = denoiser + if denoiser is not None: + assert self.is_torch + + import lensless.recon.utils + + denoiser_model, _ = lensless.recon.utils.create_process_network( + network=denoiser["network"], device=self._psf.device + ) + + def denoiser_func(x, normalize_image=True): + torch.clip(x, min=0.0, out=x) + + x_max = torch.amax(x, dim=(-2, -3), keepdim=True) + 1e-6 + denoised = lensless.recon.utils.apply_denoiser( + model=denoiser_model, + # image=x / x_max, + image=x / x_max if normalize_image else x, + noise_level=denoiser["noise_level"], + device=self._psf.device, + ) + # denoised = torch.clip(denoised, min=0.0) * x_max.to(self._psf.device) + denoised = torch.clip(denoised, min=0.0) + if normalize_image: + denoised = denoised * x_max.to(self._psf.device) + return denoised + + self._denoiser = denoiser_func + self._denoiser_use_dual = denoiser["use_dual"] + + # - need to reset with new projector + self._proj = self._denoiser + # identify function + self._PsiT = lambda x: x + self.reset() + def _Psi(self, x): """ Operator to map image to space that the image is assumed to be sparse @@ -150,7 +194,8 @@ def reset(self): # self._image_est = torch.zeros_like(self._psf) self._X = torch.zeros_like(self._image_est) - self._U = torch.zeros_like(self._Psi(self._image_est)) + # self._U = torch.zeros_like(self._Psi(self._image_est)) + self._U = torch.zeros_like(self._proj(self._image_est)) self._W = torch.zeros_like(self._X) if self._image_est.max(): # if non-zero @@ -177,7 +222,8 @@ def reset(self): # self._U = np.zeros(np.r_[self._padded_shape, [2]], dtype=self._dtype) self._X = np.zeros_like(self._image_est) - self._U = np.zeros_like(self._Psi(self._image_est)) + # self._U = np.zeros_like(self._Psi(self._image_est)) + self._U = np.zeros_like(self._proj(self._image_est)) self._W = np.zeros_like(self._X) if self._image_est.max(): # if non-zero @@ -200,7 +246,18 @@ def reset(self): def _U_update(self): """Total variation update.""" # to avoid computing sparse operator twice - self._U = soft_thresh(self._Psi_out + self._eta / self._mu2, self._tau / self._mu2) + if self._denoiser is not None: + # PnP + if self._denoiser_use_dual: + self._U = self._denoiser( + self._U + self._eta / self._mu2, + ) + else: + self._U = self._denoiser(self._image_est) + else: + self._U = soft_thresh( + self._Psi_out + self._eta / self._mu2, thresh=self._tau / self._mu2 + ) def _X_update(self): # to avoid computing forward model twice @@ -219,11 +276,22 @@ def _W_update(self): self._W = np.maximum(self._rho / self._mu3 + self._image_est, 0) def _image_update(self): - rk = ( - (self._mu3 * self._W - self._rho) - + self._PsiT(self._mu2 * self._U - self._eta) - + self._convolver.deconvolve(self._mu1 * self._X - self._xi) - ) + if self._denoiser is not None: + # PnP + rk = ( + (self._mu3 * self._W - self._rho) + # + self._mu2 * self._U + + self._mu2 * self._U - self._eta + if self._denoiser_use_dual + else self._mu2 * self._U + + self._convolver.deconvolve(self._mu1 * self._X - self._xi) + ) + else: + rk = ( + (self._mu3 * self._W - self._rho) + + self._PsiT(self._mu2 * self._U - self._eta) + + self._convolver.deconvolve(self._mu1 * self._X - self._xi) + ) # rk = self._convolver._pad(rk) @@ -242,7 +310,11 @@ def _xi_update(self): def _eta_update(self): # to avoid finite difference operataion again? - self._eta += self._mu2 * (self._Psi_out - self._U) + if self._denoiser is not None: + # PnP + self._eta += self._mu2 * (self._image_est - self._U) + else: + self._eta += self._mu2 * (self._Psi_out - self._U) def _rho_update(self): self._rho += self._mu3 * (self._image_est - self._W) @@ -255,10 +327,14 @@ def _update(self, iter): # update forward and sparse operators self._forward_out = self._convolver.convolve(self._image_est) - self._Psi_out = self._Psi(self._image_est) + if self._denoiser is None: + self._Psi_out = self._Psi(self._image_est) self._xi_update() - self._eta_update() + if self._denoiser is None: + self._eta_update() + elif self._denoiser_use_dual: + self._eta_update() self._rho_update() def _form_image(self): diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 918114b4..df98c1fc 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -102,6 +102,9 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") image : :py:class:`torch.Tensor` Reconstructed image. """ + assert noise_level > 0 + assert noise_level <= 255 + # convert from NDHWC to NCHW depth = image.shape[-4] image = image.movedim(-1, -3) @@ -118,6 +121,7 @@ def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference") noise_level = noise_level / 255.0 else: noise_level = torch.tensor([noise_level / 255.0]).to(device) + image = torch.cat( ( image, @@ -194,7 +198,7 @@ def measure_gradient(model): return total_norm -def create_process_network(network, depth, device="cpu", nc=None): +def create_process_network(network, depth=4, device="cpu", nc=None): """ Helper function to create a process network. @@ -847,7 +851,8 @@ def save(self, epoch, path="recon", include_optimizer=False): self.mask._mask.cpu().detach().numpy(), ) - if self.mask.color_filter is not None: + # if color_filter is an attribute + if hasattr(self.mask, "color_filter") and self.mask.color_filter is not None: # save save numpy array np.save( os.path.join(path, f"mask_color_filter_epoch{epoch}.npy"), @@ -860,6 +865,7 @@ def save(self, epoch, path="recon", include_optimizer=False): psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] psf_np = psf_np.squeeze() # remove (potential) singleton color channel + np.save(os.path.join(path, f"psf_epoch{epoch}.npy"), psf_np) save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) plot_image(psf_np, gamma=self.gamma) plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 7fef2e17..851a060c 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -47,6 +47,8 @@ def __init__( # background_pix=(0, 15), downsample=1, flip=False, + flip_ud=False, + flip_lr=False, transform_lensless=None, transform_lensed=None, input_snr=None, @@ -83,6 +85,8 @@ def __init__( self.input_snr = input_snr self.downsample = downsample self.flip = flip + self.flip_ud = flip_ud + self.flip_lr = flip_lr self.transform_lensless = transform_lensless self.transform_lensed = transform_lensed @@ -161,6 +165,12 @@ def __getitem__(self, idx): if self.flip: lensless = torch.rot90(lensless, dims=(-3, -2), k=2) lensed = torch.rot90(lensed, dims=(-3, -2), k=2) + if self.flip_ud: + lensless = torch.flip(lensless, dims=(-4, -3)) + lensed = torch.flip(lensed, dims=(-4, -3)) + if self.flip_lr: + lensless = torch.flip(lensless, dims=(-4, -2)) + lensed = torch.flip(lensed, dims=(-4, -2)) if self.transform_lensless: lensless = self.transform_lensless(lensless) if self.transform_lensed: @@ -769,6 +779,8 @@ def __init__( return_float=True, return_bg=True, bg_pix=(0, 15), + flip_ud=True, + flip_lr=False, ) # transform from BGR to RGB @@ -787,6 +799,8 @@ def __init__( background=background, downsample=downsample, flip=False, + flip_ud=True, + flip_lr=False, transform_lensless=transform_BRG2RGB, transform_lensed=transform_BRG2RGB, lensless_fn="diffuser", diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 750b0e0e..4f25b62a 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -21,6 +21,8 @@ def load_image( fp, verbose=False, flip=False, + flip_ud=False, + flip_lr=False, bayer=False, black_level=RPI_HQ_CAMERA_BLACK_LEVEL, blue_gain=None, @@ -157,6 +159,10 @@ def load_image( if flip: img = np.flipud(img) img = np.fliplr(img) + if flip_ud: + img = np.flipud(img) + if flip_lr: + img = np.fliplr(img) if verbose: print_image_info(img) @@ -206,6 +212,8 @@ def load_psf( bg_pix=(5, 25), return_bg=False, flip=False, + flip_ud=False, + flip_lr=False, verbose=False, bayer=False, blue_gain=None, @@ -282,6 +290,8 @@ def load_psf( fp, verbose=verbose, flip=flip, + flip_ud=flip_ud, + flip_lr=flip_lr, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 565864f4..858d2c54 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -86,7 +86,7 @@ def benchmark_recon(config): raise ValueError(f"Dataset {dataset} not supported") print(f"Number of files : {len(benchmark_dataset)}") - print(f"Data shape : {dataset[0][0].shape}") + print(f"Data shape : {benchmark_dataset[0][0].shape}") model_list = [] # list of algoritms to benchmark if "ADMM" in config.algorithms: @@ -104,6 +104,48 @@ def benchmark_recon(config): ) if "ADMM_Monakhova2019" in config.algorithms: model_list.append(("ADMM_Monakhova2019", ADMM(psf, mu1=1e-4, mu2=1e-4, mu3=1e-4, tau=2e-3))) + if "ADMM_PnP_10" in config.algorithms: + model_list.append( + ( + "ADMM_PnP_10", + ADMM( + psf, + mu1=config.admm.mu1, + mu2=config.admm.mu2, + mu3=config.admm.mu3, + tau=config.admm.tau, + denoiser={"network": "DruNet", "noise_level": 10, "use_dual": False}, + ), + ) + ) + if "ADMM_PnP_25" in config.algorithms: + model_list.append( + ( + "ADMM_PnP_25", + ADMM( + psf, + mu1=config.admm.mu1, + mu2=config.admm.mu2, + mu3=config.admm.mu3, + tau=config.admm.tau, + denoiser={"network": "DruNet", "noise_level": 25, "use_dual": False}, + ), + ) + ) + if "ADMM_PnP_50" in config.algorithms: + model_list.append( + ( + "ADMM_PnP_50", + ADMM( + psf, + mu1=config.admm.mu1, + mu2=config.admm.mu2, + mu3=config.admm.mu3, + tau=config.admm.tau, + denoiser={"network": "DruNet", "noise_level": 50, "use_dual": False}, + ), + ) + ) if "FISTA" in config.algorithms: model_list.append(("FISTA", FISTA(psf, tk=config.fista.tk))) if "GradientDescent" in config.algorithms: @@ -310,8 +352,7 @@ def benchmark_recon(config): ) plt.xlabel("Number of iterations", fontsize="12") plt.ylabel(metric, fontsize="12") - if metric == "ReconstructionError": - plt.legend(fontsize="12") + plt.legend(fontsize="12") plt.grid() plt.savefig(f"{metric}.png") diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 80657793..1fdd68ba 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -67,6 +67,7 @@ def admm(config): fig = plt.gcf() plt.close(fig) + # load model start_time = time.time() if not config.admm.unrolled: recon = ADMM(psf, **config.admm)