Skip to content

Commit

Permalink
Merge pull request #1 from nkraicer/noa
Browse files Browse the repository at this point in the history
final
  • Loading branch information
noakraicer authored Aug 29, 2024
2 parents 96b292f + 5999d7c commit 47d4a75
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 31 deletions.
114 changes: 114 additions & 0 deletions configs/benchmark_hyperspectral.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# python scripts/eval/benchmark_recon.py
#Hydra config
hydra:
run:
dir: "benchmark/${now:%Y-%m-%d}/${now:%H-%M-%S}"
job:
chdir: True


dataset: PolarLitis # DiffuserCam, DigiCamCelebA, HFDataset
seed: 0
batchsize: 1 # must be 1 for iterative approaches

huggingface:
repo: "noakraicer/polarlitis"
cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`.
psf: psf.mat
mask: mask.npy # null for simulating PSF
image_res: [250, 250] # used during measurement
rotate: False # if measurement is upside-down
flipud: False
flip_lensed: False # if rotate or flipud is True, apply to lensed

alignment:
top_left: null
height: null

downsample: 1
downsample_lensed: 2
split_seed: null
single_channel_psf: True

device: "cuda"
# numbers of iterations to benchmark
n_iter_range: [2000]
# number of files to benchmark
n_files: null # null for all files
#How much should the image be downsampled
downsample: 2
#algorithm to benchmark
algorithms: ["HyperSpectralFISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"]

# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502
baseline: "MONAKHOVA 100iter"

save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10]
gamma_psf: 1.5 # gamma factor for PSF


# Hyperparameters
nesterov:
p: 0
mu: 0.9
fista:
tk: 1
admm:
mu1: 1e-6
mu2: 1e-5
mu3: 4e-5
tau: 0.0001


# for DigiCamCelebA
files:
test_size: 0.15
downsample: 1
celeba_root: /scratch/bezzam


# dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
# psf: data/psf/adafruit_random_2mm_20231907.png
# vertical_shift: null
# horizontal_shift: null
# crop: null

dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K
psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png
vertical_shift: -117
horizontal_shift: -25
crop:
vertical: [0, 525]
horizontal: [265, 695]

# for prepping ground truth data
#for simulated dataset
simulation:
grayscale: False
output_dim: null # should be set if no PSF is used
# random variations
object_height: 0.33 # [m], range for random height or scalar
flip: True # change the orientation of the object (from vertical to horizontal)
random_shift: False
random_vflip: 0.5
random_hflip: 0.5
random_rotate: False
# these distance parameters are typically fixed for a given PSF
# for DiffuserCam psf # for tape_rgb psf
# scene2mask: 10e-2 # scene2mask: 40e-2
# mask2sensor: 9e-3 # mask2sensor: 4e-3
# -- for CelebA
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
deadspace: True # whether to account for deadspace for programmable mask
# see waveprop.devices
use_waveprop: False # for PSF simulation
sensor: "rpi_hq"
snr_db: 10
# simulate different sensor resolution
# output_dim: [24, 32] # [H, W] or null
# Downsampling for PSF
downsample: 8
# max val in simulated measured (quantized 8 bits)
quantize: False # must be False for differentiability
max_val: 255
1 change: 1 addition & 0 deletions lensless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NesterovGradientDescent,
FISTA,
GradientDescentUpdate,
HyperSpectralFISTA
)
from .recon.tikhonov import CodedApertureReconstruction
from .hardware.sensor import VirtualSensor, SensorOptions
Expand Down
94 changes: 83 additions & 11 deletions lensless/recon/gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm):
Object for applying projected gradient descent.
"""

def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):
def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs):
"""
Parameters
Expand All @@ -83,30 +83,30 @@ def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):

assert callable(proj)
self._proj = proj
super(GradientDescent, self).__init__(psf, dtype, **kwargs)
super(GradientDescent, self).__init__(psf,mask, dtype, **kwargs)

if self._denoiser is not None:
print("Using denoiser in gradient descent.")
# redefine projection function
self._proj = self._denoiser

self.mask=mask
def reset(self):
if self.is_torch:
if self._initial_est is not None:
self._image_est = self._initial_est
else:
# initial guess, half intensity image
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (
torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
) / 2
# psf_flat = self._psf.reshape(-1, self._psf_shape[3])
# pixel_start = (
# torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
# ) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start
self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device)

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)
self._alpha = 1/4770.13

else:
if self._initial_est is not None:
Expand All @@ -123,8 +123,8 @@ def reset(self):
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))

def _grad(self):
diff = self._convolver.convolve(self._image_est) - self._data
return self._convolver.deconvolve(diff)
diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1)
return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels

def _update(self, iter):
self._image_est -= self._alpha * self._grad()
Expand Down Expand Up @@ -238,6 +238,78 @@ def _update(self, iter):
self._xk = xk


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs)

# create reconstruction object
recon = GradientDescent(psf, n_iter=n_iter, proj=proj)

# set data
recon.set_data(data)

# perform reconstruction
start_time = time.time()
res = recon.apply(plot=False)
proc_time = time.time() - start_time

if verbose:
print(f"Reconstruction time : {proc_time} s")
print(f"Reconstruction shape: {res.shape}")
return res
class HyperSpectralFISTA(GradientDescent):
"""
Object for applying projected gradient descent with FISTA (Fast Iterative
Shrinkage-Thresholding Algorithm) for acceleration.
Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA
"""

def __init__(self, psf,mask, dtype=None, proj=non_neg, tk=1.0, **kwargs):
"""
Parameters
----------
psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor`
Point spread function (PSF) that models forward propagation.
Must be of shape (depth, height, width, channels) even if
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf`
to load a PSF from a file such that it is in the correct format.
dtype : float32 or float64
Data type to use for optimization. Default is float32.
proj : :py:class:`function`
Projection function to apply at each iteration. Default is
non-negative.
tk : float
Initial step size parameter for FISTA. It is updated at each iteration
according to Eq. 4.2 of paper. By default, initialized to 1.0.
"""
self._initial_tk = tk

super(HyperSpectralFISTA, self).__init__(psf,mask, dtype, proj, **kwargs)

self._tk = tk
self._xk = self._image_est

def reset(self, tk=None):
super(HyperSpectralFISTA, self).reset()
if tk:
self._tk = tk
else:
self._tk = self._initial_tk
self._xk = self._image_est
def _update(self, iter):
self._image_est -= self._alpha * self._grad()
xk = self._form_image()
tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2
self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk)
self._tk = tk
self._xk = xk


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
Expand Down
17 changes: 11 additions & 6 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class ReconstructionAlgorithm(abc.ABC):
def __init__(
self,
psf,
mask,
dtype=None,
pad=True,
n_iter=100,
Expand Down Expand Up @@ -369,12 +370,13 @@ def set_data(self, data):
assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]."

# assert same shapes
assert np.all(
self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
), "PSF and data shape mismatch"

if len(data.shape) == 3:
self._data = data[None, None, ...]
# assert np.all(
# self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
# ), "PSF and data shape mismatch"
if len(data.shape)==3:
self._data = data.unsqueeze(-1)
# if len(data.shape) == 3:
# self._data = data[None, None, ...]
elif len(data.shape) == 4:
self._data = data[None, ...]
else:
Expand Down Expand Up @@ -569,6 +571,9 @@ def apply(

for i in range(n_iter):
self._update(i)
if i%50==0:
img = self._form_image()

if self.compensation_branch is not None and i < self._n_iter - 1:
self.compensation_branch_inputs.append(self._form_image())

Expand Down
6 changes: 3 additions & 3 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class RealFFTConvolve2D:
def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs):
def __init__(self, psf, dtype=None, pad=True, norm=None, rgb=None, **kwargs):
"""
Linear operator that performs convolution in Fourier domain, and assumes
real-valued signals.
Expand Down Expand Up @@ -135,10 +135,10 @@ def convolve(self, x):
Convolve with pre-computed FFT of provided PSF.
"""
if self.pad:
self._padded_data = self._pad(x)
self._padded_data = self._pad(x).to(self._psf.device)
else:
if self.is_torch:
self._padded_data = x # .type(self.dtype).to(self._psf.device)
self._padded_data = x
else:
self._padded_data[:] = x # .astype(self.dtype)

Expand Down
13 changes: 10 additions & 3 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.transforms import functional as F
from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD
from lensless.utils.simulation import FarFieldSimulator
from lensless.utils.io import load_image, load_psf, save_image
from lensless.utils.io import load_image, load_psf, save_image,load_mask
from lensless.utils.image import is_grayscale, resize, rgb2gray
import re
from lensless.hardware.utils import capture
Expand Down Expand Up @@ -1271,6 +1271,7 @@ def __init__(
split,
n_files=None,
psf=None,
mask=None,
rotate=False, # just the lensless image
flipud=False,
flip_lensed=False,
Expand Down Expand Up @@ -1409,11 +1410,11 @@ def __init__(
if psf is not None:
# download PSF from huggingface
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf = load_psf(
psf_fp,
shape=lensless.shape,
return_float=True,
return_bg=True,
return_bg=False,
flip=self.rotate,
flip_ud=flipud,
bg_pix=(0, 15),
Expand All @@ -1424,6 +1425,10 @@ def __init__(
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)
if mask is not None:
mask_fp = hf_hub_download(repo_id=huggingface_repo, filename=mask, repo_type="dataset")
mask = load_mask(mask_fp)
self.mask= torch.from_numpy(mask)

elif "mask_label" in data_0:
self.multimask = True
Expand Down Expand Up @@ -1563,7 +1568,9 @@ def _get_images_pair(self, idx):
# convert to float
if lensless_np.dtype == np.uint8:
lensless_np = lensless_np.astype(np.float32) / 255
lensless_np = lensless_np / np.max(lensless_np)
lensed_np = lensed_np.astype(np.float32) / 255
lensed_np = lensed_np / np.max(lensed_np)
else:
# 16 bit
lensless_np = lensless_np.astype(np.float32) / 65535
Expand Down
2 changes: 1 addition & 1 deletion lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def get_max_val(img, nbits=None):
max_val : int
Maximum pixel value.
"""
assert img.dtype not in FLOAT_DTYPES
# assert img.dtype not in FLOAT_DTYPES
if nbits is None:
nbits = int(np.ceil(np.log2(img.max())))

Expand Down
Loading

0 comments on commit 47d4a75

Please sign in to comment.