diff --git a/src/cryo_challenge/_preprocessing/bfactor_normalize.py b/src/cryo_challenge/_preprocessing/bfactor_normalize.py new file mode 100644 index 0000000..37d876b --- /dev/null +++ b/src/cryo_challenge/_preprocessing/bfactor_normalize.py @@ -0,0 +1,89 @@ +import torch +from ..power_spectrum_utils import _centered_fftn, _centered_ifftn + + +def _compute_bfactor_scaling(b_factor, box_size, voxel_size): + """ + Compute the B-factor scaling factor for a given B-factor, box size, and voxel size. + The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared + distance in Fourier space. + + Parameters + ---------- + b_factor: float + B-factor to apply. + box_size: int + Size of the box. + voxel_size: float + Voxel size of the box. + + Returns + ------- + bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size)) + B-factor scaling factor. + """ + x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + y = x.clone() + z = x.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) + + return bfactor_scaling_torch + + +def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False): + """ + Normalize volumes by applying a B-factor correction. This is done by multiplying + a centered Fourier transform of the volume by the B-factor scaling factor and then + applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the + computation of the B-factor scaling. + + Parameters + ---------- + volumes: torch.tensor + Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N). + bfactor: float + B-factor to apply. + voxel_size: float + Voxel size of the volumes. + in_place: bool - default: False + Whether to normalize the volumes in place. + + Returns + ------- + volumes: torch.tensor + Normalized volumes. + """ + # assert that volumes have the correct shape + assert volumes.ndim in [ + 3, + 4, + ], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)" + + if volumes.ndim == 3: + assert ( + volumes.shape[0] == volumes.shape[1] == volumes.shape[2] + ), "Input volumes must have equal dimensions" + else: + assert ( + volumes.shape[1] == volumes.shape[2] == volumes.shape[3] + ), "Input volumes must have equal dimensions" + + if not in_place: + volumes = volumes.clone() + + b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) + + if len(volumes.shape) == 3: + volumes = _centered_fftn(volumes, dim=(0, 1, 2)) + volumes = volumes * b_factor_scaling + volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real + + elif len(volumes.shape) == 4: + volumes = _centered_fftn(volumes, dim=(1, 2, 3)) + volumes = volumes * b_factor_scaling[None, ...] + volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real + + return volumes diff --git a/src/cryo_challenge/_preprocessing/normalize.py b/src/cryo_challenge/_preprocessing/normalize.py deleted file mode 100644 index 73449bf..0000000 --- a/src/cryo_challenge/_preprocessing/normalize.py +++ /dev/null @@ -1,22 +0,0 @@ -''' -TODO: Need to implement this properly - -def normalize_mean_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.mean(-1, keepdims=True)) / vols_flat.std( - -1, keepdims=True - ) - - -def normalize_median_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.median(-1, keepdims=True).values) / vols_flat.std( - -1, keepdims=True - ) -''' diff --git a/src/cryo_challenge/power_spectrum_utils.py b/src/cryo_challenge/power_spectrum_utils.py new file mode 100644 index 0000000..afae5bc --- /dev/null +++ b/src/cryo_challenge/power_spectrum_utils.py @@ -0,0 +1,153 @@ +import torch + + +def _cart2sph(x, y, z): + """ + Converts a grid in cartesian coordinates to spherical coordinates. + + Parameters + ---------- + x: torch.tensor + x-coordinate of the grid. + y: torch.tensor + y-coordinate of the grid. + z: torch.tensor + """ + hxy = torch.hypot(x, y) + r = torch.hypot(hxy, z) + el = torch.atan2(z, hxy) + az = torch.atan2(y, x) + return az, el, r + + +def _grid_3d(n, dtype=torch.float32): + """ + Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates. + + Parameters + ---------- + n: int + Size of the grid. + dtype: torch.dtype + Data type of the grid. + + Returns + ------- + grid: dict + Dictionary containing the grid in cartesian and spherical coordinates. + keys: x, y, z, phi, theta, r + """ + start = -n // 2 + 1 + end = n // 2 + + if n % 2 == 0: + start -= 1 / 2 + end -= 1 / 2 + + grid = torch.linspace(start, end, n, dtype=dtype) + z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") + + phi, theta, r = _cart2sph(x, y, z) + + theta = torch.pi / 2 - theta + + return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} + + +def _centered_fftn(x, dim=None): + """ + Wrapper around torch.fft.fftn that centers the Fourier transform. + """ + x = torch.fft.fftn(x, dim=dim) + x = torch.fft.fftshift(x, dim=dim) + return x + + +def _centered_ifftn(x, dim=None): + """ + Wrapper around torch.fft.ifftn that centers the inverse Fourier transform. + """ + x = torch.fft.fftshift(x, dim=dim) + x = torch.fft.ifftn(x, dim=dim) + return x + + +def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5): + """ + Given a volume in Fourier space, compute the average value of the volume over a shell. + + Parameters + ---------- + shell_index: int + Index of the shell in Fourier space. + volume: torch.tensor + Volume in Fourier space. + radii: torch.tensor + Radii of the Fourier space grid. + shell_width: float + Width of the shell. + + Returns + ------- + average: float + Average value of the volume over the shell. + """ + inner_diameter = shell_width + shell_index + outer_diameter = shell_width + (shell_index + 1) + mask = (radii > inner_diameter) & (radii < outer_diameter) + return torch.sum(mask * volume) / torch.sum(mask) + + +def _average_over_shells(volume_in_fourier_space, shell_width=0.5): + """ + Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space. + + Parameters + ---------- + volume_in_fourier_space: torch.tensor + Volume in Fourier space. + + Returns + ------- + radial_average: torch.tensor + Average value of the volume over all shells. + """ + L = volume_in_fourier_space.shape[0] + dtype = torch.float32 + radii = _grid_3d(L, dtype=dtype)["r"] + + radial_average = torch.vmap( + _average_over_single_shell, in_dims=(0, None, None, None) + )(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width) + + return radial_average + + +def compute_power_spectrum(volume, shell_width=0.5): + """ + Compute the power spectrum of a volume. + + Parameters + ---------- + volume: torch.tensor + Volume for which to compute the power spectrum. + shell_width: float + Width of the shell. + + Returns + ------- + power_spectrum: torch.tensor + Power spectrum of the volume. + + Examples + -------- + volume = mrcfile.open("volume.mrc").data.copy() + volume = torch.tensor(volume, dtype=torch.float32) + power_spectrum = compute_power_spectrum(volume) + """ + + # Compute centered Fourier transforms. + vol_fft = torch.abs(_centered_fftn(volume)) ** 2 + power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width) + + return power_spectrum diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index 1fb797d..e7669ab 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -8,13 +8,13 @@ }, "0": { "name": "raw_submission_in_testdata", - "align": 1, "flavor_name": "test flavor", + "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", + "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", + "submission_version": "1.0", "box_size": 32, "pixel_size": 15.022, - "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", "flip": 1, - "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", - "submission_version": "1.0" + "align": 1 } } diff --git a/tests/test_power_spectrum_and_bfactor.py b/tests/test_power_spectrum_and_bfactor.py new file mode 100644 index 0000000..218632b --- /dev/null +++ b/tests/test_power_spectrum_and_bfactor.py @@ -0,0 +1,85 @@ +import torch +from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum +from cryo_challenge._preprocessing.bfactor_normalize import ( + _compute_bfactor_scaling, + bfactor_normalize_volumes, +) + + +def test_compute_power_spectrum(): + """ + Test the computation of the power spectrum of a radially symmetric Gaussian volume. + Since the volume is radially symmetric, the power spectrum of the whole volume should be + approximately the power spectrum in a central slice. The computation is not exact as our + averaging over shells is approximated. + """ + box_size = 224 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.073 * 2 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + b_factor = 170 + + gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape) + gaussian_volume = _centered_ifftn(gaussian_volume) + + power_spectrum = compute_power_spectrum(gaussian_volume) + power_spectrum_slice = ( + torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2 + ) + + mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2) + + assert mean_squared_error < 1e-3 + + return + + +def test_bfactor_normalize_volumes(): + """ + Similarly to the other test, we test the normalization of a radially symmetric volume. + In this case we test with an oscillatory volume, which is a volume with a sinusoidal. + Since both the b-factor correction volume and the volume are radially symmetric, the + power spectrum of the normalized volume should be the same as the power spectrum of + a normalized central slice + """ + box_size = 128 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.5 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape) + oscillatory_volume = _centered_ifftn(oscillatory_volume) + bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size) + + norm_oscillatory_vol = bfactor_normalize_volumes( + oscillatory_volume, 170, voxel_size, in_place=False + ) + + ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[ + : box_size // 2, 0, 0 + ] + ps_norm_osci = torch.fft.fftn( + norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward" + )[: box_size // 2, 0, 0] + ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0] + + ps_osci = torch.abs(ps_osci) ** 2 + ps_norm_osci = torch.abs(ps_norm_osci) ** 2 + ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2 + + assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling)