Skip to content

Commit

Permalink
Merge pull request #82 from flatironinstitute/14-implement-power-spec…
Browse files Browse the repository at this point in the history
…trum-normalization

14 implement power spectrum normalization
  • Loading branch information
DSilva27 authored Aug 21, 2024
2 parents 21ca809 + 6726c47 commit cef000d
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 26 deletions.
89 changes: 89 additions & 0 deletions src/cryo_challenge/_preprocessing/bfactor_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn


def _compute_bfactor_scaling(b_factor, box_size, voxel_size):
"""
Compute the B-factor scaling factor for a given B-factor, box size, and voxel size.
The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared
distance in Fourier space.
Parameters
----------
b_factor: float
B-factor to apply.
box_size: int
Size of the box.
voxel_size: float
Voxel size of the box.
Returns
-------
bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size))
B-factor scaling factor.
"""
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
y = x.clone()
z = x.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4)

return bfactor_scaling_torch


def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False):
"""
Normalize volumes by applying a B-factor correction. This is done by multiplying
a centered Fourier transform of the volume by the B-factor scaling factor and then
applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the
computation of the B-factor scaling.
Parameters
----------
volumes: torch.tensor
Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N).
bfactor: float
B-factor to apply.
voxel_size: float
Voxel size of the volumes.
in_place: bool - default: False
Whether to normalize the volumes in place.
Returns
-------
volumes: torch.tensor
Normalized volumes.
"""
# assert that volumes have the correct shape
assert volumes.ndim in [
3,
4,
], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)"

if volumes.ndim == 3:
assert (
volumes.shape[0] == volumes.shape[1] == volumes.shape[2]
), "Input volumes must have equal dimensions"
else:
assert (
volumes.shape[1] == volumes.shape[2] == volumes.shape[3]
), "Input volumes must have equal dimensions"

if not in_place:
volumes = volumes.clone()

b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size)

if len(volumes.shape) == 3:
volumes = _centered_fftn(volumes, dim=(0, 1, 2))
volumes = volumes * b_factor_scaling
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real

elif len(volumes.shape) == 4:
volumes = _centered_fftn(volumes, dim=(1, 2, 3))
volumes = volumes * b_factor_scaling[None, ...]
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real

return volumes
22 changes: 0 additions & 22 deletions src/cryo_challenge/_preprocessing/normalize.py

This file was deleted.

153 changes: 153 additions & 0 deletions src/cryo_challenge/power_spectrum_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch


def _cart2sph(x, y, z):
"""
Converts a grid in cartesian coordinates to spherical coordinates.
Parameters
----------
x: torch.tensor
x-coordinate of the grid.
y: torch.tensor
y-coordinate of the grid.
z: torch.tensor
"""
hxy = torch.hypot(x, y)
r = torch.hypot(hxy, z)
el = torch.atan2(z, hxy)
az = torch.atan2(y, x)
return az, el, r


def _grid_3d(n, dtype=torch.float32):
"""
Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates.
Parameters
----------
n: int
Size of the grid.
dtype: torch.dtype
Data type of the grid.
Returns
-------
grid: dict
Dictionary containing the grid in cartesian and spherical coordinates.
keys: x, y, z, phi, theta, r
"""
start = -n // 2 + 1
end = n // 2

if n % 2 == 0:
start -= 1 / 2
end -= 1 / 2

grid = torch.linspace(start, end, n, dtype=dtype)
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij")

phi, theta, r = _cart2sph(x, y, z)

theta = torch.pi / 2 - theta

return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r}


def _centered_fftn(x, dim=None):
"""
Wrapper around torch.fft.fftn that centers the Fourier transform.
"""
x = torch.fft.fftn(x, dim=dim)
x = torch.fft.fftshift(x, dim=dim)
return x


def _centered_ifftn(x, dim=None):
"""
Wrapper around torch.fft.ifftn that centers the inverse Fourier transform.
"""
x = torch.fft.fftshift(x, dim=dim)
x = torch.fft.ifftn(x, dim=dim)
return x


