-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #82 from flatironinstitute/14-implement-power-spec…
…trum-normalization 14 implement power spectrum normalization
- Loading branch information
Showing
5 changed files
with
331 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn | ||
|
||
|
||
def _compute_bfactor_scaling(b_factor, box_size, voxel_size): | ||
""" | ||
Compute the B-factor scaling factor for a given B-factor, box size, and voxel size. | ||
The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared | ||
distance in Fourier space. | ||
Parameters | ||
---------- | ||
b_factor: float | ||
B-factor to apply. | ||
box_size: int | ||
Size of the box. | ||
voxel_size: float | ||
Voxel size of the box. | ||
Returns | ||
------- | ||
bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size)) | ||
B-factor scaling factor. | ||
""" | ||
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
y = x.clone() | ||
z = x.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) | ||
|
||
return bfactor_scaling_torch | ||
|
||
|
||
def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False): | ||
""" | ||
Normalize volumes by applying a B-factor correction. This is done by multiplying | ||
a centered Fourier transform of the volume by the B-factor scaling factor and then | ||
applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the | ||
computation of the B-factor scaling. | ||
Parameters | ||
---------- | ||
volumes: torch.tensor | ||
Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N). | ||
bfactor: float | ||
B-factor to apply. | ||
voxel_size: float | ||
Voxel size of the volumes. | ||
in_place: bool - default: False | ||
Whether to normalize the volumes in place. | ||
Returns | ||
------- | ||
volumes: torch.tensor | ||
Normalized volumes. | ||
""" | ||
# assert that volumes have the correct shape | ||
assert volumes.ndim in [ | ||
3, | ||
4, | ||
], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)" | ||
|
||
if volumes.ndim == 3: | ||
assert ( | ||
volumes.shape[0] == volumes.shape[1] == volumes.shape[2] | ||
), "Input volumes must have equal dimensions" | ||
else: | ||
assert ( | ||
volumes.shape[1] == volumes.shape[2] == volumes.shape[3] | ||
), "Input volumes must have equal dimensions" | ||
|
||
if not in_place: | ||
volumes = volumes.clone() | ||
|
||
b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) | ||
|
||
if len(volumes.shape) == 3: | ||
volumes = _centered_fftn(volumes, dim=(0, 1, 2)) | ||
volumes = volumes * b_factor_scaling | ||
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real | ||
|
||
elif len(volumes.shape) == 4: | ||
volumes = _centered_fftn(volumes, dim=(1, 2, 3)) | ||
volumes = volumes * b_factor_scaling[None, ...] | ||
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real | ||
|
||
return volumes |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
|
||
|
||
def _cart2sph(x, y, z): | ||
""" | ||
Converts a grid in cartesian coordinates to spherical coordinates. | ||
Parameters | ||
---------- | ||
x: torch.tensor | ||
x-coordinate of the grid. | ||
y: torch.tensor | ||
y-coordinate of the grid. | ||
z: torch.tensor | ||
""" | ||
hxy = torch.hypot(x, y) | ||
r = torch.hypot(hxy, z) | ||
el = torch.atan2(z, hxy) | ||
az = torch.atan2(y, x) | ||
return az, el, r | ||
|
||
|
||
def _grid_3d(n, dtype=torch.float32): | ||
""" | ||
Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates. | ||
Parameters | ||
---------- | ||
n: int | ||
Size of the grid. | ||
dtype: torch.dtype | ||
Data type of the grid. | ||
Returns | ||
------- | ||
grid: dict | ||
Dictionary containing the grid in cartesian and spherical coordinates. | ||
keys: x, y, z, phi, theta, r | ||
""" | ||
start = -n // 2 + 1 | ||
end = n // 2 | ||
|
||
if n % 2 == 0: | ||
start -= 1 / 2 | ||
end -= 1 / 2 | ||
|
||
grid = torch.linspace(start, end, n, dtype=dtype) | ||
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") | ||
|
||
phi, theta, r = _cart2sph(x, y, z) | ||
|
||
theta = torch.pi / 2 - theta | ||
|
||
return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} | ||
|
||
|
||
def _centered_fftn(x, dim=None): | ||
""" | ||
Wrapper around torch.fft.fftn that centers the Fourier transform. | ||
""" | ||
x = torch.fft.fftn(x, dim=dim) | ||
x = torch.fft.fftshift(x, dim=dim) | ||
return x | ||
|
||
|
||
def _centered_ifftn(x, dim=None): | ||
""" | ||
Wrapper around torch.fft.ifftn that centers the inverse Fourier transform. | ||
""" | ||
x = torch.fft.fftshift(x, dim=dim) | ||
x = torch.fft.ifftn(x, dim=dim) | ||
return x | ||
|
||
|
||
def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5): | ||
""" | ||
Given a volume in Fourier space, compute the average value of the volume over a shell. | ||
Parameters | ||
---------- | ||
shell_index: int | ||
Index of the shell in Fourier space. | ||
volume: torch.tensor | ||
Volume in Fourier space. | ||
radii: torch.tensor | ||
Radii of the Fourier space grid. | ||
shell_width: float | ||
Width of the shell. | ||
Returns | ||
------- | ||
average: float | ||
Average value of the volume over the shell. | ||
""" | ||
inner_diameter = shell_width + shell_index | ||
outer_diameter = shell_width + (shell_index + 1) | ||
mask = (radii > inner_diameter) & (radii < outer_diameter) | ||
return torch.sum(mask * volume) / torch.sum(mask) | ||
|
||
|
||
def _average_over_shells(volume_in_fourier_space, shell_width=0.5): | ||
""" | ||
Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space. | ||
Parameters | ||
---------- | ||
volume_in_fourier_space: torch.tensor | ||
Volume in Fourier space. | ||
Returns | ||
------- | ||
radial_average: torch.tensor | ||
Average value of the volume over all shells. | ||
""" | ||
L = volume_in_fourier_space.shape[0] | ||
dtype = torch.float32 | ||
radii = _grid_3d(L, dtype=dtype)["r"] | ||
|
||
radial_average = torch.vmap( | ||
_average_over_single_shell, in_dims=(0, None, None, None) | ||
)(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width) | ||
|
||
return radial_average | ||
|
||
|
||
def compute_power_spectrum(volume, shell_width=0.5): | ||
""" | ||
Compute the power spectrum of a volume. | ||
Parameters | ||
---------- | ||
volume: torch.tensor | ||
Volume for which to compute the power spectrum. | ||
shell_width: float | ||
Width of the shell. | ||
Returns | ||
------- | ||
power_spectrum: torch.tensor | ||
Power spectrum of the volume. | ||
Examples | ||
-------- | ||
volume = mrcfile.open("volume.mrc").data.copy() | ||
volume = torch.tensor(volume, dtype=torch.float32) | ||
power_spectrum = compute_power_spectrum(volume) | ||
""" | ||
|
||
# Compute centered Fourier transforms. | ||
vol_fft = torch.abs(_centered_fftn(volume)) ** 2 | ||
power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width) | ||
|
||
return power_spectrum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum | ||
from cryo_challenge._preprocessing.bfactor_normalize import ( | ||
_compute_bfactor_scaling, | ||
bfactor_normalize_volumes, | ||
) | ||
|
||
|
||
def test_compute_power_spectrum(): | ||
""" | ||
Test the computation of the power spectrum of a radially symmetric Gaussian volume. | ||
Since the volume is radially symmetric, the power spectrum of the whole volume should be | ||
approximately the power spectrum in a central slice. The computation is not exact as our | ||
averaging over shells is approximated. | ||
""" | ||
box_size = 224 | ||
volume_shape = (box_size, box_size, box_size) | ||
voxel_size = 1.073 * 2 | ||
|
||
freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
x = freq.clone() | ||
y = freq.clone() | ||
z = freq.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
|
||
b_factor = 170 | ||
|
||
gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape) | ||
gaussian_volume = _centered_ifftn(gaussian_volume) | ||
|
||
power_spectrum = compute_power_spectrum(gaussian_volume) | ||
power_spectrum_slice = ( | ||
torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2 | ||
) | ||
|
||
mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2) | ||
|
||
assert mean_squared_error < 1e-3 | ||
|
||
return | ||
|
||
|
||
def test_bfactor_normalize_volumes(): | ||
""" | ||
Similarly to the other test, we test the normalization of a radially symmetric volume. | ||
In this case we test with an oscillatory volume, which is a volume with a sinusoidal. | ||
Since both the b-factor correction volume and the volume are radially symmetric, the | ||
power spectrum of the normalized volume should be the same as the power spectrum of | ||
a normalized central slice | ||
""" | ||
box_size = 128 | ||
volume_shape = (box_size, box_size, box_size) | ||
voxel_size = 1.5 | ||
|
||
freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
x = freq.clone() | ||
y = freq.clone() | ||
z = freq.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
|
||
oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape) | ||
oscillatory_volume = _centered_ifftn(oscillatory_volume) | ||
bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size) | ||
|
||
norm_oscillatory_vol = bfactor_normalize_volumes( | ||
oscillatory_volume, 170, voxel_size, in_place=False | ||
) | ||
|
||
ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[ | ||
: box_size // 2, 0, 0 | ||
] | ||
ps_norm_osci = torch.fft.fftn( | ||
norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward" | ||
)[: box_size // 2, 0, 0] | ||
ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0] | ||
|
||
ps_osci = torch.abs(ps_osci) ** 2 | ||
ps_norm_osci = torch.abs(ps_norm_osci) ** 2 | ||
ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2 | ||
|
||
assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling) |