Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement power spectrum normalization and include it in SVD pipeline closes #14 #15

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b600c62
implement power spectrum normalization
May 30, 2024
6085f13
implement power spectrum normalization
May 30, 2024
8de649d
Merge branch 'main' into iss14
May 30, 2024
d07b45e
implement power spectrum normalization in svd pipeline
May 31, 2024
cb17a94
Merge branch 'main' into iss14
May 31, 2024
efdd80e
implement power spectrum normalization in svd pipeline
May 31, 2024
70105c7
made small change in how svd output files are handled
May 31, 2024
12330f4
made small change in how svd output files are handled
May 31, 2024
8822ade
make preprocessing pipeline request explicitly the file for the popul…
May 31, 2024
7cdd515
fix bug on power spectrum normalization
Jun 27, 2024
f489f14
update template for config and tutorial
DSilva27 Jun 27, 2024
8d0ddaa
Update .pre-commit-config.yaml
DSilva27 Jul 11, 2024
efed76e
Merge pull request #38 from flatironinstitute/DSilva27-patch-1
geoffwoollard Jul 11, 2024
b5b4b05
merge main to iss14 to include tests
Jul 11, 2024
6719dda
merge main to iss14 to include tests
Jul 11, 2024
bf28c54
update .gitignore so it does not include testing data
Jul 11, 2024
401cea8
update .gitignore so it does not include testing data
Jul 11, 2024
761f25a
update .gitignore so it does not include testing data
Jul 11, 2024
7ee6824
update .gitignore so it does not include testing data
Jul 11, 2024
5b00f48
remove hard coded path to populations.txt
Jul 11, 2024
4d57abf
turned on validator in the dataloader class
Jul 11, 2024
705445d
turned on validator in the dataloader class
Jul 11, 2024
5b07377
remove prints for debugging
Jul 11, 2024
22504f5
merge from main to include testing
Jul 11, 2024
fb8cf53
update tutorial with examples of figures
Jul 12, 2024
bd481d5
changed name of figures
Jul 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -47,16 +47,15 @@ jobs:
python -m pip install --upgrade pip
pip install .
pip install pytest omegaconf

- name: Get test data from OSF
if: ${{ steps.cache_test_data.outputs.cache-hit != 'true' }}
run: |
sh tests/scripts/fetch_test_data.sh

- name: Test with pytest
run: |
pytest tests/test_preprocessing.py
pytest tests/test_svd.py
pytest tests/test_map_to_map.py
pytest tests/test_distribution_to_distribution.py

6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
tests/data/dataset_2_submissions
tests/data/Ground_truth
tests/results
tests/data/unprocessed_dataset_2_submissions/submission_x/*.mrc
tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.4
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<h1 align='center'>Cryo-EM Heterogeniety Challenge</h1>

<p align="center">

<img alt="Supported Python versions" src="https://img.shields.io/badge/Supported_Python_Versions-3.8_%7C_3.9_%7C_3.10_%7C_3.11-blue">
<img alt="GitHub Downloads (all assets, all releases)" src="https://img.shields.io/github/downloads/flatironinstitute/Cryo-EM-Heterogeneity-Challenge-1/total">
<img alt="GitHub branch check runs" src="https://img.shields.io/github/check-runs/flatironinstitute/Cryo-EM-Heterogeneity-Challenge-1/main">
Expand All @@ -10,13 +10,13 @@
</p>

<p align="center">

<img alt="Cryo-EM Heterogeneity Challenge" src="https://simonsfoundation.imgix.net/wp-content/uploads/2023/05/15134456/Screenshot-2023-05-15-at-1.39.07-PM.png?auto=format&q=90">

</p>



This repository contains the code used to analyse the submissions for the [Inaugural Flatiron Cryo-EM Heterogeneity Challenge](https://www.simonsfoundation.org/flatiron/center-for-computational-biology/structural-and-molecular-biophysics-collaboration/heterogeneity-in-cryo-electron-microscopy/).

# Scope
Expand All @@ -32,7 +32,7 @@ The data is available via the Open Science Foundation project [The Inaugural Fla

# Installation

## Stable installation
## Stable installation
Installing this repository is simply. We recommend creating a virtual environment (using conda or pyenv), since we have dependencies such as PyTorch or Aspire, which are better dealt with in an isolated environment. After creating your environment, make sure to activate it and run

```bash
Expand Down
2 changes: 1 addition & 1 deletion config_files/config_distribution_to_distribution.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ cvxpy_solver: ECOS
optimal_q_kl:
n_iter: 100000
break_atol: 0.0001
output_fname: results/distribution_to_distribution_submission_0.pkl
output_fname: results/distribution_to_distribution_submission_0.pkl
10 changes: 5 additions & 5 deletions config_files/config_map_to_map_distance_matrix.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
data:
n_pix: 224
psize: 2.146
psize: 2.146
submission:
fname: data/submission_0.pt
volume_key: volumes
metadata_key: populations
label_key: id
ground_truth:
volumes: data/maps_gt_flat.pt
metadata: data/metadata.csv
mask:
volumes: data/maps_gt_flat.pt
metadata: data/metadata.csv
mask:
do: true
volume: data/mask_dilated_wide_224x224.mrc
analysis:
Expand All @@ -23,4 +23,4 @@ analysis:
normalize:
do: true
method: median_zscore
output: results/map_to_map_distance_matrix_submission_0.pkl
output: results/map_to_map_distance_matrix_submission_0.pkl
13 changes: 9 additions & 4 deletions config_files/config_svd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ path_to_volumes: /path/to/volumes
box_size_ds: 32
submission_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"

power_spectrum_normalization:
ref_vol_key: "FLAVOR" # which submission should be used
ref_vol_index: 0 # which volume of that submission should be used

# optional unless experiment_mode is "all_vs_ref"
path_to_reference: /path/to/reference
dtype: "float32" # options are "float32", "float64"
output_options:
# path will be created if it does not exist
output_path: /path/to/output
# path to file will be created if it does not exist
output_file: /path/to/output_file.pt
# whether or not to save the processed volumes (downsampled, normalized, etc.)
save_volumes: True
save_volumes: False
# whether or not to save the SVD matrices (U, S, V)
save_svd_matrices: True
save_svd_matrices: False
4 changes: 2 additions & 2 deletions src/cryo_challenge/_commands/run_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def main(args):
config = yaml.safe_load(file)

validate_config_svd(config)
warnexists(config["output_options"]["output_path"])
mkbasedir(config["output_options"]["output_path"])
warnexists(config["output_options"]["output_file"])
mkbasedir(os.path.dirname(config["output_options"]["output_file"]))

if config["experiment_mode"] == "all_vs_all":
run_all_vs_all_pipeline(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
import pickle
from scipy.stats import rankdata
import yaml
import argparse
import torch
import ot

Expand All @@ -14,10 +12,12 @@


def sort_by_transport(cost):
m,n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost)
indices = np.argsort((transport * np.arange(m)[...,None]).sum(0))
return cost[:,indices], indices, transport
m, n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(
np.ones(m) / m, np.ones(n) / n, cost
)
indices = np.argsort((transport * np.arange(m)[..., None]).sum(0))
return cost[:, indices], indices, transport


def compute_wasserstein_between_distributions_from_weights_and_cost(
Expand Down Expand Up @@ -65,15 +65,14 @@ def make_assignment_matrix(cost_matrix):


def run(config):

metadata_df = pd.read_csv(config["gt_metadata_fname"])
metadata_df.sort_values("pc1", inplace=True)

with open(config["input_fname"], "rb") as f:
data = pickle.load(f)

# user_submitted_populations = np.ones(80)/80
user_submitted_populations = data["user_submitted_populations"]#.numpy()
user_submitted_populations = data["user_submitted_populations"] # .numpy()
id = torch.load(data["config"]["data"]["submission"]["fname"])["id"]

results_dict = {}
Expand Down Expand Up @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol):
DistributionToDistributionResultsValidator.from_dict(results_dict)
with open(config["output_fname"], "wb") as f:
pickle.dump(results_dict, f)

return results_dict
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(config):
user_submission_label = submission[label_key]

# n_trunc = 10
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc]
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc]

results_dict = {}
results_dict["config"] = config
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_challenge/_ploting/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np


def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
4 changes: 4 additions & 0 deletions src/cryo_challenge/_preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
downsample_volume as downsample_volume,
downsample_submission as downsample_submission,
)
from .normalize import (
compute_power_spectrum as compute_power_spectrum,
normalize_power_spectrum as normalize_power_spectrum,
)
35 changes: 22 additions & 13 deletions src/cryo_challenge/_preprocessing/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ class SubmissionPreprocessingDataLoader(Dataset):

def __init__(self, submission_config):
self.submission_config = submission_config
self.submission_paths, self.gt_path = self.extract_submission_paths()
self.validate_submission_config()

self.submission_paths, self.population_files, self.gt_path = (
self.extract_submission_paths()
)
self.subs_index = [int(idx) for idx in list(self.submission_config.keys())[1:]]
path_to_gt_ref = os.path.join(
self.gt_path, self.submission_config["gt"]["ref_align_fname"]
Expand Down Expand Up @@ -53,30 +57,35 @@ def validate_submission_config(self):
raise ValueError("Box size not found for ground truth")
if "pixel_size" not in value.keys():
raise ValueError("Pixel size not found for ground truth")
if "ref_align_fname" not in value.keys():
raise ValueError(
"Reference align file name not found for ground truth"
)
continue
else:
if "path" not in value.keys():
raise ValueError(f"Path not found for submission {key}")
if "id" not in value.keys():
raise ValueError(f"ID not found for submission {key}")
if "name" not in value.keys():
raise ValueError(f"Name not found for submission {key}")
if "box_size" not in value.keys():
raise ValueError(f"Box size not found for submission {key}")
if "pixel_size" not in value.keys():
raise ValueError(f"Pixel size not found for submission {key}")
if "align" not in value.keys():
raise ValueError(f"Align not found for submission {key}")
if "populations_file" not in value.keys():
raise ValueError(f"Population file not found for submission {key}")

if not os.path.exists(value["path"]):
raise ValueError(f"Path {value['path']} does not exist")

if not os.path.isdir(value["path"]):
raise ValueError(f"Path {value['path']} is not a directory")

ids = list(self.submission_config.keys())[1:]
if ids != list(range(len(ids))):
raise ValueError(
"Submission IDs should be integers starting from 0 and increasing by 1"
)
if not os.path.exists(value["populations_file"]):
raise ValueError(
f"Population file {value['populations_file']} does not exist"
)

return

Expand Down Expand Up @@ -135,13 +144,16 @@ def help(cls):

def extract_submission_paths(self):
submission_paths = []
population_files = []
for key, value in self.submission_config.items():
if key == "gt":
gt_path = value["path"]

else:
submission_paths.append(value["path"])
return submission_paths, gt_path
population_files.append(value["populations_file"])

return submission_paths, population_files, gt_path

def __len__(self):
return len(self.submission_paths)
Expand All @@ -154,10 +166,7 @@ def __getitem__(self, idx):

assert len(vol_paths) > 0, "No volumes found in submission directory"

populations = np.loadtxt(
os.path.join(self.submission_paths[idx], "populations.txt")
)
populations = torch.from_numpy(populations)
populations = torch.from_numpy(np.loadtxt(self.population_files[idx]))

vol0 = mrcfile.open(vol_paths[0], mode="r")
volumes = torch.zeros(
Expand Down
Loading
Loading