Skip to content

Commit

Permalink
added new method for whitening images
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 12, 2024
1 parent e4e4d51 commit 83d90b9
Showing 1 changed file with 50 additions and 18 deletions.
68 changes: 50 additions & 18 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,36 +281,68 @@ def __call__(self, image_path: str) -> torch.Tensor:
return mrc_to_tensor(image_path)


class WhitenImage:
def estimate_noise_psd(images: torch.Tensor, image_size: int, mask_radius : Union[int, None] = None) -> torch.Tensor:
"""
Whiten an image by dividing by the noise PSD.
Estimates the power spectral density (PSD) of the noise in a set of images.
Args:
noise_psd (torch.Tensor): Noise PSD of shape (n_pixels, n_pixels).
images (torch.Tensor): A tensor containing the input images. The shape of the tensor should be (N, H, W),
where N is the number of images, H is the height, and W is the width.
Returns:
torch.Tensor: A tensor containing the estimated PSD of the noise. The shape of the tensor is (H, W), where H is the height
and W is the width of the images.
"""
if mask_radius is None:
mask_radius = image_size // 2
mask = circular_mask(image_size, mask_radius, inside=False)
denominator = mask.sum() * images.shape[0]
images_masked = images * mask
mean_est = images_masked.sum() / denominator
image_masked_fft = torch.fft.fft2(images_masked)
noise_psd_est = torch.sum(torch.abs(image_masked_fft)**2, dim=[0]) / denominator
noise_psd_est[image_size // 2, image_size // 2] -= mean_est

return noise_psd_est


def __init__(self, noise_psd: torch.Tensor) -> None:
self._noise_psd = noise_psd
class WhitenImage:
"""
Whiten an image by dividing by the square root of the noise PSD.
Args:
image_size (int): Size of image in pixels.
mask_radius (int, optional): Radius of the mask. Defaults to None.
"""

def __init__(self, image_size: int, mask_radius: Union[int, None] = None) -> None:
self.image_size = image_size
self.mask_radius = mask_radius

def _estimate_noise_psd(self, images: torch.Tensor) -> torch.Tensor:
"""
Estimates the power spectral density (PSD) of the noise in a set of images.
"""
noise_psd = estimate_noise_psd(images, self.image_size, self.mask_radius)
return noise_psd

def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Whiten an image by dividing by the noise PSD.
Whiten an image by dividing by the square root of the noise PSD.
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels).
Returns:
reconstructed (torch.Tensor): Whiten image.
image (torch.Tensor): Whitened image.
"""

fft_image = torch.fft.fft2(image)
if image.ndim == 3:
fft_image = fft_image / torch.sqrt(self._noise_psd.unsqueeze(0))
elif image.ndim == 2:
fft_image = fft_image / torch.sqrt(self._noise_psd)
reconstructed = torch.fft.ifft2(fft_image).real

return reconstructed
noise_psd = self._estimate_noise_psd(image) ** -0.5
image_fft = torch.fft.fft2(image)
image_fft = image_fft * noise_psd
image = torch.fft.ifft2(image_fft).real
return image


class MRCdataset:
Expand Down

0 comments on commit 83d90b9

Please sign in to comment.