diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..7540267 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +*.ipynb linguist-vendored +*.mrc filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/main_merge_check.yml b/.github/workflows/main_merge_check.yml index 8e6f5c5..b4aa6e1 100644 --- a/.github/workflows/main_merge_check.yml +++ b/.github/workflows/main_merge_check.yml @@ -11,4 +11,4 @@ jobs: if: github.base_ref == 'main' && github.head_ref != 'dev' run: | echo "ERROR: You can only merge to main from dev." - exit 1 \ No newline at end of file + exit 1 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..05e06e0 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,12 @@ +# Runs the Ruff linter and formatter. + +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index f8a40f2..ffc2fc8 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -15,6 +15,7 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] + fail-fast: false steps: @@ -25,30 +26,22 @@ jobs: python-version: ${{ matrix.python-version }} cache: 'pip' # caching pip dependencies - - name: Cache test data - id: cache_test_data - uses: actions/cache@v3 - with: - path: | - tests/data - data - key: venv-${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('**/tests/scripts/fetch_test_data.sh') }} - + - name: Install Git LFS + run: | + sudo apt-get update + sudo apt-get install git-lfs + git lfs install + - name: Pull LFS Files + run: git lfs pull - name: Install dependencies run: | python -m pip install --upgrade pip pip install . pip install pytest omegaconf - - - name: Get test data from OSF - if: ${{ steps.cache_test_data.outputs.cache-hit != 'true' }} - run: | - sh tests/scripts/fetch_test_data.sh - + - name: Test with pytest run: | pytest tests/test_preprocessing.py pytest tests/test_svd.py pytest tests/test_map_to_map.py pytest tests/test_distribution_to_distribution.py - diff --git a/.gitignore b/.gitignore index 5ede44a..659f92a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,9 +4,6 @@ data/dataset_1_submissions data/dataset_2_ground_truth # data for testing and resulting outputs -tests/data/Ground_truth -tests/data/dataset_2_submissions/ -tests/data/unprocessed_dataset_2_submissions/submission_x/ tests/results/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3c79b8..2d1bb35 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml + - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.3.4 diff --git a/README.md b/README.md index 165b497..c27413b 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,9 @@ pip install . ``` ## Developer installation + +First of all, make sure to have git lfs installed, otherwise you will have no access to the testing data. For installing, please follow these [guidelines](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage). + If you are interested in testing the programs previously installed, please, install the repository in development mode with the following commands: ```bash @@ -52,7 +55,6 @@ The test included in the repo can be executed with PyTest as shown below: ```bash cd /path/to/Cryo-EM-Heterogeneity-Challenge-1 -sh tests/scripts/fetch_test_data.sh # download test data from OSF pytest tests/test_preprocessing.py pytest tests/test_svd.py pytest tests/test_map_to_map.py diff --git a/config_files/config_distribution_to_distribution.yaml b/config_files/config_distribution_to_distribution.yaml index da8b3d9..eaa94ac 100644 --- a/config_files/config_distribution_to_distribution.yaml +++ b/config_files/config_distribution_to_distribution.yaml @@ -12,4 +12,4 @@ cvxpy_solver: ECOS optimal_q_kl: n_iter: 100000 break_atol: 0.0001 -output_fname: results/distribution_to_distribution_submission_0.pkl \ No newline at end of file +output_fname: results/distribution_to_distribution_submission_0.pkl diff --git a/config_files/config_map_to_map_distance_matrix.yaml b/config_files/config_map_to_map_distance_matrix.yaml index 4302227..5a98e28 100644 --- a/config_files/config_map_to_map_distance_matrix.yaml +++ b/config_files/config_map_to_map_distance_matrix.yaml @@ -1,15 +1,15 @@ data: n_pix: 224 - psize: 2.146 + psize: 2.146 submission: fname: data/dataset_2_ground_truth/submission_0.pt volume_key: volumes metadata_key: populations label_key: id ground_truth: - volumes: data/dataset_2_ground_truth/maps_gt_flat.pt - metadata: data/dataset_2_ground_truth/metadata.csv - mask: + volumes: data/dataset_2_ground_truth/maps_gt_flat.pt + metadata: data/dataset_2_ground_truth/metadata.csv + mask: do: true volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc analysis: @@ -23,4 +23,4 @@ analysis: normalize: do: true method: median_zscore -output: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file +output: results/map_to_map_distance_matrix_submission_0.pkl diff --git a/pyproject.toml b/pyproject.toml index ba0facf..597890e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,29 +38,29 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "torch<=2.3.1", - "numpy<=2.0.0", - "natsort<=8.4.0", - "pandas<=2.2.2", - "dataclasses_json<=0.6.7", - "mrcfile<=1.5.0", - "scipy<=1.13.1", - "cvxpy<=1.5.2", - "POT<=0.9.3", - "aspire<=0.12.2", - "jupyter<=1.0.0", - "osfclient<=0.0.5", - "seaborn<=0.13.2", - "ipyfilechooser<=0.6.0", + "torch", + "numpy", + "natsort", + "pandas", + "dataclasses_json", + "mrcfile", + "scipy", + "cvxpy", + "POT", + "aspire", + "jupyter", + "osfclient", + "seaborn", + "ipyfilechooser", + "omegaconf" ] [project.optional-dependencies] dev = [ - "pytest<=8.2.2", + "pytest", "mypy", "pre-commit", "ruff", - "omegaconf<=2.3.0" ] [project.urls] diff --git a/src/cryo_challenge/__init__.py b/src/cryo_challenge/__init__.py index cafea4e..934c6a8 100644 --- a/src/cryo_challenge/__init__.py +++ b/src/cryo_challenge/__init__.py @@ -1 +1,3 @@ -from cryo_challenge.__about__ import __version__ \ No newline at end of file +from cryo_challenge.__about__ import __version__ + +__all__ = ["__version__"] diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 70961d8..18c57bb 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -2,8 +2,6 @@ import numpy as np import pickle from scipy.stats import rankdata -import yaml -import argparse import torch import ot @@ -14,10 +12,12 @@ def sort_by_transport(cost): - m,n = cost.shape - _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) - indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) - return cost[:,indices], indices, transport + m, n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost( + np.ones(m) / m, np.ones(n) / n, cost + ) + indices = np.argsort((transport * np.arange(m)[..., None]).sum(0)) + return cost[:, indices], indices, transport def compute_wasserstein_between_distributions_from_weights_and_cost( @@ -65,7 +65,6 @@ def make_assignment_matrix(cost_matrix): def run(config): - metadata_df = pd.read_csv(config["gt_metadata_fname"]) metadata_df.sort_values("pc1", inplace=True) @@ -73,7 +72,7 @@ def run(config): data = pickle.load(f) # user_submitted_populations = np.ones(80)/80 - user_submitted_populations = data["user_submitted_populations"]#.numpy() + user_submitted_populations = data["user_submitted_populations"] # .numpy() id = torch.load(data["config"]["data"]["submission"]["fname"])["id"] results_dict = {} @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol): DistributionToDistributionResultsValidator.from_dict(results_dict) with open(config["output_fname"], "wb") as f: pickle.dump(results_dict, f) - + return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 0578dfa..47d9a00 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -42,7 +42,7 @@ def run(config): user_submission_label = submission[label_key] # n_trunc = 10 - metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc] + metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc] results_dict = {} results_dict["config"] = config diff --git a/src/cryo_challenge/_ploting/plotting_utils.py b/src/cryo_challenge/_ploting/plotting_utils.py index 04d5ca9..681ab01 100644 --- a/src/cryo_challenge/_ploting/plotting_utils.py +++ b/src/cryo_challenge/_ploting/plotting_utils.py @@ -1,6 +1,7 @@ import numpy as np + def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist \ No newline at end of file + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist diff --git a/src/cryo_challenge/_preprocessing/dataloader.py b/src/cryo_challenge/_preprocessing/dataloader.py index f9da0c4..27ca57a 100644 --- a/src/cryo_challenge/_preprocessing/dataloader.py +++ b/src/cryo_challenge/_preprocessing/dataloader.py @@ -25,7 +25,11 @@ class SubmissionPreprocessingDataLoader(Dataset): def __init__(self, submission_config): self.submission_config = submission_config - self.submission_paths, self.gt_path = self.extract_submission_paths() + self.validate_submission_config() + + self.submission_paths, self.population_files, self.gt_path = ( + self.extract_submission_paths() + ) self.subs_index = [int(idx) for idx in list(self.submission_config.keys())[1:]] path_to_gt_ref = os.path.join( self.gt_path, self.submission_config["gt"]["ref_align_fname"] @@ -53,12 +57,16 @@ def validate_submission_config(self): raise ValueError("Box size not found for ground truth") if "pixel_size" not in value.keys(): raise ValueError("Pixel size not found for ground truth") + if "ref_align_fname" not in value.keys(): + raise ValueError( + "Reference align file name not found for ground truth" + ) continue else: if "path" not in value.keys(): raise ValueError(f"Path not found for submission {key}") - if "id" not in value.keys(): - raise ValueError(f"ID not found for submission {key}") + if "name" not in value.keys(): + raise ValueError(f"Name not found for submission {key}") if "box_size" not in value.keys(): raise ValueError(f"Box size not found for submission {key}") if "pixel_size" not in value.keys(): @@ -79,11 +87,10 @@ def validate_submission_config(self): if not os.path.isdir(value["path"]): raise ValueError(f"Path {value['path']} is not a directory") - ids = list(self.submission_config.keys())[1:] - if ids != list(range(len(ids))): - raise ValueError( - "Submission IDs should be integers starting from 0 and increasing by 1" - ) + if not os.path.exists(value["populations_file"]): + raise ValueError( + f"Population file {value['populations_file']} does not exist" + ) return @@ -142,13 +149,16 @@ def help(cls): def extract_submission_paths(self): submission_paths = [] + population_files = [] for key, value in self.submission_config.items(): if key == "gt": gt_path = value["path"] else: submission_paths.append(value["path"]) - return submission_paths, gt_path + population_files.append(value["populations_file"]) + + return submission_paths, population_files, gt_path def __len__(self): return len(self.submission_paths) @@ -160,10 +170,7 @@ def __getitem__(self, idx): vol_paths = [vol_path for vol_path in vol_paths if "mask" not in vol_path] assert len(vol_paths) > 0, "No volumes found in submission directory" - populations = np.loadtxt( - os.path.join(self.submission_paths[idx], "populations.txt") - ) - populations = torch.from_numpy(populations) + populations = torch.from_numpy(np.loadtxt(self.population_files[idx])) vol0 = mrcfile.open(vol_paths[0], mode="r") volumes = torch.zeros( diff --git a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py index b6dea6f..589239c 100644 --- a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py +++ b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py @@ -122,7 +122,7 @@ def preprocess_submissions(submission_dataset, config): submission_version = submission_dataset.submission_config[str(idx)][ "submission_version" ] - if submission_version == "0": + if str(submission_version) == "0": submission_version = "" else: submission_version = f" {submission_version}" diff --git a/src/cryo_challenge/data/__init__.py b/src/cryo_challenge/data/__init__.py index 8b4655c..fb27bbd 100644 --- a/src/cryo_challenge/data/__init__.py +++ b/src/cryo_challenge/data/__init__.py @@ -1,6 +1,18 @@ -from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist -from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl -from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator -from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator -from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD -from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL +from ._validation.config_validators import ( + validate_input_config_disttodist as validate_input_config_disttodist, +) +from ._validation.config_validators import ( + validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, +) +from cryo_challenge.data._validation.output_validators import ( + DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, +) +from cryo_challenge.data._validation.output_validators import ( + MetricDistToDistValidator as MetricDistToDistValidator, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorEMD as ReplicateValidatorEMD, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorKL as ReplicateValidatorKL, +) diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index f194c14..2a4d954 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): # Reshape volumes to correct size if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.)))) + box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) elif volumes.dim() == 4: pass else: - raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).") + raise ValueError( + f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " + f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " + f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." + ) volumes_ds = torch.empty( (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 93316a0..5acc6ba 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -1,7 +1,7 @@ from numbers import Number import pandas as pd import os -from typing import List + def validate_generic_config(config: dict, reference: dict) -> None: """ diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 35d9791..39cbb67 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -13,7 +13,7 @@ @dataclass_json @dataclass class MapToMapResultsValidator: - ''' + """ Validate the output dictionary of the map-to-map distance matrix computation. config: dict, input config dictionary. @@ -22,7 +22,8 @@ class MapToMapResultsValidator: l2: dict, L2 results. bioem: dict, BioEM results. fsc: dict, FSC results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None @@ -49,7 +50,7 @@ class ReplicateValidatorEMD: Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline. q_opt: List[float], optimal user submitted distribution, which sums to 1. - EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). + EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). The transport plan is a joint distribution, such that: summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution. transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns). @@ -61,6 +62,7 @@ class ReplicateValidatorEMD: The transport plan is a joint distribution, such that: summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution. """ + q_opt: List[float] EMD_opt: float transport_plan_opt: List[List[float]] @@ -87,8 +89,9 @@ class ReplicateValidatorKL: iter_stop: int, number of iterations until convergence. eps_stop: float, stopping criterion. klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q). - klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). + klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). """ + q_opt: List[float] klpq_opt: float klqp_opt: float @@ -106,11 +109,12 @@ def __post_init__(self): @dataclass_json @dataclass class MetricDistToDistValidator: - ''' + """ Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline. replicates: dict, dictionary of replicates. - ''' + """ + replicates: dict def validate_replicates(self, n_replicates): @@ -126,7 +130,7 @@ def validate_replicates(self, n_replicates): @dataclass_json @dataclass class DistributionToDistributionResultsValidator: - ''' + """ Validate the output dictionary of the distribution-to-distribution pipeline. config: dict, input config dictionary. @@ -136,7 +140,8 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor id: str diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index c70d5aa..ecc5cde 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -2,7 +2,7 @@ data: n_pix: 224 psize: 2.146 submission: - fname: tests/data/dataset_2_submissions/test_submission_0_n8.pt + fname: tests/data/dataset_2_submissions/submission_10000.pt volume_key: volumes metadata_key: populations label_key: id diff --git a/tests/config_files/test_config_svd.yaml b/tests/config_files/test_config_svd.yaml index c392525..ff0b413 100644 --- a/tests/config_files/test_config_svd.yaml +++ b/tests/config_files/test_config_svd.yaml @@ -1,6 +1,6 @@ path_to_volumes: tests/data/dataset_2_submissions/ box_size_ds: 32 -submission_list: [0] +submission_list: [10000] experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref" # optional unless experiment_mode is "all_vs_ref" path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.pt diff --git a/tests/data/Ground_truth/1.mrc b/tests/data/Ground_truth/1.mrc new file mode 100644 index 0000000..8c2a13e --- /dev/null +++ b/tests/data/Ground_truth/1.mrc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f866dcc00684a672ad12f5bfd61254843a0213a4b34efbe1510e002e3e357f20 +size 44958720 diff --git a/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc b/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc new file mode 100644 index 0000000..7b02aea --- /dev/null +++ b/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2483eb52fbeb28a2d5851749270f0127fb287563332c923b33312bef567627d4 +size 44958720 diff --git a/tests/data/Ground_truth/test_maps_gt_flat_10.npy b/tests/data/Ground_truth/test_maps_gt_flat_10.npy new file mode 100644 index 0000000..983e61b --- /dev/null +++ b/tests/data/Ground_truth/test_maps_gt_flat_10.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b4733700fca5d5b1de2a8e106a46259c4a9b19b81f946776929e2cde0bdbd6e +size 449577088 diff --git a/tests/data/Ground_truth/test_maps_gt_flat_10.pt b/tests/data/Ground_truth/test_maps_gt_flat_10.pt new file mode 100644 index 0000000..d0c5e40 --- /dev/null +++ b/tests/data/Ground_truth/test_maps_gt_flat_10.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e204191c9536d8a0d09fec5e5648acfa6a09c5182f60c0b3c6744ac5c4a3cb82 +size 449577746 diff --git a/tests/data/Ground_truth/test_metadata_10.csv b/tests/data/Ground_truth/test_metadata_10.csv new file mode 100644 index 0000000..de041fc --- /dev/null +++ b/tests/data/Ground_truth/test_metadata_10.csv @@ -0,0 +1,11 @@ +index,volumes,populations_count,pc1,populations +3238,13396.mrc,1,-231.62100638454024,2.9636654614427123e-05 +3020,10063.mrc,13,-179.257841640357,0.0003852765099875 +1113,01421.mrc,12,-141.8237629192062,0.0003556398553731 +3592,21858.mrc,4,-101.62462603216916,0.0001185466184577 +1947,03298.mrc,5,-34.99878410436052,0.0001481832730721 +2097,03764.mrc,6,3.946553364334135,0.0001778199276865 +1574,02336.mrc,5,44.70670231717438,0.0001481832730721 +2813,08011.mrc,8,108.6308222660271,0.0002370932369154 +88,00090.mrc,21,147.70416251702042,0.0006223697469029 +771,00906.mrc,11,186.3446095998357,0.0003260032007586 diff --git a/tests/data/dataset_2_submissions/submission_10000.pt b/tests/data/dataset_2_submissions/submission_10000.pt new file mode 100644 index 0000000..b3a67d4 --- /dev/null +++ b/tests/data/dataset_2_submissions/submission_10000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9abfc103cbf4a8885b22716cfb3ae0a9bb038f81a2b5f3c1e237ec14d407dba6 +size 359662738 diff --git a/tests/data/test_maps_gt_flat_2.pt b/tests/data/test_maps_gt_flat_2.pt index dc3293c..f74c145 100644 Binary files a/tests/data/test_maps_gt_flat_2.pt and b/tests/data/test_maps_gt_flat_2.pt differ diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc new file mode 100644 index 0000000..8c2a13e --- /dev/null +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f866dcc00684a672ad12f5bfd61254843a0213a4b34efbe1510e002e3e357f20 +size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc new file mode 100644 index 0000000..4381ab3 --- /dev/null +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4551a156c3fc495f8e895ee8c37cea397b04b06902e5d6bb2516e8699fe9bc8e +size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc new file mode 100644 index 0000000..68fb0e9 --- /dev/null +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e46737d911b12fa0624169b0f5c6ae746df9b5c68114d6a76831b5c974fe114 +size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc new file mode 100644 index 0000000..9f7b60e --- /dev/null +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:203123b395a1d39a90c1d46f071ffaa096299c9d969c35cf532afe2c5c403641 +size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt b/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt new file mode 100644 index 0000000..53e4b60 --- /dev/null +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt @@ -0,0 +1,4 @@ +0.25 +0.25 +0.25 +0.25 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index 2fd1766..03bb0b0 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -8,11 +8,22 @@ }, "0": { "name": "raw_submission_in_testdata", - "submission_version": "0", - "flip": 0, + "submission_version": 0, + "flip": 1, "align": 1, "box_size": 244, "pixel_size": 2.146, - "path": "tests/data/unprocessed_dataset_2_submissions/submission_x" + "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", + "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt" + }, + "1": { + "name": "raw_submission_in_testdata", + "submission_version": "1", + "flip": 0, + "align": 0, + "box_size": 244, + "pixel_size": 2.146, + "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", + "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt" } } diff --git a/tests/scripts/fetch_test_data.sh b/tests/scripts/fetch_test_data.sh deleted file mode 100644 index c252871..0000000 --- a/tests/scripts/fetch_test_data.sh +++ /dev/null @@ -1,12 +0,0 @@ -mkdir -p tests/data/dataset_2_submissions tests/data/dataset_2_submissions tests/results tests/data/unprocessed_dataset_2_submissions/submission_x tests/data/Ground_truth/ tests/data/Ground_truth -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/dataset_2_submissions/test_submission_0_n8.pt?download=true -O tests/data/dataset_2_submissions/test_submission_0_n8.pt -ADIR=$(pwd) -ln -s $ADIR/tests/data/dataset_2_submissions/test_submission_0_n8.pt $ADIR/tests/data/dataset_2_submissions/submission_0.pt # symlink for svd which needs submission_0.pt for filename -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_maps_gt_flat_10.pt?download=true -O tests/data/Ground_truth/test_maps_gt_flat_10.pt -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_metadata_10.csv?download=true -O tests/data/Ground_truth/test_metadata_10.csv -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/1.mrc?download=true -O tests/data/Ground_truth/1.mrc -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/Ground_truth/mask_dilated_wide_224x224.mrc?download=true -O tests/data/Ground_truth/mask_dilated_wide_224x224.mrc -for FILE in 1.mrc 2.mrc 3.mrc 4.mrc populations.txt -do - wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/unprocessed_dataset_2_submissions/submission_x/${FILE}?download=true -O tests/data/unprocessed_dataset_2_submissions/submission_x/${FILE} -done diff --git a/tests/test_distribution_to_distribution.py b/tests/test_distribution_to_distribution.py index d9c340b..a4cfb79 100644 --- a/tests/test_distribution_to_distribution.py +++ b/tests/test_distribution_to_distribution.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_distribution2distribution_pipeline -def test_run_distribution2distribution_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'}) - run_distribution2distribution_pipeline.main(args) \ No newline at end of file +def test_run_distribution2distribution_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_distribution_to_distribution.yaml"} + ) + run_distribution2distribution_pipeline.main(args) diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c782a8c..e31f29f 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_map2map_pipeline -def test_run_map2map_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'}) - run_map2map_pipeline.main(args) \ No newline at end of file +def test_run_map2map_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map.yaml"} + ) + run_map2map_pipeline.main(args) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index cbf54e4..31db34e 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_preprocessing -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'}) - run_preprocessing.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"}) + run_preprocessing.main(args) diff --git a/tests/test_svd.py b/tests/test_svd.py index 568370e..ea166ea 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'}) - run_svd.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) + run_svd.main(args) diff --git a/tutorials/1_tutorial_preprocessing.ipynb b/tutorials/1_tutorial_preprocessing.ipynb index 0c718e4..b2063ca 100644 --- a/tutorials/1_tutorial_preprocessing.ipynb +++ b/tutorials/1_tutorial_preprocessing.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-06-17T15:40:12.854854Z", @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-06-17T15:40:20.557563Z", @@ -30,7 +30,6 @@ "import os\n", "import torch\n", "import matplotlib.pyplot as plt\n", - "import numpy as np\n", "import yaml\n", "from ipyfilechooser import FileChooser" ] @@ -80,6 +79,17 @@ "display(submission1_path)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select path to populations (submission 1)\n", + "submission1_pop_path = FileChooser(path_to_sub_set.selected_path)\n", + "display(submission1_pop_path)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -97,6 +107,26 @@ "display(submission2_path)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select path to populations (submission 2)\n", + "submission2_pop_path = FileChooser(path_to_sub_set.selected_path)\n", + "display(submission2_pop_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "submission2_pop_path.selected" + ] + }, { "cell_type": "code", "execution_count": null, @@ -116,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-06-13T07:40:59.387306Z", @@ -134,25 +164,31 @@ " \"ref_align_fname\": \"1.mrc\",\n", " },\n", " 0: {\n", + " \"path\": submission1_path.selected_path,\n", + " \"populations_file\": submission1_pop_path.selected,\n", " \"name\": \"submission1\",\n", - " \"align\": 0,\n", + " \"submission_version\": 0, # does not change the submission id\n", " \"box_size\": 144,\n", " \"pixel_size\": 1.073 * 2,\n", - " \"path\": submission1_path.selected_path,\n", + " \"flip\": 0,\n", + " \"align\": 0,\n", " },\n", " 1: {\n", + " \"path\": submission2_path.selected_path,\n", + " \"populations_file\": submission2_pop_path.selected,\n", " \"name\": \"submission2\",\n", - " \"align\": 1,\n", + " \"submission_version\": 1, # makes the id \"ice cream name 1\"\n", " \"box_size\": 288,\n", " \"pixel_size\": 1.073,\n", - " \"path\": submission2_path.selected_path,\n", + " \"flip\": 1, # flip the z axis. DO AN ALIGN if you set this to 1\n", + " \"align\": 1,\n", " },\n", "}" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-06-13T07:41:01.194466Z", @@ -174,17 +210,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "After you create your submission_config, simply grab a copy of the file \"config_preproc.yaml\" from the provided config_files, and change the path for the \"submission_config_file\" to the file we created in the previous cell. Also change the path for the output. The rest of the parameters you can leave untouched. Please see the publication \"Singer, A., & Yang, R. (2024). Alignment of density maps in Wasserstein distance. Biological Imaging, 4, e5\" for more details. Then simply run\n", + "Lastly, to run the preprocessing pipeline follow these steps\n", + "\n", + "0. Make sure to activate your environment and have the package installed!\n", + "\n", + "1. Grab a copy of the file `config_preproc.yaml`from our config file templates.\n", "\n", - "```bash\n", - "cryo_challenge run_preprocessing --config /path/to/config_preproc.yaml\n", - "```\n", + "2. In the copied config file, update the value of `submission_config_file` to match the path to the file we created in the last cell.\n", "\n", - "Note: make sure to activate your environment and have the package installed!\n", + "3. Optionally, change the other parameters. \n", + " * Most of the parameters (BOT_* and thresh_percentile) are for the alignment. For details on how they work, please see the publication \"Singer, A., & Yang, R. (2024). Alignment of density maps in Wasserstein distance. Biological Imaging, 4, e5\" for more details. \n", "\n", - "You can run the following cell to visualize your volumes (more precisely, a projection of them)\n", + " * The other parameters are self explanatory, \"seed_flavor_assignment\" changes which submission gets assigned which ice cream flavor, keep this if you want to revert anonymity.\n", "\n", - "IMPORTANT: The execution of the previous program relies on the existence of file to be saved at {{ submission1_path.selected_path }} with a specific formatting. The file must be named \"populations.txt\", and should be formatted as a single row/column CSV file containing the populations computed from your results. If the previous file is not included, the execution of the program will result in a runtime error." + "4. Run the command: `cryo_challenge run_preprocessing --config /path/to/config_preproc.yaml`\n", + "\n", + "You can run the following cell to visualize your volumes (more precisely, a projection of them)\n" ] }, { @@ -201,13 +242,13 @@ "# Select path to Config file\n", "# An example of this file is available in the path ../config_files/config_preproc.yaml\n", "config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_preproc_path.filter_pattern = '*.yaml'\n", + "config_preproc_path.filter_pattern = \"*.yaml\"\n", "display(config_preproc_path)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-06-13T07:43:16.259106Z", @@ -224,7 +265,7 @@ "if os.path.isabs(output_path):\n", " full_output_path = output_path\n", "else:\n", - " full_output_path = os.path.join(os.getcwd(), '..', output_path)" + " full_output_path = os.path.join(os.getcwd(), \"..\", output_path)" ] }, { @@ -240,12 +281,10 @@ "source": [ "n_submissions = 2 # change this to however many submissions you preprocessed\n", "\n", - "fig, ax = plt.subplots(2, 6, figsize=(20, 8)) # change values here too\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 4)) # change values here too\n", "\n", "for i in range(n_submissions):\n", - " idx = np.random.randint(\n", - " 0, 20\n", - " ) # doing random volumes to check that everything went fine\n", + " idx = 0\n", "\n", " submission = torch.load(os.path.join(full_output_path, f\"submission_{i}.pt\"))\n", " print(submission[\"volumes\"].shape, submission[\"id\"])\n", @@ -256,9 +295,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "cryo-challenge-kernel", "language": "python", - "name": "python3" + "name": "cryo-challenge-kernel" }, "language_info": { "codemirror_mode": { @@ -270,7 +309,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.10.10" } }, "nbformat": 4, diff --git a/tutorials/2_tutorial_svd.ipynb b/tutorials/2_tutorial_svd.ipynb index b41bfba..fe8f432 100644 --- a/tutorials/2_tutorial_svd.ipynb +++ b/tutorials/2_tutorial_svd.ipynb @@ -62,7 +62,7 @@ "# Select path to SVD config file\n", "# An example of this file is available in the path ../config_files/config_svd.yaml\n", "config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_svd_path.filter_pattern = '*.yaml'\n", + "config_svd_path.filter_pattern = \"*.yaml\"\n", "display(config_svd_path)" ] }, @@ -125,7 +125,7 @@ "source": [ "# Select path to SVD results\n", "svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_results_path.filter_pattern = '*.pt'\n", + "svd_results_path.filter_pattern = \"*.pt\"\n", "display(svd_results_path)" ] }, @@ -316,7 +316,7 @@ "source": [ "# Select path to SVD results\n", "svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_all_vs_all_results_path.filter_pattern = '*.pt'\n", + "svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n", "display(svd_all_vs_all_results_path)" ] }, diff --git a/tutorials/3_tutorial_map2map.ipynb b/tutorials/3_tutorial_map2map.ipynb index a3701ff..0497578 100644 --- a/tutorials/3_tutorial_map2map.ipynb +++ b/tutorials/3_tutorial_map2map.ipynb @@ -23,15 +23,8 @@ "\n", "from cryo_challenge.data._validation.config_validators import (\n", " validate_input_config_mtm,\n", - " validate_config_mtm_data, \n", - " validate_config_mtm_data_submission, \n", - " validate_config_mtm_data_ground_truth, \n", - " validate_config_mtm_data_mask, \n", - " validate_config_mtm_analysis, \n", - " validate_config_mtm_analysis_normalize, \n", - " )\n", - "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator\n", - "from cryo_challenge.data._validation.config_validators import validate_maptomap_result" + ")\n", + "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator" ] }, { @@ -80,7 +73,7 @@ "# Select path to Map to Map config file\n", "# An example of this file is available in the path ../config_files/config_map_to_map.yaml\n", "config_m2m_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_m2m_path.filter_pattern = '*.yaml'\n", + "config_m2m_path.filter_pattern = \"*.yaml\"\n", "display(config_m2m_path)" ] }, @@ -341,8 +334,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { diff --git a/tutorials/4_tutorial_distribution2distribution.ipynb b/tutorials/4_tutorial_distribution2distribution.ipynb index 07dc1d9..2a6fc53 100644 --- a/tutorials/4_tutorial_distribution2distribution.ipynb +++ b/tutorials/4_tutorial_distribution2distribution.ipynb @@ -30,7 +30,7 @@ "import pickle\n", "from ipyfilechooser import FileChooser\n", "\n", - "from cryo_challenge.data import validate_input_config_disttodist, validate_config_dtd_optimal_q_kl\n", + "from cryo_challenge.data import validate_input_config_disttodist\n", "from cryo_challenge.data import DistributionToDistributionResultsValidator" ] }, @@ -65,7 +65,7 @@ "# Select path to Distribution to Distribution config file\n", "# An example of this file is available in the path ../config_files/config_distribution_to_distribution.yaml\n", "config_d2d_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_d2d_path.filter_pattern = '*.yaml'\n", + "config_d2d_path.filter_pattern = \"*.yaml\"\n", "display(config_d2d_path)" ] }, @@ -199,8 +199,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output_fname\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output_fname\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { @@ -286,7 +286,6 @@ } ], "source": [ - "from cryo_challenge.data import MetricDistToDistValidator\n", "MetricDistToDistValidator?" ] }, @@ -302,9 +301,7 @@ "execution_count": 30, "metadata": {}, "outputs": [], - "source": [ - "from cryo_challenge.data import ReplicateValidatorEMD, ReplicateValidatorKL" - ] + "source": [] }, { "cell_type": "code", diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index ed8a924..e2648c8 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -23,7 +23,9 @@ "metadata": {}, "outputs": [], "source": [ - "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import (\n", + " sort_by_transport,\n", + ")\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -102,6 +104,7 @@ " map2map_results: List[str]\n", " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", + "\n", "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", "config = PlottingConfig.from_dict(config)\n", @@ -115,7 +118,7 @@ "outputs": [], "source": [ "metadata_df = pd.read_csv(config.gt_metadata)\n", - "metadata_df.sort_values('pc1', inplace=True)\n", + "metadata_df.sort_values(\"pc1\", inplace=True)\n", "gt_ordering = metadata_df.index.tolist()" ] }, @@ -136,12 +139,14 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[map2map_distance]['user_submission_label']\n", + " anonymous_label = data[map2map_distance][\"user_submission_label\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", - "data_d = get_fsc_distances(config.map2map_results, 'fsc')" + "\n", + "\n", + "data_d = get_fsc_distances(config.map2map_results, \"fsc\")" ] }, { @@ -162,22 +167,24 @@ ], "source": [ "def plot_fsc_distances(data_d, gt_ordering):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'$d_{FSC}$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"$d_{FSC}$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, data) in enumerate(data_d.items()):\n", - " map2map_dist_matrix = data['fsc']['cost_matrix'].iloc[gt_ordering].values\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", - "\n", + " map2map_dist_matrix = data[\"fsc\"][\"cost_matrix\"].iloc[gt_ordering].values\n", + " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(\n", + " map2map_dist_matrix\n", + " )\n", "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", @@ -185,14 +192,24 @@ " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", "plot_fsc_distances(data_d, gt_ordering)" ] @@ -214,12 +231,13 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['fsc']['user_submission_label']\n", - " data_d[anonymous_label] = data['fsc']['computed_assets']['fsc_matrix']\n", + " anonymous_label = data[\"fsc\"][\"user_submission_label\"]\n", + " data_d[anonymous_label] = data[\"fsc\"][\"computed_assets\"][\"fsc_matrix\"]\n", " return data_d\n", "\n", + "\n", "fscs_sorted_d = get_full_fsc_curve(config.map2map_results)" ] }, @@ -229,9 +247,11 @@ "metadata": {}, "outputs": [], "source": [ - "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", - "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" + "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(\n", + " fscs_sorted_d[\"Cookie Dough\"], threshold=0.5\n", + ")\n", + "n_fourier_bins = fscs_sorted_d[\"Cookie Dough\"].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1, n_fourier_bins + 1) / n_fourier_bins)" ] }, { @@ -261,7 +281,7 @@ } ], "source": [ - "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect='auto', cmap='Blues_r')\n", + "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect=\"auto\", cmap=\"Blues_r\")\n", "plt.colorbar()" ] }, @@ -283,17 +303,18 @@ ], "source": [ "def plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={}):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'Resolution $(\\AA)$ at $FSC=0.5$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"Resolution $(\\AA)$ at $FSC=0.5$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", @@ -302,27 +323,38 @@ "\n", " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", - "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", " vmin = map2map_dist_matrix.min()\n", " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", - " if 'vmax' in overwrite_dict.keys():\n", - " vmax = overwrite_dict['vmax']\n", - " if 'vmin' in overwrite_dict.keys():\n", - " vmin = overwrite_dict['vmin']\n", - "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " if \"vmax\" in overwrite_dict.keys():\n", + " vmax = overwrite_dict[\"vmax\"]\n", + " if \"vmin\" in overwrite_dict.keys():\n", + " vmin = overwrite_dict[\"vmin\"]\n", + "\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", - "plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={'vmax': 31})" + "plot_res_at_fsc_threshold_distances(\n", + " fscs_sorted_d, gt_ordering, overwrite_dict={\"vmax\": 31}\n", + ")" ] }, { @@ -338,9 +370,9 @@ "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", + "fname = config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"][0]\n", "\n", - "with open(fname, 'rb') as f:\n", + "with open(fname, \"rb\") as f:\n", " data = pickle.load(f)" ] }, @@ -354,13 +386,16 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['id']\n", + " anonymous_label = data[\"id\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" + "\n", + "dist2dist_results_d = get_dist2dist_results(\n", + " config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"]\n", + ")" ] }, { @@ -381,46 +416,58 @@ ], "source": [ "window_size = 15\n", - "nrows, ncols = 3,4\n", + "nrows, ncols = 3, 4\n", "suptitle = f\"Submitted populations vs optimal populations \\n d_FSC (no rank) | n_replicates={data['config']['n_replicates']} | window_size={window_size} | n_pool_microstate={data['config']['n_pool_microstate']}\"\n", "\n", + "\n", "def plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols):\n", + " fig, axes = plt.subplots(nrows, ncols, figsize=(40, 25))\n", "\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=(40,25))\n", - " \n", - " fig.suptitle(\n", - " suptitle,\n", - " fontsize=30,\n", - " y=0.95)\n", + " fig.suptitle(suptitle, fontsize=30, y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", - " \n", + " for idx_fname, (_, data) in enumerate(dist2dist_results_d.items()):\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " data[\"user_submitted_populations\"], color=\"black\", label=\"submited\"\n", + " )\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_title(data[\"id\"], fontsize=30)\n", "\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_title(data['id'], fontsize=30)\n", - "\n", - " def window_q(q_opt,window_size):\n", - " running_avg = np.convolve(q_opt, np.ones(window_size)/window_size, mode='same')\n", + " def window_q(q_opt, window_size):\n", + " running_avg = np.convolve(\n", + " q_opt, np.ones(window_size) / window_size, mode=\"same\"\n", + " )\n", " return running_avg\n", - " \n", - " for replicate_idx in range(data['config']['n_replicates']):\n", - " \n", - " if replicate_idx == 0: \n", - " label_d = {'EMD': 'EMD', 'KL': 'KL', 'KL_raw': 'Unwindowed', 'EMD_raw': 'Unwindowed'}\n", - " else:\n", - " label_d = {'EMD': None, 'KL': None, 'KL_raw': None, 'EMD_raw': None}\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(window_q(data['fsc']['replicates'][replicate_idx]['EMD']['q_opt'],window_size), color='blue', alpha=alpha, label=label_d['EMD'])\n", - "\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_xlabel('Submission index')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_ylabel('Population')\n", "\n", - " legend = axes[idx_fname//ncols, idx_fname%ncols].legend()\n", + " for replicate_idx in range(data[\"config\"][\"n_replicates\"]):\n", + " if replicate_idx == 0:\n", + " label_d = {\n", + " \"EMD\": \"EMD\",\n", + " \"KL\": \"KL\",\n", + " \"KL_raw\": \"Unwindowed\",\n", + " \"EMD_raw\": \"Unwindowed\",\n", + " }\n", + " else:\n", + " label_d = {\"EMD\": None, \"KL\": None, \"KL_raw\": None, \"EMD_raw\": None}\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " window_q(\n", + " data[\"fsc\"][\"replicates\"][replicate_idx][\"EMD\"][\"q_opt\"],\n", + " window_size,\n", + " ),\n", + " color=\"blue\",\n", + " alpha=alpha,\n", + " label=label_d[\"EMD\"],\n", + " )\n", + "\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_xlabel(\"Submission index\")\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_ylabel(\"Population\")\n", + "\n", + " legend = axes[idx_fname // ncols, idx_fname % ncols].legend()\n", " for line, text in zip(legend.get_lines(), legend.get_texts()):\n", " text.set_color(line.get_color())\n", " line.set_alpha(1)\n", "\n", - "plot_q_opt_distances(dist2dist_results_d, suptitle,nrows, ncols)" + "\n", + "plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols)" ] }, { @@ -437,7 +484,6 @@ "outputs": [], "source": [ "def wragle_pkl_to_dataframe(pkl_globs):\n", - "\n", " fnames = []\n", " for fname_glob in pkl_globs:\n", " fnames.extend(glob.glob(fname_glob))\n", @@ -446,28 +492,50 @@ "\n", " df_list = []\n", " n_replicates = 30\n", - " metric = 'fsc'\n", + " metric = \"fsc\"\n", "\n", " for fname in fnames:\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", "\n", - " df_list.append(pd.DataFrame({\n", - " 'EMD_opt': [data[metric]['replicates'][i]['EMD']['EMD_opt'] for i in range(n_replicates)],\n", - " 'EMD_submitted': [data[metric]['replicates'][i]['EMD']['EMD_submitted'] for i in range(n_replicates)],\n", - " 'klpq_opt': [data[metric]['replicates'][i]['KL']['klpq_opt'] for i in range(n_replicates)],\n", - " 'klqp_opt': [data[metric]['replicates'][i]['KL']['klqp_opt'] for i in range(n_replicates)],\n", - " 'klpq_submitted': [data[metric]['replicates'][i]['KL']['klpq_submitted'] for i in range(n_replicates)], \n", - " 'klqp_submitted': [data[metric]['replicates'][i]['KL']['klqp_submitted'] for i in range(n_replicates)], \n", - " 'id': data['id'],\n", - " 'n_pool_microstate': data['config']['n_pool_microstate'],\n", - " }))\n", + " df_list.append(\n", + " pd.DataFrame(\n", + " {\n", + " \"EMD_opt\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"EMD_submitted\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"id\": data[\"id\"],\n", + " \"n_pool_microstate\": data[\"config\"][\"n_pool_microstate\"],\n", + " }\n", + " )\n", + " )\n", "\n", " df = pd.concat(df_list)\n", - " df['EMD_opt_norm'] = df['EMD_opt'] / df['n_pool_microstate']\n", - " df['EMD_submitted_norm'] = df['EMD_submitted'] / df['n_pool_microstate']\n", + " df[\"EMD_opt_norm\"] = df[\"EMD_opt\"] / df[\"n_pool_microstate\"]\n", + " df[\"EMD_submitted_norm\"] = df[\"EMD_submitted\"] / df[\"n_pool_microstate\"]\n", "\n", - " return df\n" + " return df" ] }, { @@ -476,7 +544,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results[\"emd_plot\"][\"pkl_globs\"])" ] }, { @@ -498,50 +566,96 @@ "source": [ "def plot_EMD_vs_EMDopt(df, suptitle=None):\n", " alpha = 1\n", - " df_average = df.groupby(['n_pool_microstate','id']).mean().reset_index()\n", - " df_std = df_average.groupby(['id']).std().reset_index().filter(['EMD_opt_norm','EMD_submitted_norm', 'id']).rename(columns={'EMD_opt_norm':'EMD_opt_norm_std', 'EMD_submitted_norm':'EMD_submitted_norm_std'})\n", - " df_average = df.groupby(['id']).mean().reset_index()\n", - "\n", - " df_average_and_error = pd.merge(df_average, df_std, on='id')\n", + " df_average = df.groupby([\"n_pool_microstate\", \"id\"]).mean().reset_index()\n", + " df_std = (\n", + " df_average.groupby([\"id\"])\n", + " .std()\n", + " .reset_index()\n", + " .filter([\"EMD_opt_norm\", \"EMD_submitted_norm\", \"id\"])\n", + " .rename(\n", + " columns={\n", + " \"EMD_opt_norm\": \"EMD_opt_norm_std\",\n", + " \"EMD_submitted_norm\": \"EMD_submitted_norm_std\",\n", + " }\n", + " )\n", + " )\n", + " df_average = df.groupby([\"id\"]).mean().reset_index()\n", + "\n", + " df_average_and_error = pd.merge(df_average, df_std, on=\"id\")\n", "\n", " # Get unique ids\n", - " ids = df_average_and_error['id'].unique()\n", + " ids = df_average_and_error[\"id\"].unique()\n", "\n", " # Define marker styles\n", - " markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']\n", + " markers = [\n", + " \"o\",\n", + " \"v\",\n", + " \"^\",\n", + " \"<\",\n", + " \">\",\n", + " \"s\",\n", + " \"p\",\n", + " \"*\",\n", + " \"h\",\n", + " \"H\",\n", + " \"+\",\n", + " \"x\",\n", + " \"D\",\n", + " \"d\",\n", + " \"|\",\n", + " \"_\",\n", + " ]\n", " marker_size = 250\n", "\n", - " plt.style.use('seaborn-v0_8-poster')\n", - " plot_width, plot_height = 8,6\n", + " plt.style.use(\"seaborn-v0_8-poster\")\n", + " plot_width, plot_height = 8, 6\n", " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", " for idx, id_label in enumerate(ids):\n", - " df_average_id = df_average[df_average['id'] == id_label]\n", - " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", - "\n", - " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", - " y=df_average_and_error['EMD_opt_norm'], \n", - " xerr=df_average_and_error['EMD_submitted_norm_std'], \n", - " yerr=df_average_and_error['EMD_opt_norm_std'], \n", - " fmt='', alpha=0.05, linestyle='None', ecolor='k', elinewidth=2, capsize=5)\n", + " df_average_id = df_average[df_average[\"id\"] == id_label]\n", + " sns.scatterplot(\n", + " x=\"EMD_submitted_norm\",\n", + " y=\"EMD_opt_norm\",\n", + " data=df_average_id,\n", + " alpha=alpha,\n", + " marker=markers[idx % len(markers)],\n", + " label=id_label,\n", + " s=marker_size,\n", + " )\n", + "\n", + " plt.errorbar(\n", + " x=df_average_and_error[\"EMD_submitted_norm\"],\n", + " y=df_average_and_error[\"EMD_opt_norm\"],\n", + " xerr=df_average_and_error[\"EMD_submitted_norm_std\"],\n", + " yerr=df_average_and_error[\"EMD_opt_norm_std\"],\n", + " fmt=\"\",\n", + " alpha=0.05,\n", + " linestyle=\"None\",\n", + " ecolor=\"k\",\n", + " elinewidth=2,\n", + " capsize=5,\n", + " )\n", "\n", " plt.xlim(left=0.5)\n", " plt.ylim(bottom=0.5)\n", "\n", - " limits = [np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", - " np.max([plt.xlim(), plt.ylim()])] # max of both axes\n", + " limits = [\n", + " np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", + " np.max([plt.xlim(), plt.ylim()]),\n", + " ] # max of both axes\n", "\n", - " plt.plot(limits, limits, 'k-', alpha=0.75, zorder=0)\n", + " plt.plot(limits, limits, \"k-\", alpha=0.75, zorder=0)\n", " plt.xlim(limits)\n", " plt.ylim(limits)\n", - " legend = plt.legend(loc='upper left', fontsize=12)\n", + " legend = plt.legend(loc=\"upper left\", fontsize=12)\n", " for handle in legend.legend_handles:\n", " handle.set_alpha(1)\n", "\n", " plt.suptitle(suptitle)\n", "\n", - "suptitle = r'$d_{FSC}$ (no rank)'\n", + "\n", + "suptitle = r\"$d_{FSC}$ (no rank)\"\n", "plot_EMD_vs_EMDopt(df, suptitle=None)" ] } diff --git a/tutorials/README.md b/tutorials/README.md index d427ba5..842bdf7 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -25,4 +25,4 @@ This notebook walks through generating and analyzing (plots) the map to map dist - output: a `.pkl` file ### `5_tutorial_plotting.ipynb` -This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. \ No newline at end of file +This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results.