From a19e38880de91a6312339670fd82b34e4f630c6c Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Wed, 24 Jan 2024 11:08:39 +0100 Subject: [PATCH] added gaussian low pass filter to image utils --- src/cryo_sbi/utils/image_utils.py | 38 +++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/cryo_sbi/utils/image_utils.py b/src/cryo_sbi/utils/image_utils.py index b4b0f12..fb4a6cc 100644 --- a/src/cryo_sbi/utils/image_utils.py +++ b/src/cryo_sbi/utils/image_utils.py @@ -171,6 +171,44 @@ def __call__(self, image: torch.Tensor) -> torch.Tensor: return reconstructed +class GaussianLowPassFilter: + """ + Low pass filter by dampening the outer frequencies with a Gaussian. + """ + + def __init__(self, image_size: int, sigma: int): + self._image_size = image_size + self._sigma = sigma + self._grid = torch.linspace(-0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size) + self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2 + self._mask = torch.exp(-self._r_2d / (2 * sigma ** 2)) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Low pass filter an image by dampening the outer frequencies with a Gaussian. + + Args: + image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels). + + Returns: + reconstructed (torch.Tensor): Low pass filtered image. + """ + + fft_image = torch.fft.fft2(image) + fft_image = torch.fft.fftshift(fft_image) + + if len(image.shape) == 2: + fft_image = fft_image * self._mask + elif len(image.shape) == 3: + fft_image = fft_image * self._mask.unsqueeze(0) + else: + raise NotImplementedError + + fft_image = torch.fft.fftshift(fft_image) + reconstructed = torch.fft.ifft2(fft_image).real + return reconstructed + + class NormalizeIndividual: """ Normalize an image by subtracting the mean and dividing by the standard deviation.