Skip to content

Commit

Permalink
implement submission version in preprocessing pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 5, 2024
2 parents 1b21567 + 26151d3 commit 1dc79e4
Show file tree
Hide file tree
Showing 46 changed files with 508 additions and 283 deletions.
4 changes: 4 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.ipynb linguist-vendored
*.mrc filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
2 changes: 1 addition & 1 deletion .github/workflows/main_merge_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
if: github.base_ref == 'main' && github.head_ref != 'dev'
run: |
echo "ERROR: You can only merge to main from dev."
exit 1
exit 1
12 changes: 12 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Runs the Ruff linter and formatter.

name: Lint

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
25 changes: 9 additions & 16 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false


steps:
Expand All @@ -25,30 +26,22 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: 'pip' # caching pip dependencies

- name: Cache test data
id: cache_test_data
uses: actions/cache@v3
with:
path: |
tests/data
data
key: venv-${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('**/tests/scripts/fetch_test_data.sh') }}

- name: Install Git LFS
run: |
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- name: Pull LFS Files
run: git lfs pull
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
pip install pytest omegaconf
- name: Get test data from OSF
if: ${{ steps.cache_test_data.outputs.cache-hit != 'true' }}
run: |
sh tests/scripts/fetch_test_data.sh
- name: Test with pytest
run: |
pytest tests/test_preprocessing.py
pytest tests/test_svd.py
pytest tests/test_map_to_map.py
pytest tests/test_distribution_to_distribution.py
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ data/dataset_1_submissions
data/dataset_2_ground_truth

# data for testing and resulting outputs
tests/data/Ground_truth
tests/data/dataset_2_submissions/
tests/data/unprocessed_dataset_2_submissions/submission_x/
tests/results/


Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.4
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pip install .
```

## Developer installation

First of all, make sure to have git lfs installed, otherwise you will have no access to the testing data. For installing, please follow these [guidelines](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage).

If you are interested in testing the programs previously installed, please, install the repository in development mode with the following commands:

```bash
Expand All @@ -52,7 +55,6 @@ The test included in the repo can be executed with PyTest as shown below:

```bash
cd /path/to/Cryo-EM-Heterogeneity-Challenge-1
sh tests/scripts/fetch_test_data.sh # download test data from OSF
pytest tests/test_preprocessing.py
pytest tests/test_svd.py
pytest tests/test_map_to_map.py
Expand Down
2 changes: 1 addition & 1 deletion config_files/config_distribution_to_distribution.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ cvxpy_solver: ECOS
optimal_q_kl:
n_iter: 100000
break_atol: 0.0001
output_fname: results/distribution_to_distribution_submission_0.pkl
output_fname: results/distribution_to_distribution_submission_0.pkl
10 changes: 5 additions & 5 deletions config_files/config_map_to_map_distance_matrix.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
data:
n_pix: 224
psize: 2.146
psize: 2.146
submission:
fname: data/dataset_2_ground_truth/submission_0.pt
volume_key: volumes
metadata_key: populations
label_key: id
ground_truth:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
do: true
volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc
analysis:
Expand All @@ -23,4 +23,4 @@ analysis:
normalize:
do: true
method: median_zscore
output: results/map_to_map_distance_matrix_submission_0.pkl
output: results/map_to_map_distance_matrix_submission_0.pkl
32 changes: 16 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,29 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"torch<=2.3.1",
"numpy<=2.0.0",
"natsort<=8.4.0",
"pandas<=2.2.2",
"dataclasses_json<=0.6.7",
"mrcfile<=1.5.0",
"scipy<=1.13.1",
"cvxpy<=1.5.2",
"POT<=0.9.3",
"aspire<=0.12.2",
"jupyter<=1.0.0",
"osfclient<=0.0.5",
"seaborn<=0.13.2",
"ipyfilechooser<=0.6.0",
"torch",
"numpy",
"natsort",
"pandas",
"dataclasses_json",
"mrcfile",
"scipy",
"cvxpy",
"POT",
"aspire",
"jupyter",
"osfclient",
"seaborn",
"ipyfilechooser",
"omegaconf"
]

[project.optional-dependencies]
dev = [
"pytest<=8.2.2",
"pytest",
"mypy",
"pre-commit",
"ruff",
"omegaconf<=2.3.0"
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion src/cryo_challenge/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from cryo_challenge.__about__ import __version__
from cryo_challenge.__about__ import __version__

__all__ = ["__version__"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
import pickle
from scipy.stats import rankdata
import yaml
import argparse
import torch
import ot

Expand All @@ -14,10 +12,12 @@


def sort_by_transport(cost):
m,n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost)
indices = np.argsort((transport * np.arange(m)[...,None]).sum(0))
return cost[:,indices], indices, transport
m, n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(
np.ones(m) / m, np.ones(n) / n, cost
)
indices = np.argsort((transport * np.arange(m)[..., None]).sum(0))
return cost[:, indices], indices, transport


def compute_wasserstein_between_distributions_from_weights_and_cost(
Expand Down Expand Up @@ -65,15 +65,14 @@ def make_assignment_matrix(cost_matrix):


def run(config):

metadata_df = pd.read_csv(config["gt_metadata_fname"])
metadata_df.sort_values("pc1", inplace=True)

with open(config["input_fname"], "rb") as f:
data = pickle.load(f)

# user_submitted_populations = np.ones(80)/80
user_submitted_populations = data["user_submitted_populations"]#.numpy()
user_submitted_populations = data["user_submitted_populations"] # .numpy()
id = torch.load(data["config"]["data"]["submission"]["fname"])["id"]

results_dict = {}
Expand Down Expand Up @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol):
DistributionToDistributionResultsValidator.from_dict(results_dict)
with open(config["output_fname"], "wb") as f:
pickle.dump(results_dict, f)

return results_dict
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(config):
user_submission_label = submission[label_key]

# n_trunc = 10
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc]
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc]

results_dict = {}
results_dict["config"] = config
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_challenge/_ploting/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np


def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
33 changes: 20 additions & 13 deletions src/cryo_challenge/_preprocessing/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ class SubmissionPreprocessingDataLoader(Dataset):

def __init__(self, submission_config):
self.submission_config = submission_config
self.submission_paths, self.gt_path = self.extract_submission_paths()
self.validate_submission_config()

self.submission_paths, self.population_files, self.gt_path = (
self.extract_submission_paths()
)
self.subs_index = [int(idx) for idx in list(self.submission_config.keys())[1:]]
path_to_gt_ref = os.path.join(
self.gt_path, self.submission_config["gt"]["ref_align_fname"]
Expand Down Expand Up @@ -53,12 +57,16 @@ def validate_submission_config(self):
raise ValueError("Box size not found for ground truth")
if "pixel_size" not in value.keys():
raise ValueError("Pixel size not found for ground truth")
if "ref_align_fname" not in value.keys():
raise ValueError(
"Reference align file name not found for ground truth"
)
continue
else:
if "path" not in value.keys():
raise ValueError(f"Path not found for submission {key}")
if "id" not in value.keys():
raise ValueError(f"ID not found for submission {key}")
if "name" not in value.keys():
raise ValueError(f"Name not found for submission {key}")
if "box_size" not in value.keys():
raise ValueError(f"Box size not found for submission {key}")
if "pixel_size" not in value.keys():
Expand All @@ -79,11 +87,10 @@ def validate_submission_config(self):
if not os.path.isdir(value["path"]):
raise ValueError(f"Path {value['path']} is not a directory")

ids = list(self.submission_config.keys())[1:]
if ids != list(range(len(ids))):
raise ValueError(
"Submission IDs should be integers starting from 0 and increasing by 1"
)
if not os.path.exists(value["populations_file"]):
raise ValueError(
f"Population file {value['populations_file']} does not exist"
)

return

Expand Down Expand Up @@ -142,13 +149,16 @@ def help(cls):

def extract_submission_paths(self):
submission_paths = []
population_files = []
for key, value in self.submission_config.items():
if key == "gt":
gt_path = value["path"]

else:
submission_paths.append(value["path"])
return submission_paths, gt_path
population_files.append(value["populations_file"])

return submission_paths, population_files, gt_path

def __len__(self):
return len(self.submission_paths)
Expand All @@ -160,10 +170,7 @@ def __getitem__(self, idx):
vol_paths = [vol_path for vol_path in vol_paths if "mask" not in vol_path]
assert len(vol_paths) > 0, "No volumes found in submission directory"

populations = np.loadtxt(
os.path.join(self.submission_paths[idx], "populations.txt")
)
populations = torch.from_numpy(populations)
populations = torch.from_numpy(np.loadtxt(self.population_files[idx]))

vol0 = mrcfile.open(vol_paths[0], mode="r")
volumes = torch.zeros(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def preprocess_submissions(submission_dataset, config):
submission_version = submission_dataset.submission_config[str(idx)][
"submission_version"
]
if submission_version == "0":
if str(submission_version) == "0":
submission_version = ""
else:
submission_version = f" {submission_version}"
Expand Down
24 changes: 18 additions & 6 deletions src/cryo_challenge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist
from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl
from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator
from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator
from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD
from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL
from ._validation.config_validators import (
validate_input_config_disttodist as validate_input_config_disttodist,
)
from ._validation.config_validators import (
validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl,
)
from cryo_challenge.data._validation.output_validators import (
DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator,
)
from cryo_challenge.data._validation.output_validators import (
MetricDistToDistValidator as MetricDistToDistValidator,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorEMD as ReplicateValidatorEMD,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorKL as ReplicateValidatorKL,
)
10 changes: 6 additions & 4 deletions src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):

# Reshape volumes to correct size
if volumes.dim() == 2:
box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.))))
box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0))))
volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size))
elif volumes.dim() == 4:
pass
else:
raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).")
raise ValueError(
f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)."
)

volumes_ds = torch.empty(
(volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype
Expand Down
Loading

0 comments on commit 1dc79e4

Please sign in to comment.