def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5):
"""
Given a volume in Fourier space, compute the average value of the volume over a shell.
Parameters
----------
shell_index: int
Index of the shell in Fourier space.
volume: torch.tensor
Volume in Fourier space.
radii: torch.tensor
Radii of the Fourier space grid.
shell_width: float
Width of the shell.
Returns
-------
average: float
Average value of the volume over the shell.
"""
inner_diameter = shell_width + shell_index
outer_diameter = shell_width + (shell_index + 1)
mask = (radii > inner_diameter) & (radii < outer_diameter)
return torch.sum(mask * volume) / torch.sum(mask)


def _average_over_shells(volume_in_fourier_space, shell_width=0.5):
"""
Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space.
Parameters
----------
volume_in_fourier_space: torch.tensor
Volume in Fourier space.
Returns
-------
radial_average: torch.tensor
Average value of the volume over all shells.
"""
L = volume_in_fourier_space.shape[0]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

radial_average = torch.vmap(
_average_over_single_shell, in_dims=(0, None, None, None)
)(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width)

return radial_average


def compute_power_spectrum(volume, shell_width=0.5):
"""
Compute the power spectrum of a volume.
Parameters
----------
volume: torch.tensor
Volume for which to compute the power spectrum.
shell_width: float
Width of the shell.
Returns
-------
power_spectrum: torch.tensor
Power spectrum of the volume.
Examples
--------
volume = mrcfile.open("volume.mrc").data.copy()
volume = torch.tensor(volume, dtype=torch.float32)
power_spectrum = compute_power_spectrum(volume)
"""

# Compute centered Fourier transforms.
vol_fft = torch.abs(_centered_fftn(volume)) ** 2
power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width)

return power_spectrum
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
},
"0": {
"name": "raw_submission_in_testdata",
"align": 1,
"flavor_name": "test flavor",
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
"populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt",
"submission_version": "1.0",
"box_size": 32,
"pixel_size": 15.022,
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
"flip": 1,
"populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt",
"submission_version": "1.0"
"align": 1
}
}
85 changes: 85 additions & 0 deletions tests/test_power_spectrum_and_bfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum
from cryo_challenge._preprocessing.bfactor_normalize import (
_compute_bfactor_scaling,
bfactor_normalize_volumes,
)


def test_compute_power_spectrum():
"""
Test the computation of the power spectrum of a radially symmetric Gaussian volume.
Since the volume is radially symmetric, the power spectrum of the whole volume should be
approximately the power spectrum in a central slice. The computation is not exact as our
averaging over shells is approximated.
"""
box_size = 224
volume_shape = (box_size, box_size, box_size)
voxel_size = 1.073 * 2

freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
x = freq.clone()
y = freq.clone()
z = freq.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2

b_factor = 170

gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape)
gaussian_volume = _centered_ifftn(gaussian_volume)

power_spectrum = compute_power_spectrum(gaussian_volume)
power_spectrum_slice = (
torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2
)

mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2)

assert mean_squared_error < 1e-3

return


def test_bfactor_normalize_volumes():
"""
Similarly to the other test, we test the normalization of a radially symmetric volume.
In this case we test with an oscillatory volume, which is a volume with a sinusoidal.
Since both the b-factor correction volume and the volume are radially symmetric, the
power spectrum of the normalized volume should be the same as the power spectrum of
a normalized central slice
"""
box_size = 128
volume_shape = (box_size, box_size, box_size)
voxel_size = 1.5

freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
x = freq.clone()
y = freq.clone()
z = freq.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2

oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape)
oscillatory_volume = _centered_ifftn(oscillatory_volume)
bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size)

norm_oscillatory_vol = bfactor_normalize_volumes(
oscillatory_volume, 170, voxel_size, in_place=False
)

ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[
: box_size // 2, 0, 0
]
ps_norm_osci = torch.fft.fftn(
norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward"
)[: box_size // 2, 0, 0]
ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0]

ps_osci = torch.abs(ps_osci) ** 2
ps_norm_osci = torch.abs(ps_norm_osci) ** 2
ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2

assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling)

0 comments on commit cef000d

Please sign in to comment.