Skip to content

Commit

Permalink
IRFFT fix (#126)
Browse files Browse the repository at this point in the history
* Add output shape to irfft.

* Expose fixed parameter.

* Update CHANGELOG.

* Fix assertion.

* Fix typo.
  • Loading branch information
ebezzam authored Apr 23, 2024
1 parent 718cadd commit 95aa24e
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Bugfix
- Local path for DRUNet download.
- APGD input handling (float32).
- Multimask handling.
- Passing shape to IRFFT so that it matches shape of input to RFFT.

1.0.6 - (2024-02-21)
--------------------
Expand Down
8 changes: 6 additions & 2 deletions lensless/recon/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,14 @@ def _image_update(self):

if self.is_torch:
freq_space_result = self._R_divmat * torch.fft.rfft2(rk, dim=(-3, -2))
self._image_est = torch.fft.irfft2(freq_space_result, dim=(-3, -2))
self._image_est = torch.fft.irfft2(
freq_space_result, dim=(-3, -2), s=self._convolver._padded_shape[-3:-1]
)
else:
freq_space_result = self._R_divmat * fft.rfft2(rk, axes=(-3, -2))
self._image_est = fft.irfft2(freq_space_result, axes=(-3, -2))
self._image_est = fft.irfft2(
freq_space_result, axes=(-3, -2), s=self._convolver._padded_shape[-3:-1]
)

# self._image_est = self._convolver._crop(res)

Expand Down
37 changes: 27 additions & 10 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,28 @@ def convolve(self, x):
if self.is_torch:
conv_output = torch.fft.ifftshift(
torch.fft.irfft2(
torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._H, dim=(-3, -2)
torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._H,
dim=(-3, -2),
s=self._padded_shape[-3:-1],
),
dim=(-3, -2),
)

else:
conv_output = fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(-3, -2)) * self._H, axes=(-3, -2)),
fft.irfft2(
fft.rfft2(self._padded_data, axes=(-3, -2)) * self._H,
axes=(-3, -2),
s=self._padded_shape[-3:-1],
),
axes=(-3, -2),
)
if self.pad:
return self._crop(conv_output)
else:
return conv_output
conv_output = self._crop(conv_output)

# ensure shape stays the same
assert conv_output.shape[-3:-1] == x.shape[-3:-1]
return conv_output

def deconvolve(self, y):
"""
Expand All @@ -165,21 +173,30 @@ def deconvolve(self, y):
self._padded_data = y # .type(self.dtype).to(self._psf.device)
else:
self._padded_data[:] = y # .astype(self.dtype)

if self.is_torch:
deconv_output = torch.fft.ifftshift(
torch.fft.irfft2(
torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._Hadj, dim=(-3, -2)
torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._Hadj,
dim=(-3, -2),
s=self._padded_shape[-3:-1],
),
dim=(-3, -2),
)

else:
deconv_output = fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(-3, -2)) * self._Hadj, axes=(-3, -2)),
fft.irfft2(
fft.rfft2(self._padded_data, axes=(-3, -2)) * self._Hadj,
axes=(-3, -2),
s=self._padded_shape[-3:-1],
),
axes=(-3, -2),
)

if self.pad:
return self._crop(deconv_output)
else:
return deconv_output
deconv_output = self._crop(deconv_output)

# ensure shape stays the same
assert deconv_output.shape[-3:-1] == y.shape[-3:-1]
return deconv_output
4 changes: 3 additions & 1 deletion lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def _image_update(self, iter):
+ self._convolver.deconvolve(self._mu1[iter] * self._X - self._xi)
)
freq_space_result = self._R_divmat[iter] * torch.fft.rfft2(rk, dim=(-3, -2))
self._image_est = torch.fft.irfft2(freq_space_result, dim=(-3, -2))
self._image_est = torch.fft.irfft2(
freq_space_result, dim=(-3, -2), s=self._convolver._padded_shape[-3:-1]
)

def _W_update(self, iter):
"""Non-negativity update"""
Expand Down
4 changes: 1 addition & 3 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def __init__(
psf="psf.tiff",
downsample=2,
flip_ud=True,
dtype="float32",
**kwargs,
):
"""
Expand All @@ -981,9 +982,6 @@ def __init__(
If True, data is flipped up-down, by default ``True``. Otherwise data is upside-down.
"""

# fixed parameters
dtype = "float32"

# get dataset
self.dataset = load_dataset(repo_id, split=split)

Expand Down

0 comments on commit 95aa24e

Please sign in to comment.