Skip to content

Commit

Permalink
implement power spectrum and bfactor scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 12, 2024
1 parent fc99c96 commit 0ce2df1
Show file tree
Hide file tree
Showing 25 changed files with 500 additions and 214 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main_merge_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
if: github.base_ref == 'main' && github.head_ref != 'dev'
run: |
echo "ERROR: You can only merge to main from dev."
exit 1
exit 1
2 changes: 1 addition & 1 deletion config_files/config_distribution_to_distribution.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ cvxpy_solver: ECOS
optimal_q_kl:
n_iter: 100000
break_atol: 0.0001
output_fname: results/distribution_to_distribution_submission_0.pkl
output_fname: results/distribution_to_distribution_submission_0.pkl
10 changes: 5 additions & 5 deletions config_files/config_map_to_map_distance_matrix.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
data:
n_pix: 224
psize: 2.146
psize: 2.146
submission:
fname: data/dataset_2_ground_truth/submission_0.pt
volume_key: volumes
metadata_key: populations
label_key: id
ground_truth:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
do: true
volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc
analysis:
Expand All @@ -23,4 +23,4 @@ analysis:
normalize:
do: true
method: median_zscore
output: results/map_to_map_distance_matrix_submission_0.pkl
output: results/map_to_map_distance_matrix_submission_0.pkl
4 changes: 3 additions & 1 deletion src/cryo_challenge/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from cryo_challenge.__about__ import __version__
from cryo_challenge.__about__ import __version__

__all__ = ["__version__"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
import pickle
from scipy.stats import rankdata
import yaml
import argparse
import torch
import ot

Expand All @@ -14,10 +12,12 @@


def sort_by_transport(cost):
m,n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost)
indices = np.argsort((transport * np.arange(m)[...,None]).sum(0))
return cost[:,indices], indices, transport
m, n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(
np.ones(m) / m, np.ones(n) / n, cost
)
indices = np.argsort((transport * np.arange(m)[..., None]).sum(0))
return cost[:, indices], indices, transport


def compute_wasserstein_between_distributions_from_weights_and_cost(
Expand Down Expand Up @@ -65,15 +65,14 @@ def make_assignment_matrix(cost_matrix):


def run(config):

metadata_df = pd.read_csv(config["gt_metadata_fname"])
metadata_df.sort_values("pc1", inplace=True)

with open(config["input_fname"], "rb") as f:
data = pickle.load(f)

# user_submitted_populations = np.ones(80)/80
user_submitted_populations = data["user_submitted_populations"]#.numpy()
user_submitted_populations = data["user_submitted_populations"] # .numpy()
id = torch.load(data["config"]["data"]["submission"]["fname"])["id"]

results_dict = {}
Expand Down Expand Up @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol):
DistributionToDistributionResultsValidator.from_dict(results_dict)
with open(config["output_fname"], "wb") as f:
pickle.dump(results_dict, f)

return results_dict
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(config):
user_submission_label = submission[label_key]

# n_trunc = 10
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc]
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc]

results_dict = {}
results_dict["config"] = config
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_challenge/_ploting/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np


def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
36 changes: 36 additions & 0 deletions src/cryo_challenge/_preprocessing/bfactor_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from ..power_spectrum_utils import _centered_fftn, _centered_ifftn


def _compute_bfactor_scaling(b_factor, box_size, voxel_size):
x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size))
y = x.clone()
z = x.clone()
x, y, z = torch.meshgrid(x, y, z, indexing="ij")

s2 = x**2 + y**2 + z**2
bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4)

return bfactor_scaling_torch


def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True):
if not in_place:
volumes = volumes.clone()

b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size)

if len(volumes.shape) == 3:
volumes = _centered_fftn(volumes, dim=(0, 1, 2))
volumes = volumes * b_factor_scaling
volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real

elif len(volumes.shape) == 4:
volumes = _centered_fftn(volumes, dim=(1, 2, 3))
volumes = volumes * b_factor_scaling[None, ...]
volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real

else:
raise ValueError("Input volumes must have 3 or 4 dimensions.")

return volumes
22 changes: 0 additions & 22 deletions src/cryo_challenge/_preprocessing/normalize.py

This file was deleted.

24 changes: 18 additions & 6 deletions src/cryo_challenge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist
from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl
from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator
from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator
from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD
from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL
from ._validation.config_validators import (
validate_input_config_disttodist as validate_input_config_disttodist,
)
from ._validation.config_validators import (
validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl,
)
from cryo_challenge.data._validation.output_validators import (
DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator,
)
from cryo_challenge.data._validation.output_validators import (
MetricDistToDistValidator as MetricDistToDistValidator,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorEMD as ReplicateValidatorEMD,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorKL as ReplicateValidatorKL,
)
10 changes: 6 additions & 4 deletions src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):

# Reshape volumes to correct size
if volumes.dim() == 2:
box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.))))
box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0))))
volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size))
elif volumes.dim() == 4:
pass
else:
raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).")
raise ValueError(
f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)."
)

