-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement power spectrum and bfactor scaling
- Loading branch information
DSilva27
committed
Aug 12, 2024
1 parent
fc99c96
commit 0ce2df1
Showing
25 changed files
with
500 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from cryo_challenge.__about__ import __version__ | ||
from cryo_challenge.__about__ import __version__ | ||
|
||
__all__ = ["__version__"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import numpy as np | ||
|
||
|
||
def res_at_fsc_threshold(fscs, threshold=0.5): | ||
res_fsc_half = np.argmin(fscs > threshold, axis=-1) | ||
fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] | ||
return res_fsc_half, fraction_nyquist | ||
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] | ||
return res_fsc_half, fraction_nyquist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn | ||
|
||
|
||
def _compute_bfactor_scaling(b_factor, box_size, voxel_size): | ||
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) | ||
y = x.clone() | ||
z = x.clone() | ||
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | ||
|
||
s2 = x**2 + y**2 + z**2 | ||
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) | ||
|
||
return bfactor_scaling_torch | ||
|
||
|
||
def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): | ||
if not in_place: | ||
volumes = volumes.clone() | ||
|
||
b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) | ||
|
||
if len(volumes.shape) == 3: | ||
volumes = _centered_fftn(volumes, dim=(0, 1, 2)) | ||
volumes = volumes * b_factor_scaling | ||
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real | ||
|
||
elif len(volumes.shape) == 4: | ||
volumes = _centered_fftn(volumes, dim=(1, 2, 3)) | ||
volumes = volumes * b_factor_scaling[None, ...] | ||
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real | ||
|
||
else: | ||
raise ValueError("Input volumes must have 3 or 4 dimensions.") | ||
|
||
return volumes |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,18 @@ | ||
from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist | ||
from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl | ||
from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator | ||
from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator | ||
from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD | ||
from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL | ||
from ._validation.config_validators import ( | ||
validate_input_config_disttodist as validate_input_config_disttodist, | ||
) | ||
from ._validation.config_validators import ( | ||
validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, | ||
) | ||
from cryo_challenge.data._validation.output_validators import ( | ||
DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, | ||
) | ||
from cryo_challenge.data._validation.output_validators import ( | ||
MetricDistToDistValidator as MetricDistToDistValidator, | ||
) | ||
from cryo_challenge.data._validation.output_validators import ( | ||
ReplicateValidatorEMD as ReplicateValidatorEMD, | ||
) | ||
from cryo_challenge.data._validation.output_validators import ( | ||
ReplicateValidatorKL as ReplicateValidatorKL, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
|
||
|
||
def _cart2sph(x, y, z): | ||
""" | ||
Converts a grid in cartesian coordinates to spherical coordinates. | ||
Parameters | ||
---------- | ||
x: torch.tensor | ||
x-coordinate of the grid. | ||
y: torch.tensor | ||
y-coordinate of the grid. | ||
z: torch.tensor | ||
""" | ||
hxy = torch.hypot(x, y) | ||
r = torch.hypot(hxy, z) | ||
el = torch.atan2(z, hxy) | ||
az = torch.atan2(y, x) | ||
return az, el, r | ||
|
||
|
||
def _grid_3d(n, dtype=torch.float32): | ||
start = -n // 2 + 1 | ||
end = n // 2 | ||
|
||
if n % 2 == 0: | ||
start -= 1 / 2 | ||
end -= 1 / 2 | ||
|
||
grid = torch.linspace(start, end, n, dtype=dtype) | ||
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") | ||
|
||
phi, theta, r = _cart2sph(x, y, z) | ||
|
||
theta = torch.pi / 2 - theta | ||
|
||
return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} | ||
|
||
|
||
def _centered_fftn(x, dim=None): | ||
x = torch.fft.fftn(x, dim=dim) | ||
x = torch.fft.fftshift(x, dim=dim) | ||
return x | ||
|
||
|
||
def _centered_ifftn(x, dim=None): | ||
x = torch.fft.fftshift(x, dim=dim) | ||
x = torch.fft.ifftn(x, dim=dim) | ||
return x | ||
|
||
|
||
def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5): | ||
inner_diameter = shell_width + index | ||
outer_diameter = shell_width + (index + 1) | ||
mask = (radii > inner_diameter) & (radii < outer_diameter) | ||
return torch.sum(mask * volume) / torch.sum(mask) | ||
|
||
|
||
def compute_power_spectrum(volume, shell_width=0.5): | ||
L = volume.shape[0] | ||
dtype = torch.float32 | ||
radii = _grid_3d(L, dtype=dtype)["r"] | ||
|
||
# Compute centered Fourier transforms. | ||
vol_fft = torch.abs(_centered_fftn(volume)) ** 2 | ||
|
||
power_spectrum = torch.vmap( | ||
_compute_power_spectrum_shell, in_dims=(0, None, None, None) | ||
)(torch.arange(0, L // 2), vol_fft, radii, shell_width) | ||
return power_spectrum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.