Skip to content

Commit

Permalink
Ensure backward compatability.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 6, 2024
1 parent 47d4a75 commit 68cc8d5
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 83 deletions.
133 changes: 66 additions & 67 deletions lensless/recon/gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm):
Object for applying projected gradient descent.
"""

def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs):
def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):
"""
Parameters
Expand All @@ -83,30 +83,30 @@ def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs):

assert callable(proj)
self._proj = proj
super(GradientDescent, self).__init__(psf,mask, dtype, **kwargs)
super(GradientDescent, self).__init__(psf, dtype, **kwargs)

if self._denoiser is not None:
print("Using denoiser in gradient descent.")
# redefine projection function
self._proj = self._denoiser
self.mask=mask

def reset(self):
if self.is_torch:
if self._initial_est is not None:
self._image_est = self._initial_est
else:
# initial guess, half intensity image
# psf_flat = self._psf.reshape(-1, self._psf_shape[3])
# pixel_start = (
# torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
# ) / 2
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (
torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device)
self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = 1/4770.13
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)

else:
if self._initial_est is not None:
Expand All @@ -123,8 +123,8 @@ def reset(self):
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))

def _grad(self):
diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1)
return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels
diff = self._convolver.convolve(self._image_est) - self._data
return self._convolver.deconvolve(diff)

def _update(self, iter):
self._image_est -= self._alpha * self._grad()
Expand Down Expand Up @@ -238,76 +238,75 @@ def _update(self, iter):
self._xk = xk


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs)

# create reconstruction object
recon = GradientDescent(psf, n_iter=n_iter, proj=proj)

# set data
recon.set_data(data)

# perform reconstruction
start_time = time.time()
res = recon.apply(plot=False)
proc_time = time.time() - start_time

if verbose:
print(f"Reconstruction time : {proc_time} s")
print(f"Reconstruction shape: {res.shape}")
return res
class HyperSpectralFISTA(GradientDescent):
class HyperSpectralFISTA(FISTA):
"""
Object for applying projected gradient descent with FISTA (Fast Iterative
Shrinkage-Thresholding Algorithm) for acceleration.
Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA
Applying HyperSpectral FISTA as in: https://github.com/Waller-Lab/SpectralDiffuserCam
"""

def __init__(self, psf,mask, dtype=None, proj=non_neg, tk=1.0, **kwargs):
def __init__(self, psf, mask, **kwargs):
"""
Parameters
----------
psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor`
Point spread function (PSF) that models forward propagation.
Must be of shape (depth, height, width, channels) even if
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf`
to load a PSF from a file such that it is in the correct format.
dtype : float32 or float64
Data type to use for optimization. Default is float32.
proj : :py:class:`function`
Projection function to apply at each iteration. Default is
non-negative.
tk : float
Initial step size parameter for FISTA. It is updated at each iteration
according to Eq. 4.2 of paper. By default, initialized to 1.0.
mask :
Hyperspectral mask
"""
self._initial_tk = tk
# same PSF for all hyperspectral channels
assert psf.shape[-1] == 1
assert mask.shape[-3:-1] == psf.shape[-3:-1]
self._mask = mask[None, ...] # adding batch dimension

super(HyperSpectralFISTA, self).__init__(psf,mask, dtype, proj, **kwargs)
super(HyperSpectralFISTA, self).__init__(psf, **kwargs)

self._tk = tk
self._xk = self._image_est
def reset(self):

# TODO set lipschitz constant correctly/differently?

if self.is_torch:
if self._initial_est is not None:
self._image_est = self._initial_est
else:
# initial guess, half intensity image
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (
torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = torch.ones_like(self._mask) * pixel_start

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)

def reset(self, tk=None):
super(HyperSpectralFISTA, self).reset()
if tk:
self._tk = tk
else:
self._tk = self._initial_tk
self._xk = self._image_est
def _update(self, iter):
self._image_est -= self._alpha * self._grad()
xk = self._form_image()
tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2
self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk)
self._tk = tk
self._xk = xk
if self._initial_est is not None:
self._image_est = self._initial_est
else:
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = np.ones_like(self._mask) * pixel_start

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))

# TODO how was his value determined?
self._alpha = 1 / 4770.13

def _grad(self):
# make sure to sum on correct axis, and apply mask on correct dimensions
diff = (
np.sum(self._mask * self._convolver.convolve(self._image_est), -1, keepdims=True)
- self._data
) # (B, D, H, W, 1)
return self._convolver.deconvolve(
diff * self._mask
) # (H, W, C) where C is number of hyperspectral channels


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):
Expand Down
17 changes: 6 additions & 11 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ class ReconstructionAlgorithm(abc.ABC):
def __init__(
self,
psf,
mask,
dtype=None,
pad=True,
n_iter=100,
Expand Down Expand Up @@ -370,13 +369,11 @@ def set_data(self, data):
assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]."

# assert same shapes
# assert np.all(
# self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
# ), "PSF and data shape mismatch"
if len(data.shape)==3:
self._data = data.unsqueeze(-1)
# if len(data.shape) == 3:
# self._data = data[None, None, ...]
assert np.all(
self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
), "PSF and data shape mismatch"
if len(data.shape) == 3:
self._data = data[None, None, ...]
elif len(data.shape) == 4:
self._data = data[None, ...]
else:
Expand Down Expand Up @@ -571,9 +568,7 @@ def apply(

for i in range(n_iter):
self._update(i)
if i%50==0:
img = self._form_image()


if self.compensation_branch is not None and i < self._n_iter - 1:
self.compensation_branch_inputs.append(self._form_image())

Expand Down
16 changes: 11 additions & 5 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,24 @@ def _crop(self, x):
]

def _pad(self, v):

shape = self._padded_shape.copy()
if v.shape[-1] != self._padded_shape[-1]:
# different number of channels in PSF and data
assert v.shape[-1] == 1 or self._padded_shape[-1] == 1
shape[-1] = v.shape[-1]

if len(v.shape) == 5:
batch_size = v.shape[0]
shape = [batch_size] + self._padded_shape
elif len(v.shape) == 4:
shape = self._padded_shape
else:
shape = [batch_size] + shape
elif len(v.shape) != 4:
raise ValueError("Expected 4D or 5D tensor")

if self.is_torch:
vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device)
else:
vpad = np.zeros(shape).astype(v.dtype)

vpad[
..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], :
] = v
Expand Down Expand Up @@ -135,7 +141,7 @@ def convolve(self, x):
Convolve with pre-computed FFT of provided PSF.
"""
if self.pad:
self._padded_data = self._pad(x).to(self._psf.device)
self._padded_data = self._pad(x)
else:
if self.is_torch:
self._padded_data = x
Expand Down
Loading

0 comments on commit 68cc8d5

Please sign in to comment.