Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement power spectrum normalization and include it in SVD pipeline closes #14 #15

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b600c62
implement power spectrum normalization
May 30, 2024
6085f13
implement power spectrum normalization
May 30, 2024
8de649d
Merge branch 'main' into iss14
May 30, 2024
d07b45e
implement power spectrum normalization in svd pipeline
May 31, 2024
cb17a94
Merge branch 'main' into iss14
May 31, 2024
efdd80e
implement power spectrum normalization in svd pipeline
May 31, 2024
70105c7
made small change in how svd output files are handled
May 31, 2024
12330f4
made small change in how svd output files are handled
May 31, 2024
8822ade
make preprocessing pipeline request explicitly the file for the popul…
May 31, 2024
7cdd515
fix bug on power spectrum normalization
Jun 27, 2024
f489f14
update template for config and tutorial
DSilva27 Jun 27, 2024
8d0ddaa
Update .pre-commit-config.yaml
DSilva27 Jul 11, 2024
efed76e
Merge pull request #38 from flatironinstitute/DSilva27-patch-1
geoffwoollard Jul 11, 2024
b5b4b05
merge main to iss14 to include tests
Jul 11, 2024
6719dda
merge main to iss14 to include tests
Jul 11, 2024
bf28c54
update .gitignore so it does not include testing data
Jul 11, 2024
401cea8
update .gitignore so it does not include testing data
Jul 11, 2024
761f25a
update .gitignore so it does not include testing data
Jul 11, 2024
7ee6824
update .gitignore so it does not include testing data
Jul 11, 2024
5b00f48
remove hard coded path to populations.txt
Jul 11, 2024
4d57abf
turned on validator in the dataloader class
Jul 11, 2024
705445d
turned on validator in the dataloader class
Jul 11, 2024
5b07377
remove prints for debugging
Jul 11, 2024
22504f5
merge from main to include testing
Jul 11, 2024
fb8cf53
update tutorial with examples of figures
Jul 12, 2024
bd481d5
changed name of figures
Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cryo_challenge/_commands/run_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def main(args):
config = yaml.safe_load(file)

validate_config_svd(config)
warnexists(config["output_options"]["output_path"])
mkbasedir(config["output_options"]["output_path"])
warnexists(config["output_options"]["output_file"])
mkbasedir(os.path.dirname(config["output_options"]["output_file"]))

if config["experiment_mode"] == "all_vs_all":
run_all_vs_all_pipeline(config)
Expand Down
4 changes: 4 additions & 0 deletions src/cryo_challenge/_preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
downsample_volume as downsample_volume,
downsample_submission as downsample_submission,
)
from .normalize import (
compute_power_spectrum as compute_power_spectrum,
normalize_power_spectrum as normalize_power_spectrum,
)
6 changes: 4 additions & 2 deletions src/cryo_challenge/_preprocessing/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def validate_submission_config(self):
raise ValueError(f"Pixel size not found for submission {key}")
if "align" not in value.keys():
raise ValueError(f"Align not found for submission {key}")
if "populations_file" not in value.keys():
raise ValueError(f"Population file not found for submission {key}")

if not os.path.exists(value["path"]):
raise ValueError(f"Path {value['path']} does not exist")
Expand Down Expand Up @@ -154,8 +156,8 @@ def __getitem__(self, idx):

assert len(vol_paths) > 0, "No volumes found in submission directory"

