diff --git a/configs/benchmark_hyperspectral.yaml b/configs/benchmark_hyperspectral.yaml new file mode 100644 index 00000000..53671cf4 --- /dev/null +++ b/configs/benchmark_hyperspectral.yaml @@ -0,0 +1,114 @@ +# python scripts/eval/benchmark_recon.py +#Hydra config +hydra: + run: + dir: "benchmark/${now:%Y-%m-%d}/${now:%H-%M-%S}" + job: + chdir: True + + +dataset: PolarLitis # DiffuserCam, DigiCamCelebA, HFDataset +seed: 0 +batchsize: 1 # must be 1 for iterative approaches + +huggingface: + repo: "noakraicer/polarlitis" + cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`. + psf: psf.mat + mask: mask.npy # null for simulating PSF + image_res: [250, 250] # used during measurement + rotate: False # if measurement is upside-down + flipud: False + flip_lensed: False # if rotate or flipud is True, apply to lensed + + alignment: + top_left: null + height: null + + downsample: 1 + downsample_lensed: 2 + split_seed: null + single_channel_psf: True + +device: "cuda" +# numbers of iterations to benchmark +n_iter_range: [2000] +# number of files to benchmark +n_files: null # null for all files +#How much should the image be downsampled +downsample: 2 +#algorithm to benchmark +algorithms: ["HyperSpectralFISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"] + +# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502 +baseline: "MONAKHOVA 100iter" + +save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10] +gamma_psf: 1.5 # gamma factor for PSF + + +# Hyperparameters +nesterov: + p: 0 + mu: 0.9 +fista: + tk: 1 +admm: + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 + + +# for DigiCamCelebA +files: + test_size: 0.15 + downsample: 1 + celeba_root: /scratch/bezzam + + + # dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + # psf: data/psf/adafruit_random_2mm_20231907.png + # vertical_shift: null + # horizontal_shift: null + # crop: null + + dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K + psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png + vertical_shift: -117 + horizontal_shift: -25 + crop: + vertical: [0, 525] + horizontal: [265, 695] + +# for prepping ground truth data +#for simulated dataset +simulation: + grayscale: False + output_dim: null # should be set if no PSF is used + # random variations + object_height: 0.33 # [m], range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) + random_shift: False + random_vflip: 0.5 + random_hflip: 0.5 + random_rotate: False + # these distance parameters are typically fixed for a given PSF + # for DiffuserCam psf # for tape_rgb psf + # scene2mask: 10e-2 # scene2mask: 40e-2 + # mask2sensor: 9e-3 # mask2sensor: 4e-3 + # -- for CelebA + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + deadspace: True # whether to account for deadspace for programmable mask + # see waveprop.devices + use_waveprop: False # for PSF simulation + sensor: "rpi_hq" + snr_db: 10 + # simulate different sensor resolution + # output_dim: [24, 32] # [H, W] or null + # Downsampling for PSF + downsample: 8 + # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability + max_val: 255 diff --git a/lensless/__init__.py b/lensless/__init__.py index 70990774..4d67f179 100644 --- a/lensless/__init__.py +++ b/lensless/__init__.py @@ -20,6 +20,7 @@ NesterovGradientDescent, FISTA, GradientDescentUpdate, + HyperSpectralFISTA ) from .recon.tikhonov import CodedApertureReconstruction from .hardware.sensor import VirtualSensor, SensorOptions diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index dc61e809..c5af193c 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm): Object for applying projected gradient descent. """ - def __init__(self, psf, dtype=None, proj=non_neg, **kwargs): + def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs): """ Parameters @@ -83,30 +83,30 @@ def __init__(self, psf, dtype=None, proj=non_neg, **kwargs): assert callable(proj) self._proj = proj - super(GradientDescent, self).__init__(psf, dtype, **kwargs) + super(GradientDescent, self).__init__(psf,mask, dtype, **kwargs) if self._denoiser is not None: print("Using denoiser in gradient descent.") # redefine projection function self._proj = self._denoiser - + self.mask=mask def reset(self): if self.is_torch: if self._initial_est is not None: self._image_est = self._initial_est else: # initial guess, half intensity image - psf_flat = self._psf.reshape(-1, self._psf_shape[3]) - pixel_start = ( - torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values - ) / 2 + # psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + # pixel_start = ( + # torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values + # ) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] - self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start + self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device) # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) - self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) + self._alpha = 1/4770.13 else: if self._initial_est is not None: @@ -123,8 +123,8 @@ def reset(self): self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) def _grad(self): - diff = self._convolver.convolve(self._image_est) - self._data - return self._convolver.deconvolve(diff) + diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1) + return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels def _update(self, iter): self._image_est -= self._alpha * self._grad() @@ -238,6 +238,78 @@ def _update(self, iter): self._xk = xk +def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): + + # load data + psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs) + + # create reconstruction object + recon = GradientDescent(psf, n_iter=n_iter, proj=proj) + + # set data + recon.set_data(data) + + # perform reconstruction + start_time = time.time() + res = recon.apply(plot=False) + proc_time = time.time() - start_time + + if verbose: + print(f"Reconstruction time : {proc_time} s") + print(f"Reconstruction shape: {res.shape}") + return res +class HyperSpectralFISTA(GradientDescent): + """ + Object for applying projected gradient descent with FISTA (Fast Iterative + Shrinkage-Thresholding Algorithm) for acceleration. + + Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA + + """ + + def __init__(self, psf,mask, dtype=None, proj=non_neg, tk=1.0, **kwargs): + """ + + Parameters + ---------- + psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Point spread function (PSF) that models forward propagation. + Must be of shape (depth, height, width, channels) even if + depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` + to load a PSF from a file such that it is in the correct format. + dtype : float32 or float64 + Data type to use for optimization. Default is float32. + proj : :py:class:`function` + Projection function to apply at each iteration. Default is + non-negative. + tk : float + Initial step size parameter for FISTA. It is updated at each iteration + according to Eq. 4.2 of paper. By default, initialized to 1.0. + + """ + self._initial_tk = tk + + super(HyperSpectralFISTA, self).__init__(psf,mask, dtype, proj, **kwargs) + + self._tk = tk + self._xk = self._image_est + + def reset(self, tk=None): + super(HyperSpectralFISTA, self).reset() + if tk: + self._tk = tk + else: + self._tk = self._initial_tk + self._xk = self._image_est + def _update(self, iter): + self._image_est -= self._alpha * self._grad() + xk = self._form_image() + tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2 + self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk) + self._tk = tk + self._xk = xk + + def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): # load data diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index ff1fc55c..b17b165f 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -203,6 +203,7 @@ class ReconstructionAlgorithm(abc.ABC): def __init__( self, psf, + mask, dtype=None, pad=True, n_iter=100, @@ -369,12 +370,13 @@ def set_data(self, data): assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]." # assert same shapes - assert np.all( - self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] - ), "PSF and data shape mismatch" - - if len(data.shape) == 3: - self._data = data[None, None, ...] + # assert np.all( + # self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] + # ), "PSF and data shape mismatch" + if len(data.shape)==3: + self._data = data.unsqueeze(-1) + # if len(data.shape) == 3: + # self._data = data[None, None, ...] elif len(data.shape) == 4: self._data = data[None, ...] else: @@ -569,6 +571,9 @@ def apply( for i in range(n_iter): self._update(i) + if i%50==0: + img = self._form_image() + if self.compensation_branch is not None and i < self._n_iter - 1: self.compensation_branch_inputs.append(self._form_image()) diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index e7b9be74..a4848e45 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -24,7 +24,7 @@ class RealFFTConvolve2D: - def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs): + def __init__(self, psf, dtype=None, pad=True, norm=None, rgb=None, **kwargs): """ Linear operator that performs convolution in Fourier domain, and assumes real-valued signals. @@ -135,10 +135,10 @@ def convolve(self, x): Convolve with pre-computed FFT of provided PSF. """ if self.pad: - self._padded_data = self._pad(x) + self._padded_data = self._pad(x).to(self._psf.device) else: if self.is_torch: - self._padded_data = x # .type(self.dtype).to(self._psf.device) + self._padded_data = x else: self._padded_data[:] = x # .astype(self.dtype) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5ce95f7a..d58eca05 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -17,7 +17,7 @@ from torchvision.transforms import functional as F from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD from lensless.utils.simulation import FarFieldSimulator -from lensless.utils.io import load_image, load_psf, save_image +from lensless.utils.io import load_image, load_psf, save_image,load_mask from lensless.utils.image import is_grayscale, resize, rgb2gray import re from lensless.hardware.utils import capture @@ -1271,6 +1271,7 @@ def __init__( split, n_files=None, psf=None, + mask=None, rotate=False, # just the lensless image flipud=False, flip_lensed=False, @@ -1409,11 +1410,11 @@ def __init__( if psf is not None: # download PSF from huggingface psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") - psf, _ = load_psf( + psf = load_psf( psf_fp, shape=lensless.shape, return_float=True, - return_bg=True, + return_bg=False, flip=self.rotate, flip_ud=flipud, bg_pix=(0, 15), @@ -1424,6 +1425,10 @@ def __init__( if single_channel_psf: # replicate across three channels self.psf = self.psf.repeat(1, 1, 1, 3) + if mask is not None: + mask_fp = hf_hub_download(repo_id=huggingface_repo, filename=mask, repo_type="dataset") + mask = load_mask(mask_fp) + self.mask= torch.from_numpy(mask) elif "mask_label" in data_0: self.multimask = True @@ -1563,7 +1568,9 @@ def _get_images_pair(self, idx): # convert to float if lensless_np.dtype == np.uint8: lensless_np = lensless_np.astype(np.float32) / 255 + lensless_np = lensless_np / np.max(lensless_np) lensed_np = lensed_np.astype(np.float32) / 255 + lensed_np = lensed_np / np.max(lensed_np) else: # 16 bit lensless_np = lensless_np.astype(np.float32) / 65535 diff --git a/lensless/utils/image.py b/lensless/utils/image.py index eed00121..e3ebe0c2 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -264,7 +264,7 @@ def get_max_val(img, nbits=None): max_val : int Maximum pixel value. """ - assert img.dtype not in FLOAT_DTYPES + # assert img.dtype not in FLOAT_DTYPES if nbits is None: nbits = int(np.ceil(np.log2(img.max()))) diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 5596befd..1d5e4d1a 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -8,7 +8,7 @@ import os.path import warnings - +import scipy import cv2 import numpy as np from PIL import Image @@ -17,6 +17,10 @@ from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val from lensless.utils.plot import plot_image +def load_mask(fp): + mask = np.load(fp) + return np.expand_dims(mask, axis=0) + def load_image( fp, @@ -121,6 +125,9 @@ def load_image( black_level = np.array(raw.black_level_per_channel[:3]).astype(np.float32) elif "npy" in fp or "npz" in fp: img = np.load(fp) + elif "mat" in fp: + mat = scipy.io.loadmat(fp) + img = mat['psf'][:,:,0] else: img = cv2.imread(fp, cv2.IMREAD_UNCHANGED) @@ -202,7 +209,10 @@ def load_image( else: if dtype is None: dtype = original_dtype - img = img.astype(dtype) + img = img.astype(np.float64) + + img = img[10:260, 35:320-35] + img = img / np.linalg.norm(img) return img @@ -380,7 +390,7 @@ def load_psf( if return_bg: return psf, bg else: - return psf + return psf.astype(np.float64) def load_data( diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 76fbc367..b7571698 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -25,7 +25,7 @@ import pathlib as plib from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt -from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent +from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent,HyperSpectralFISTA from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image from lensless.utils.image import gamma_correction @@ -35,7 +35,7 @@ from torch.utils.data import Subset -@hydra.main(version_base=None, config_path="../../configs", config_name="benchmark") +@hydra.main(version_base=None, config_path="../../configs", config_name="benchmark_hyperspectral") def benchmark_recon(config): # set seed @@ -86,7 +86,7 @@ def benchmark_recon(config): _, benchmark_dataset = torch.utils.data.random_split( dataset, [train_size, test_size], generator=generator ) - elif dataset == "HFDataset": + elif dataset == "PolarLitis": split_test = "test" if config.huggingface.split_seed is not None: @@ -120,6 +120,7 @@ def benchmark_recon(config): huggingface_repo=config.huggingface.repo, cache_dir=config.huggingface.cache_dir, psf=config.huggingface.psf, + mask = config.huggingface.mask, n_files=n_files, split=split_test, display_res=config.huggingface.image_res, @@ -138,6 +139,8 @@ def benchmark_recon(config): psf = benchmark_dataset.psf[first_psf_key].to(device) else: psf = benchmark_dataset.psf.to(device) + mask = benchmark_dataset.mask.to(device) + else: raise ValueError(f"Dataset {dataset} not supported") @@ -190,6 +193,8 @@ def benchmark_recon(config): ) if algo == "FISTA": model_list.append(("FISTA", FISTA(psf, tk=config.fista.tk))) + if algo == "HyperSpectralFISTA": + model_list.append(("HyperSpectralFISTA", HyperSpectralFISTA(psf,mask, tk=config.fista.tk))) if algo == "GradientDescent": model_list.append(("GradientDescent", GradientDescent(psf))) if algo == "NesterovGradientDescent": @@ -243,7 +248,7 @@ def benchmark_recon(config): :2 ] # take first two in case multimask dataset ground_truth_np = ground_truth.cpu().numpy()[0] - lensless_np = lensless.cpu().numpy()[0] + lensless_np = lensless.cpu().numpy() if crop is not None: ground_truth_np = ground_truth_np[