volumes_ds = torch.empty(
(volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype
Expand Down
21 changes: 13 additions & 8 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@dataclass_json
@dataclass
class MapToMapResultsValidator:
'''
"""
Validate the output dictionary of the map-to-map distance matrix computation.
config: dict, input config dictionary.
Expand All @@ -22,7 +22,8 @@ class MapToMapResultsValidator:
l2: dict, L2 results.
bioem: dict, BioEM results.
fsc: dict, FSC results.
'''
"""

config: dict
user_submitted_populations: torch.Tensor
corr: Optional[dict] = None
Expand All @@ -49,7 +50,7 @@ class ReplicateValidatorEMD:
Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline.
q_opt: List[float], optimal user submitted distribution, which sums to 1.
EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt).
EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt).
The transport plan is a joint distribution, such that:
summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution.
transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns).
Expand All @@ -61,6 +62,7 @@ class ReplicateValidatorEMD:
The transport plan is a joint distribution, such that:
summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution.
"""

q_opt: List[float]
EMD_opt: float
transport_plan_opt: List[List[float]]
Expand All @@ -87,8 +89,9 @@ class ReplicateValidatorKL:
iter_stop: int, number of iterations until convergence.
eps_stop: float, stopping criterion.
klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q).
klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p).
klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p).
"""

q_opt: List[float]
klpq_opt: float
klqp_opt: float
Expand All @@ -106,11 +109,12 @@ def __post_init__(self):
@dataclass_json
@dataclass
class MetricDistToDistValidator:
'''
"""
Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline.
replicates: dict, dictionary of replicates.
'''
"""

replicates: dict

def validate_replicates(self, n_replicates):
Expand All @@ -126,7 +130,7 @@ def validate_replicates(self, n_replicates):
@dataclass_json
@dataclass
class DistributionToDistributionResultsValidator:
'''
"""
Validate the output dictionary of the distribution-to-distribution pipeline.
config: dict, input config dictionary.
Expand All @@ -136,7 +140,8 @@ class DistributionToDistributionResultsValidator:
bioem: dict, BioEM distance results.
l2: dict, L2 distance results.
corr: dict, correlation distance results.
'''
"""

config: dict
user_submitted_populations: torch.Tensor
id: str
Expand Down
71 changes: 71 additions & 0 deletions src/cryo_challenge/power_spectrum_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch


def _cart2sph(x, y, z):
"""
Converts a grid in cartesian coordinates to spherical coordinates.
Parameters
----------
x: torch.tensor
x-coordinate of the grid.
y: torch.tensor
y-coordinate of the grid.
z: torch.tensor
"""
hxy = torch.hypot(x, y)
r = torch.hypot(hxy, z)
el = torch.atan2(z, hxy)
az = torch.atan2(y, x)
return az, el, r


def _grid_3d(n, dtype=torch.float32):
start = -n // 2 + 1
end = n // 2

if n % 2 == 0:
start -= 1 / 2
end -= 1 / 2

grid = torch.linspace(start, end, n, dtype=dtype)
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij")

phi, theta, r = _cart2sph(x, y, z)

theta = torch.pi / 2 - theta

return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r}


def _centered_fftn(x, dim=None):
x = torch.fft.fftn(x, dim=dim)
x = torch.fft.fftshift(x, dim=dim)
return x


def _centered_ifftn(x, dim=None):
x = torch.fft.fftshift(x, dim=dim)
x = torch.fft.ifftn(x, dim=dim)
return x


def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5):
inner_diameter = shell_width + index
outer_diameter = shell_width + (index + 1)
mask = (radii > inner_diameter) & (radii < outer_diameter)
return torch.sum(mask * volume) / torch.sum(mask)


def compute_power_spectrum(volume, shell_width=0.5):
L = volume.shape[0]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

# Compute centered Fourier transforms.
vol_fft = torch.abs(_centered_fftn(volume)) ** 2

power_spectrum = torch.vmap(
_compute_power_spectrum_shell, in_dims=(0, None, None, None)
)(torch.arange(0, L // 2), vol_fft, radii, shell_width)
return power_spectrum
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"0": {
"name": "raw_submission_in_testdata",
"align": 1,
"flavor_name": "test flavor",
"flavor_name": "test flavor",
"box_size": 244,
"pixel_size": 2.146,
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
Expand Down
8 changes: 5 additions & 3 deletions tests/test_distribution_to_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from cryo_challenge._commands import run_distribution2distribution_pipeline


def test_run_distribution2distribution_pipeline():
args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'})
run_distribution2distribution_pipeline.main(args)
def test_run_distribution2distribution_pipeline():
args = OmegaConf.create(
{"config": "tests/config_files/test_config_distribution_to_distribution.yaml"}
)
run_distribution2distribution_pipeline.main(args)
Loading

0 comments on commit 0ce2df1

Please sign in to comment.