diff --git a/src/cryo_sbi/utils/image_utils.py b/src/cryo_sbi/utils/image_utils.py index 4bab630..a4fa084 100644 --- a/src/cryo_sbi/utils/image_utils.py +++ b/src/cryo_sbi/utils/image_utils.py @@ -281,36 +281,68 @@ def __call__(self, image_path: str) -> torch.Tensor: return mrc_to_tensor(image_path) -class WhitenImage: +def estimate_noise_psd(images: torch.Tensor, image_size: int, mask_radius : Union[int, None] = None) -> torch.Tensor: """ - Whiten an image by dividing by the noise PSD. + Estimates the power spectral density (PSD) of the noise in a set of images. Args: - noise_psd (torch.Tensor): Noise PSD of shape (n_pixels, n_pixels). + images (torch.Tensor): A tensor containing the input images. The shape of the tensor should be (N, H, W), + where N is the number of images, H is the height, and W is the width. + + Returns: + torch.Tensor: A tensor containing the estimated PSD of the noise. The shape of the tensor is (H, W), where H is the height + and W is the width of the images. + """ + if mask_radius is None: + mask_radius = image_size // 2 + mask = circular_mask(image_size, mask_radius, inside=False) + denominator = mask.sum() * images.shape[0] + images_masked = images * mask + mean_est = images_masked.sum() / denominator + image_masked_fft = torch.fft.fft2(images_masked) + noise_psd_est = torch.sum(torch.abs(image_masked_fft)**2, dim=[0]) / denominator + noise_psd_est[image_size // 2, image_size // 2] -= mean_est + + return noise_psd_est + - def __init__(self, noise_psd: torch.Tensor) -> None: - self._noise_psd = noise_psd +class WhitenImage: + """ + Whiten an image by dividing by the square root of the noise PSD. + + Args: + image_size (int): Size of image in pixels. + mask_radius (int, optional): Radius of the mask. Defaults to None. + + """ + + def __init__(self, image_size: int, mask_radius: Union[int, None] = None) -> None: + self.image_size = image_size + self.mask_radius = mask_radius + def _estimate_noise_psd(self, images: torch.Tensor) -> torch.Tensor: + """ + Estimates the power spectral density (PSD) of the noise in a set of images. + """ + noise_psd = estimate_noise_psd(images, self.image_size, self.mask_radius) + return noise_psd + def __call__(self, image: torch.Tensor) -> torch.Tensor: """ - Whiten an image by dividing by the noise PSD. - + Whiten an image by dividing by the square root of the noise PSD. + Args: image (torch.Tensor): Image of shape (n_pixels, n_pixels). - + Returns: - reconstructed (torch.Tensor): Whiten image. + image (torch.Tensor): Whitened image. """ - - fft_image = torch.fft.fft2(image) - if image.ndim == 3: - fft_image = fft_image / torch.sqrt(self._noise_psd.unsqueeze(0)) - elif image.ndim == 2: - fft_image = fft_image / torch.sqrt(self._noise_psd) - reconstructed = torch.fft.ifft2(fft_image).real - - return reconstructed + noise_psd = self._estimate_noise_psd(image) ** -0.5 + image_fft = torch.fft.fft2(image) + image_fft = image_fft * noise_psd + image = torch.fft.ifft2(image_fft).real + return image class MRCdataset: