Skip to content

Commit

Permalink
add docstrings to tests and implemented functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 15, 2024
1 parent a06b96d commit 5c1b901
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 12 deletions.
61 changes: 57 additions & 4 deletions src/cryo_challenge/_preprocessing/bfactor_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@


def _compute_bfactor_scaling(b_factor, box_size, voxel_size):
"""
Compute the B-factor scaling factor for a given B-factor, box size, and voxel size.
The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared
distance in Fourier space.
Parameters
----------
b_factor: float
B-factor to apply.
box_size: int
Size of the box.
voxel_size: float
Voxel size of the box.
Returns
-------
bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size))
B-factor scaling factor.
"""
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
y = x.clone()
z = x.clone()
Expand All @@ -14,7 +33,44 @@ def _compute_bfactor_scaling(b_factor, box_size, voxel_size):
return bfactor_scaling_torch


def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True):
def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False):
"""
Normalize volumes by applying a B-factor correction. This is done by multiplying
a centered Fourier transform of the volume by the B-factor scaling factor and then
applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the
computation of the B-factor scaling.
Parameters
----------
volumes: torch.tensor
Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N).
bfactor: float
B-factor to apply.
voxel_size: float
Voxel size of the volumes.
in_place: bool - default: False
Whether to normalize the volumes in place.
Returns
-------
volumes: torch.tensor
Normalized volumes.
"""
# assert that volumes have the correct shape
assert volumes.ndim in [
3,
4,
], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)"

if volumes.ndim == 3:
assert (
volumes.shape[0] == volumes.shape[1] == volumes.shape[2]
), "Input volumes must have equal dimensions"
else:
assert (
volumes.shape[1] == volumes.shape[2] == volumes.shape[3]
), "Input volumes must have equal dimensions"

if not in_place:
volumes = volumes.clone()

Expand All @@ -30,7 +86,4 @@ def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True):
volumes = volumes * b_factor_scaling[None, ...]
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real

else:
raise ValueError("Input volumes must have 3 or 4 dimensions.")

return volumes
98 changes: 90 additions & 8 deletions src/cryo_challenge/power_spectrum_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ def _cart2sph(x, y, z):


def _grid_3d(n, dtype=torch.float32):
"""
Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates.
Parameters
----------
n: int
Size of the grid.
dtype: torch.dtype
Data type of the grid.
Returns
-------
grid: dict
Dictionary containing the grid in cartesian and spherical coordinates.
keys: x, y, z, phi, theta, r
"""
start = -n // 2 + 1
end = n // 2

Expand All @@ -39,33 +55,99 @@ def _grid_3d(n, dtype=torch.float32):


def _centered_fftn(x, dim=None):
"""
Wrapper around torch.fft.fftn that centers the Fourier transform.
"""
x = torch.fft.fftn(x, dim=dim)
x = torch.fft.fftshift(x, dim=dim)
return x


def _centered_ifftn(x, dim=None):
"""
Wrapper around torch.fft.ifftn that centers the inverse Fourier transform.
"""
x = torch.fft.fftshift(x, dim=dim)
x = torch.fft.ifftn(x, dim=dim)
return x


def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5):
inner_diameter = shell_width + index
outer_diameter = shell_width + (index + 1)
def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5):
"""
Given a volume in Fourier space, compute the average value of the volume over a shell.
Parameters
----------
shell_index: int
Index of the shell in Fourier space.
volume: torch.tensor
Volume in Fourier space.
radii: torch.tensor
Radii of the Fourier space grid.
shell_width: float
Width of the shell.
Returns
-------
average: float
Average value of the volume over the shell.
"""
inner_diameter = shell_width + shell_index
outer_diameter = shell_width + (shell_index + 1)
mask = (radii > inner_diameter) & (radii < outer_diameter)
return torch.sum(mask * volume) / torch.sum(mask)


def compute_power_spectrum(volume, shell_width=0.5):
L = volume.shape[0]
def _average_over_shells(volume_in_fourier_space, shell_width=0.5):
"""
Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space.
Parameters
----------
volume_in_fourier_space: torch.tensor
Volume in Fourier space.
Returns
-------
radial_average: torch.tensor
Average value of the volume over all shells.
"""
L = volume_in_fourier_space.shape[0]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

radial_average = torch.vmap(
_average_over_single_shell, in_dims=(0, None, None, None)
)(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width)

return radial_average


def compute_power_spectrum(volume, shell_width=0.5):
"""
Compute the power spectrum of a volume.
Parameters
----------
volume: torch.tensor
Volume for which to compute the power spectrum.
shell_width: float
Width of the shell.
Returns
-------
power_spectrum: torch.tensor
Power spectrum of the volume.
Examples
--------
volume = mrcfile.open("volume.mrc").data.copy()
volume = torch.tensor(volume, dtype=torch.float32)
power_spectrum = compute_power_spectrum(volume)
"""

# Compute centered Fourier transforms.
vol_fft = torch.abs(_centered_fftn(volume)) ** 2
power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width)

power_spectrum = torch.vmap(
_compute_power_spectrum_shell, in_dims=(0, None, None, None)
)(torch.arange(0, L // 2), vol_fft, radii, shell_width)
return power_spectrum
13 changes: 13 additions & 0 deletions tests/test_power_spectrum_and_bfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@


def test_compute_power_spectrum():
"""
Test the computation of the power spectrum of a radially symmetric Gaussian volume.
Since the volume is radially symmetric, the power spectrum of the whole volume should be
approximately the power spectrum in a central slice. The computation is not exact as our
averaging over shells is approximated.
"""
box_size = 224
volume_shape = (box_size, box_size, box_size)
voxel_size = 1.073 * 2
Expand Down Expand Up @@ -37,6 +43,13 @@ def test_compute_power_spectrum():


def test_bfactor_normalize_volumes():
"""
Similarly to the other test, we test the normalization of a radially symmetric volume.
In this case we test with an oscillatory volume, which is a volume with a sinusoidal.
Since both the b-factor correction volume and the volume are radially symmetric, the
power spectrum of the normalized volume should be the same as the power spectrum of
a normalized central slice
"""
box_size = 128
volume_shape = (box_size, box_size, box_size)
voxel_size = 1.5
Expand Down

0 comments on commit 5c1b901

Please sign in to comment.