Skip to content

Commit

Permalink
added gaussian low pass filter to image utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jan 24, 2024
1 parent a8ff67b commit a19e388
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,44 @@ def __call__(self, image: torch.Tensor) -> torch.Tensor:
return reconstructed


class GaussianLowPassFilter:
"""
Low pass filter by dampening the outer frequencies with a Gaussian.
"""

def __init__(self, image_size: int, sigma: int):
self._image_size = image_size
self._sigma = sigma
self._grid = torch.linspace(-0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size)
self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2
self._mask = torch.exp(-self._r_2d / (2 * sigma ** 2))

def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Low pass filter an image by dampening the outer frequencies with a Gaussian.
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels) or (n_channels, n_pixels, n_pixels).
Returns:
reconstructed (torch.Tensor): Low pass filtered image.
"""

fft_image = torch.fft.fft2(image)
fft_image = torch.fft.fftshift(fft_image)

if len(image.shape) == 2:
fft_image = fft_image * self._mask
elif len(image.shape) == 3:
fft_image = fft_image * self._mask.unsqueeze(0)
else:
raise NotImplementedError

fft_image = torch.fft.fftshift(fft_image)
reconstructed = torch.fft.ifft2(fft_image).real
return reconstructed


class NormalizeIndividual:
"""
Normalize an image by subtracting the mean and dividing by the standard deviation.
Expand Down

0 comments on commit a19e388

Please sign in to comment.