populations = np.loadtxt(
os.path.join(self.submission_paths[idx], "populations.txt")
populations = np.loadtxt(self.submission_config["populations_file"]).astype(
float
)
populations = torch.from_numpy(populations)

Expand Down
115 changes: 99 additions & 16 deletions src/cryo_challenge/_preprocessing/normalize.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,105 @@
'''
TODO: Need to implement this properly
"""
Power spectrum normalization and required utility functions
"""

def normalize_mean_std(vols_flat):
"""
vols_flat.shape is (n_vols, n_pix**3)
vols_flat is a torch tensor
"""
return (vols_flat - vols_flat.mean(-1, keepdims=True)) / vols_flat.std(
-1, keepdims=True
)
import torch


def normalize_median_std(vols_flat):
def _cart2sph(x, y, z):
"""
vols_flat.shape is (n_vols, n_pix**3)
vols_flat is a torch tensor
Converts a grid in cartesian coordinates to spherical coordinates.

Parameters
----------
x: torch.tensor
x-coordinate of the grid.
y: torch.tensor
y-coordinate of the grid.
z: torch.tensor
"""
return (vols_flat - vols_flat.median(-1, keepdims=True).values) / vols_flat.std(
-1, keepdims=True
hxy = torch.hypot(x, y)
r = torch.hypot(hxy, z)
el = torch.atan2(z, hxy)
az = torch.atan2(y, x)
return az, el, r


def _grid_3d(n, dtype=torch.float32):
start = -n // 2 + 1
end = n // 2

if n % 2 == 0:
start -= 1 / 2
end -= 1 / 2

grid = torch.linspace(start, end, n, dtype=dtype)
z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij")

phi, theta, r = _cart2sph(x, y, z)

theta = torch.pi / 2 - theta

return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r}


def _centered_fftn(x, dim=None):
x = torch.fft.fftn(x, dim=dim)
x = torch.fft.fftshift(x, dim=dim)
return x


def _centered_ifftn(x, dim=None):
x = torch.fft.fftshift(x, dim=dim)
x = torch.fft.ifftn(x, dim=dim)
return x


def _compute_power_spectrum_shell(index, volume, radii):
inner_diameter = 0.5 + index
outer_diameter = 0.5 + (index + 1)
mask = (radii > inner_diameter) & (radii < outer_diameter)
return torch.norm(mask * volume) ** 2


def compute_power_spectrum(volume):
L = volume.shape[0]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

# Compute centered Fourier transforms.
vol_fft = _centered_fftn(volume)

power_spectrum = torch.vmap(_compute_power_spectrum_shell, in_dims=(0, None, None))(
torch.arange(0, L // 2), vol_fft, radii
)
'''
return power_spectrum


def normalize_power_spectrum(volumes, power_spectrum_ref):
L = volumes.shape[-1]
n_vols = volumes.shape[0]
dtype = torch.float32
radii = _grid_3d(L, dtype=dtype)["r"]

# Compute centered Fourier transforms.
vols_fft = _centered_fftn(volumes, dim=(1, 2, 3))

inner_diameter = 0.5
for i in range(0, L // 2):
# Compute ring mask
outer_diameter = 0.5 + (i + 1)
ring_mask = (radii > inner_diameter) & (radii < outer_diameter)

power_spectrum = torch.norm(
(ring_mask[None, ...] * vols_fft).reshape(n_vols, -1), dim=1
)
vols_fft[:, ring_mask] = (
vols_fft[:, ring_mask]
/ (power_spectrum[:, None] + 1e-5)
* power_spectrum_ref[i]
)

# # Update ring
inner_diameter = outer_diameter

return _centered_ifftn(vols_fft, dim=(1, 2, 3)).real
56 changes: 42 additions & 14 deletions src/cryo_challenge/_svd/svd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
from typing import Tuple
import yaml
import argparse
import os

from .svd_utils import get_vols_svd, project_vols_to_svd
from ..data._io.svd_io_utils import load_volumes, load_ref_vols
from ..data._io.svd_io_utils import (
load_volumes,
load_ref_vols,
remove_mean_volumes,
normalize_power_spectrum_sub,
)
from ..data._validation.config_validators import validate_config_svd
from .._preprocessing.normalize import compute_power_spectrum, normalize_power_spectrum


def run_svd_with_ref(
Expand Down Expand Up @@ -115,19 +120,29 @@ def run_all_vs_all_pipeline(config: dict):
""" # noqa: E501

dtype = torch.float32 if config["dtype"] == "float32" else torch.float64
volumes, mean_volumes, metadata = load_volumes(
volumes, metadata = load_volumes(
box_size_ds=config["box_size_ds"],
submission_list=config["submission_list"],
path_to_submissions=config["path_to_volumes"],
dtype=dtype,
)

volumes = normalize_power_spectrum_sub(
volumes,
metadata,
config["ref_vol_key"],
config["ref_vol_index"],
)

volumes, mean_volumes = remove_mean_volumes(volumes, metadata)

U, S, V, coeffs = run_svd_all_vs_all(volumes=volumes)

output_dict = {
"coeffs": coeffs,
"metadata": metadata,
"config": config,
"sing_vals": S,
}

if config["output_options"]["save_volumes"]:
Expand All @@ -136,13 +151,10 @@ def run_all_vs_all_pipeline(config: dict):

if config["output_options"]["save_svd_matrices"]:
output_dict["U"] = U
output_dict["S"] = S
output_dict["V"] = V
output_dict["S"] = S

output_file = os.path.join(
config["output_options"]["output_path"], "svd_results.pt"
)
torch.save(output_dict, output_file)
torch.save(output_dict, config["output_options"]["output_file"])

return output_dict

Expand Down Expand Up @@ -178,19 +190,37 @@ def run_all_vs_ref_pipeline(config: dict):

dtype = torch.float32 if config["dtype"] == "float32" else torch.float64

ref_volumes, mean_volume = load_ref_vols(
ref_volumes = load_ref_vols(
box_size_ds=config["box_size_ds"],
path_to_volumes=config["path_to_reference"],
dtype=dtype,
)

volumes, mean_volumes, metadata = load_volumes(
volumes, metadata = load_volumes(
box_size_ds=config["box_size_ds"],
submission_list=config["submission_list"],
path_to_submissions=config["path_to_volumes"],
dtype=dtype,
)

# Normalize Power spectrums
idx_ref_vol = (
metadata[config["ref_vol_key"]]["indices"][0] + config["ref_vol_index"]
)
ref_power_spectrum = compute_power_spectrum(volumes[idx_ref_vol])
ref_volumes = normalize_power_spectrum(ref_volumes, ref_power_spectrum)

volumes = normalize_power_spectrum_sub(
volumes,
metadata,
config["ref_vol_key"],
config["ref_vol_index"],
)

# Remove mean volumes
volumes, mean_volumes = remove_mean_volumes(volumes, metadata)
ref_volumes, mean_volume = remove_mean_volumes(ref_volumes)

U, S, V, coeffs, coeffs_ref = run_svd_with_ref(
volumes=volumes, ref_volumes=ref_volumes
)
Expand All @@ -200,6 +230,7 @@ def run_all_vs_ref_pipeline(config: dict):
"coeffs_ref": coeffs_ref,
"metadata": metadata,
"config": config,
"sing_vals": S,
}

if config["output_options"]["save_volumes"]:
Expand All @@ -213,10 +244,7 @@ def run_all_vs_ref_pipeline(config: dict):
output_dict["S"] = S
output_dict["V"] = V

output_file = os.path.join(
config["output_options"]["output_path"], "svd_results.pt"
)
torch.save(output_dict, output_file)
torch.save(output_dict, config["output_options"]["output_file"])

return output_dict

Expand Down
66 changes: 52 additions & 14 deletions src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,49 @@
from typing import Tuple

from ..._preprocessing.fourier_utils import downsample_volume
from ..._preprocessing.normalize import compute_power_spectrum, normalize_power_spectrum


def _remove_mean_volumes_sub(volumes, metadata):
box_size = volumes.shape[-1]
n_subs = len(list(metadata.keys()))
mean_volumes = torch.zeros((n_subs, box_size, box_size, box_size))

for i, key in enumerate(metadata.keys()):
indices = metadata[key]["indices"]

mean_volumes[i] = torch.mean(volumes[indices[0] : indices[1]], dim=0)
volumes[indices[0] : indices[1]] = (
volumes[indices[0] : indices[1]] - mean_volumes[i][None, ...]
)

return volumes, mean_volumes


def remove_mean_volumes(volumes, metadata=None):
volumes = volumes.clone()
if metadata is None:
mean_volumes = torch.mean(volumes, dim=0)
volumes = volumes - mean_volumes[None, ...]

else:
volumes, mean_volumes = _remove_mean_volumes_sub(volumes, metadata)

return volumes, mean_volumes


def normalize_power_spectrum_sub(volumes, metadata, ref_vol_key, ref_vol_index):
volumes = volumes.clone()
idx_ref_vol = metadata[ref_vol_key]["indices"][0] + ref_vol_index
ref_power_spectrum = compute_power_spectrum(volumes[idx_ref_vol])

for key in metadata.keys():
indices = metadata[key]["indices"]
volumes[indices[0] : indices[1]] = normalize_power_spectrum(
volumes[indices[0] : indices[1]], ref_power_spectrum
)

return volumes


def load_volumes(
Expand All @@ -28,10 +71,10 @@ def load_volumes(
-------
volumes: torch.tensor
Tensor of shape (n_volumes, n_x, n_y, n_z) containing the volumes.
populations: dict
Dictionary containing the populations of each submission.
vols_per_submission: dict
Dictionary containing the number of volumes per submission.
metadata: dict
Dictionary containing the metadata for each submission.
The keys are the id (ice cream name) of each submission.
The values are dictionaries containing the number of volumes, the populations, and the indices of the volumes in the volumes tensor.

Examples
--------
Expand All @@ -43,12 +86,10 @@ def load_volumes(

metadata = {}
volumes = torch.empty((0, box_size_ds, box_size_ds, box_size_ds), dtype=dtype)
mean_volumes = torch.empty(
(len(submission_list), box_size_ds, box_size_ds, box_size_ds), dtype=dtype
)

counter = 0

for i, idx in enumerate(submission_list):
for idx in submission_list:
submission = torch.load(f"{path_to_submissions}/submission_{idx}.pt")
vols = submission["volumes"]
pops = submission["populations"]
Expand All @@ -68,11 +109,9 @@ def load_volumes(
"indices": (counter_start, counter),
}

mean_volumes[i] = vols_tmp.mean(dim=0)
vols_tmp = vols_tmp - mean_volumes[i][None, :, :, :]
volumes = torch.cat((volumes, vols_tmp), dim=0)

return volumes, mean_volumes, metadata
return volumes, metadata


def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):
Expand Down Expand Up @@ -122,7 +161,6 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):
volumes_ds[i] = downsample_volume(vol, box_size_ds)
volumes_ds[i] = volumes_ds[i] / volumes_ds[i].sum()

mean_volume = volumes_ds.mean(dim=0)
volumes_ds = volumes_ds - mean_volume[None, :, :, :]
volumes_ds = volumes_ds

return volumes_ds, mean_volume
return volumes_ds
Loading