diff --git a/src/cryo_challenge/_preprocessing/bfactor_normalize.py b/src/cryo_challenge/_preprocessing/bfactor_normalize.py index 2af5e98..37d876b 100644 --- a/src/cryo_challenge/_preprocessing/bfactor_normalize.py +++ b/src/cryo_challenge/_preprocessing/bfactor_normalize.py @@ -3,6 +3,25 @@ def _compute_bfactor_scaling(b_factor, box_size, voxel_size): + """ + Compute the B-factor scaling factor for a given B-factor, box size, and voxel size. + The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared + distance in Fourier space. + + Parameters + ---------- + b_factor: float + B-factor to apply. + box_size: int + Size of the box. + voxel_size: float + Voxel size of the box. + + Returns + ------- + bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size)) + B-factor scaling factor. + """ x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) y = x.clone() z = x.clone() @@ -14,7 +33,44 @@ def _compute_bfactor_scaling(b_factor, box_size, voxel_size): return bfactor_scaling_torch -def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): +def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False): + """ + Normalize volumes by applying a B-factor correction. This is done by multiplying + a centered Fourier transform of the volume by the B-factor scaling factor and then + applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the + computation of the B-factor scaling. + + Parameters + ---------- + volumes: torch.tensor + Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N). + bfactor: float + B-factor to apply. + voxel_size: float + Voxel size of the volumes. + in_place: bool - default: False + Whether to normalize the volumes in place. + + Returns + ------- + volumes: torch.tensor + Normalized volumes. + """ + # assert that volumes have the correct shape + assert volumes.ndim in [ + 3, + 4, + ], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)" + + if volumes.ndim == 3: + assert ( + volumes.shape[0] == volumes.shape[1] == volumes.shape[2] + ), "Input volumes must have equal dimensions" + else: + assert ( + volumes.shape[1] == volumes.shape[2] == volumes.shape[3] + ), "Input volumes must have equal dimensions" + if not in_place: volumes = volumes.clone() @@ -30,7 +86,4 @@ def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): volumes = volumes * b_factor_scaling[None, ...] volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real - else: - raise ValueError("Input volumes must have 3 or 4 dimensions.") - return volumes diff --git a/src/cryo_challenge/power_spectrum_utils.py b/src/cryo_challenge/power_spectrum_utils.py index 0a6338f..afae5bc 100644 --- a/src/cryo_challenge/power_spectrum_utils.py +++ b/src/cryo_challenge/power_spectrum_utils.py @@ -21,6 +21,22 @@ def _cart2sph(x, y, z): def _grid_3d(n, dtype=torch.float32): + """ + Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates. + + Parameters + ---------- + n: int + Size of the grid. + dtype: torch.dtype + Data type of the grid. + + Returns + ------- + grid: dict + Dictionary containing the grid in cartesian and spherical coordinates. + keys: x, y, z, phi, theta, r + """ start = -n // 2 + 1 end = n // 2 @@ -39,33 +55,99 @@ def _grid_3d(n, dtype=torch.float32): def _centered_fftn(x, dim=None): + """ + Wrapper around torch.fft.fftn that centers the Fourier transform. + """ x = torch.fft.fftn(x, dim=dim) x = torch.fft.fftshift(x, dim=dim) return x def _centered_ifftn(x, dim=None): + """ + Wrapper around torch.fft.ifftn that centers the inverse Fourier transform. + """ x = torch.fft.fftshift(x, dim=dim) x = torch.fft.ifftn(x, dim=dim) return x -def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5): - inner_diameter = shell_width + index - outer_diameter = shell_width + (index + 1) +def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5): + """ + Given a volume in Fourier space, compute the average value of the volume over a shell. + + Parameters + ---------- + shell_index: int + Index of the shell in Fourier space. + volume: torch.tensor + Volume in Fourier space. + radii: torch.tensor + Radii of the Fourier space grid. + shell_width: float + Width of the shell. + + Returns + ------- + average: float + Average value of the volume over the shell. + """ + inner_diameter = shell_width + shell_index + outer_diameter = shell_width + (shell_index + 1) mask = (radii > inner_diameter) & (radii < outer_diameter) return torch.sum(mask * volume) / torch.sum(mask) -def compute_power_spectrum(volume, shell_width=0.5): - L = volume.shape[0] +def _average_over_shells(volume_in_fourier_space, shell_width=0.5): + """ + Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space. + + Parameters + ---------- + volume_in_fourier_space: torch.tensor + Volume in Fourier space. + + Returns + ------- + radial_average: torch.tensor + Average value of the volume over all shells. + """ + L = volume_in_fourier_space.shape[0] dtype = torch.float32 radii = _grid_3d(L, dtype=dtype)["r"] + radial_average = torch.vmap( + _average_over_single_shell, in_dims=(0, None, None, None) + )(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width) + + return radial_average + + +def compute_power_spectrum(volume, shell_width=0.5): + """ + Compute the power spectrum of a volume. + + Parameters + ---------- + volume: torch.tensor + Volume for which to compute the power spectrum. + shell_width: float + Width of the shell. + + Returns + ------- + power_spectrum: torch.tensor + Power spectrum of the volume. + + Examples + -------- + volume = mrcfile.open("volume.mrc").data.copy() + volume = torch.tensor(volume, dtype=torch.float32) + power_spectrum = compute_power_spectrum(volume) + """ + # Compute centered Fourier transforms. vol_fft = torch.abs(_centered_fftn(volume)) ** 2 + power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width) - power_spectrum = torch.vmap( - _compute_power_spectrum_shell, in_dims=(0, None, None, None) - )(torch.arange(0, L // 2), vol_fft, radii, shell_width) return power_spectrum diff --git a/tests/test_power_spectrum_and_bfactor.py b/tests/test_power_spectrum_and_bfactor.py index 8496ba2..218632b 100644 --- a/tests/test_power_spectrum_and_bfactor.py +++ b/tests/test_power_spectrum_and_bfactor.py @@ -7,6 +7,12 @@ def test_compute_power_spectrum(): + """ + Test the computation of the power spectrum of a radially symmetric Gaussian volume. + Since the volume is radially symmetric, the power spectrum of the whole volume should be + approximately the power spectrum in a central slice. The computation is not exact as our + averaging over shells is approximated. + """ box_size = 224 volume_shape = (box_size, box_size, box_size) voxel_size = 1.073 * 2 @@ -37,6 +43,13 @@ def test_compute_power_spectrum(): def test_bfactor_normalize_volumes(): + """ + Similarly to the other test, we test the normalization of a radially symmetric volume. + In this case we test with an oscillatory volume, which is a volume with a sinusoidal. + Since both the b-factor correction volume and the volume are radially symmetric, the + power spectrum of the normalized volume should be the same as the power spectrum of + a normalized central slice + """ box_size = 128 volume_shape = (box_size, box_size, box_size) voxel_size = 1.5