diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bd0508d0..136545f2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -40,6 +40,7 @@ Bugfix - Local path for DRUNet download. - APGD input handling (float32). - Multimask handling. +- Passing shape to IRFFT so that it matches shape of input to RFFT. 1.0.6 - (2024-02-21) -------------------- diff --git a/lensless/recon/admm.py b/lensless/recon/admm.py index b5a781d6..bd50d6db 100644 --- a/lensless/recon/admm.py +++ b/lensless/recon/admm.py @@ -282,10 +282,14 @@ def _image_update(self): if self.is_torch: freq_space_result = self._R_divmat * torch.fft.rfft2(rk, dim=(-3, -2)) - self._image_est = torch.fft.irfft2(freq_space_result, dim=(-3, -2)) + self._image_est = torch.fft.irfft2( + freq_space_result, dim=(-3, -2), s=self._convolver._padded_shape[-3:-1] + ) else: freq_space_result = self._R_divmat * fft.rfft2(rk, axes=(-3, -2)) - self._image_est = fft.irfft2(freq_space_result, axes=(-3, -2)) + self._image_est = fft.irfft2( + freq_space_result, axes=(-3, -2), s=self._convolver._padded_shape[-3:-1] + ) # self._image_est = self._convolver._crop(res) diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index c5178491..5518a651 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -139,20 +139,28 @@ def convolve(self, x): if self.is_torch: conv_output = torch.fft.ifftshift( torch.fft.irfft2( - torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._H, dim=(-3, -2) + torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._H, + dim=(-3, -2), + s=self._padded_shape[-3:-1], ), dim=(-3, -2), ) else: conv_output = fft.ifftshift( - fft.irfft2(fft.rfft2(self._padded_data, axes=(-3, -2)) * self._H, axes=(-3, -2)), + fft.irfft2( + fft.rfft2(self._padded_data, axes=(-3, -2)) * self._H, + axes=(-3, -2), + s=self._padded_shape[-3:-1], + ), axes=(-3, -2), ) if self.pad: - return self._crop(conv_output) - else: - return conv_output + conv_output = self._crop(conv_output) + + # ensure shape stays the same + assert conv_output.shape[-3:-1] == x.shape[-3:-1] + return conv_output def deconvolve(self, y): """ @@ -165,21 +173,30 @@ def deconvolve(self, y): self._padded_data = y # .type(self.dtype).to(self._psf.device) else: self._padded_data[:] = y # .astype(self.dtype) + if self.is_torch: deconv_output = torch.fft.ifftshift( torch.fft.irfft2( - torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._Hadj, dim=(-3, -2) + torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._Hadj, + dim=(-3, -2), + s=self._padded_shape[-3:-1], ), dim=(-3, -2), ) else: deconv_output = fft.ifftshift( - fft.irfft2(fft.rfft2(self._padded_data, axes=(-3, -2)) * self._Hadj, axes=(-3, -2)), + fft.irfft2( + fft.rfft2(self._padded_data, axes=(-3, -2)) * self._Hadj, + axes=(-3, -2), + s=self._padded_shape[-3:-1], + ), axes=(-3, -2), ) if self.pad: - return self._crop(deconv_output) - else: - return deconv_output + deconv_output = self._crop(deconv_output) + + # ensure shape stays the same + assert deconv_output.shape[-3:-1] == y.shape[-3:-1] + return deconv_output diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index b20eaa07..2447aa4b 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -198,7 +198,9 @@ def _image_update(self, iter): + self._convolver.deconvolve(self._mu1[iter] * self._X - self._xi) ) freq_space_result = self._R_divmat[iter] * torch.fft.rfft2(rk, dim=(-3, -2)) - self._image_est = torch.fft.irfft2(freq_space_result, dim=(-3, -2)) + self._image_est = torch.fft.irfft2( + freq_space_result, dim=(-3, -2), s=self._convolver._padded_shape[-3:-1] + ) def _W_update(self, iter): """Non-negativity update""" diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 15bb91f5..803d4166 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -968,6 +968,7 @@ def __init__( psf="psf.tiff", downsample=2, flip_ud=True, + dtype="float32", **kwargs, ): """ @@ -981,9 +982,6 @@ def __init__( If True, data is flipped up-down, by default ``True``. Otherwise data is upside-down. """ - # fixed parameters - dtype = "float32" - # get dataset self.dataset = load_dataset(repo_id, split=split)