diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index 7540267..0000000 --- a/.gitattributes +++ /dev/null @@ -1,4 +0,0 @@ -*.ipynb linguist-vendored -*.mrc filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/main_merge_check.yml b/.github/workflows/main_merge_check.yml index b4aa6e1..8e6f5c5 100644 --- a/.github/workflows/main_merge_check.yml +++ b/.github/workflows/main_merge_check.yml @@ -11,4 +11,4 @@ jobs: if: github.base_ref == 'main' && github.head_ref != 'dev' run: | echo "ERROR: You can only merge to main from dev." - exit 1 + exit 1 \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 05e06e0..0000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,12 +0,0 @@ -# Runs the Ruff linter and formatter. - -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: chartboost/ruff-action@v1 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index ffc2fc8..413047d 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -26,22 +26,30 @@ jobs: python-version: ${{ matrix.python-version }} cache: 'pip' # caching pip dependencies - - name: Install Git LFS - run: | - sudo apt-get update - sudo apt-get install git-lfs - git lfs install - - name: Pull LFS Files - run: git lfs pull + - name: Cache test data + id: cache_test_data + uses: actions/cache@v3 + with: + path: | + tests/data + data + key: venv-${{ runner.os }}-${{ env.pythonLocation }}-${{ hashFiles('**/tests/scripts/fetch_test_data.sh') }} + - name: Install dependencies run: | python -m pip install --upgrade pip pip install . pip install pytest omegaconf - + + - name: Get test data from OSF + if: ${{ steps.cache_test_data.outputs.cache-hit != 'true' }} + run: | + sh tests/scripts/fetch_test_data.sh + - name: Test with pytest run: | pytest tests/test_preprocessing.py pytest tests/test_svd.py pytest tests/test_map_to_map.py pytest tests/test_distribution_to_distribution.py + diff --git a/.gitignore b/.gitignore index 659f92a..5ede44a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ data/dataset_1_submissions data/dataset_2_ground_truth # data for testing and resulting outputs +tests/data/Ground_truth +tests/data/dataset_2_submissions/ +tests/data/unprocessed_dataset_2_submissions/submission_x/ tests/results/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d1bb35..e3c79b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,6 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.3.4 diff --git a/README.md b/README.md index c27413b..165b497 100644 --- a/README.md +++ b/README.md @@ -41,9 +41,6 @@ pip install . ``` ## Developer installation - -First of all, make sure to have git lfs installed, otherwise you will have no access to the testing data. For installing, please follow these [guidelines](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage). - If you are interested in testing the programs previously installed, please, install the repository in development mode with the following commands: ```bash @@ -55,6 +52,7 @@ The test included in the repo can be executed with PyTest as shown below: ```bash cd /path/to/Cryo-EM-Heterogeneity-Challenge-1 +sh tests/scripts/fetch_test_data.sh # download test data from OSF pytest tests/test_preprocessing.py pytest tests/test_svd.py pytest tests/test_map_to_map.py diff --git a/config_files/config_distribution_to_distribution.yaml b/config_files/config_distribution_to_distribution.yaml index eaa94ac..da8b3d9 100644 --- a/config_files/config_distribution_to_distribution.yaml +++ b/config_files/config_distribution_to_distribution.yaml @@ -12,4 +12,4 @@ cvxpy_solver: ECOS optimal_q_kl: n_iter: 100000 break_atol: 0.0001 -output_fname: results/distribution_to_distribution_submission_0.pkl +output_fname: results/distribution_to_distribution_submission_0.pkl \ No newline at end of file diff --git a/config_files/config_map_to_map_distance_matrix.yaml b/config_files/config_map_to_map_distance_matrix.yaml index 5a98e28..4302227 100644 --- a/config_files/config_map_to_map_distance_matrix.yaml +++ b/config_files/config_map_to_map_distance_matrix.yaml @@ -1,15 +1,15 @@ data: n_pix: 224 - psize: 2.146 + psize: 2.146 submission: fname: data/dataset_2_ground_truth/submission_0.pt volume_key: volumes metadata_key: populations label_key: id ground_truth: - volumes: data/dataset_2_ground_truth/maps_gt_flat.pt - metadata: data/dataset_2_ground_truth/metadata.csv - mask: + volumes: data/dataset_2_ground_truth/maps_gt_flat.pt + metadata: data/dataset_2_ground_truth/metadata.csv + mask: do: true volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc analysis: @@ -23,4 +23,4 @@ analysis: normalize: do: true method: median_zscore -output: results/map_to_map_distance_matrix_submission_0.pkl +output: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file diff --git a/src/cryo_challenge/__init__.py b/src/cryo_challenge/__init__.py index 934c6a8..cafea4e 100644 --- a/src/cryo_challenge/__init__.py +++ b/src/cryo_challenge/__init__.py @@ -1,3 +1 @@ -from cryo_challenge.__about__ import __version__ - -__all__ = ["__version__"] +from cryo_challenge.__about__ import __version__ \ No newline at end of file diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 18c57bb..70961d8 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -2,6 +2,8 @@ import numpy as np import pickle from scipy.stats import rankdata +import yaml +import argparse import torch import ot @@ -12,12 +14,10 @@ def sort_by_transport(cost): - m, n = cost.shape - _, transport = compute_wasserstein_between_distributions_from_weights_and_cost( - np.ones(m) / m, np.ones(n) / n, cost - ) - indices = np.argsort((transport * np.arange(m)[..., None]).sum(0)) - return cost[:, indices], indices, transport + m,n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) + indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) + return cost[:,indices], indices, transport def compute_wasserstein_between_distributions_from_weights_and_cost( @@ -65,6 +65,7 @@ def make_assignment_matrix(cost_matrix): def run(config): + metadata_df = pd.read_csv(config["gt_metadata_fname"]) metadata_df.sort_values("pc1", inplace=True) @@ -72,7 +73,7 @@ def run(config): data = pickle.load(f) # user_submitted_populations = np.ones(80)/80 - user_submitted_populations = data["user_submitted_populations"] # .numpy() + user_submitted_populations = data["user_submitted_populations"]#.numpy() id = torch.load(data["config"]["data"]["submission"]["fname"])["id"] results_dict = {} @@ -212,5 +213,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol): DistributionToDistributionResultsValidator.from_dict(results_dict) with open(config["output_fname"], "wb") as f: pickle.dump(results_dict, f) - + return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 47d9a00..0578dfa 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -42,7 +42,7 @@ def run(config): user_submission_label = submission[label_key] # n_trunc = 10 - metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc] + metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc] results_dict = {} results_dict["config"] = config diff --git a/src/cryo_challenge/_ploting/plotting_utils.py b/src/cryo_challenge/_ploting/plotting_utils.py index 681ab01..04d5ca9 100644 --- a/src/cryo_challenge/_ploting/plotting_utils.py +++ b/src/cryo_challenge/_ploting/plotting_utils.py @@ -1,7 +1,6 @@ import numpy as np - def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist + fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist \ No newline at end of file diff --git a/src/cryo_challenge/data/__init__.py b/src/cryo_challenge/data/__init__.py index fb27bbd..8b4655c 100644 --- a/src/cryo_challenge/data/__init__.py +++ b/src/cryo_challenge/data/__init__.py @@ -1,18 +1,6 @@ -from ._validation.config_validators import ( - validate_input_config_disttodist as validate_input_config_disttodist, -) -from ._validation.config_validators import ( - validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, -) -from cryo_challenge.data._validation.output_validators import ( - DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, -) -from cryo_challenge.data._validation.output_validators import ( - MetricDistToDistValidator as MetricDistToDistValidator, -) -from cryo_challenge.data._validation.output_validators import ( - ReplicateValidatorEMD as ReplicateValidatorEMD, -) -from cryo_challenge.data._validation.output_validators import ( - ReplicateValidatorKL as ReplicateValidatorKL, -) +from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist +from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl +from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator +from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator +from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD +from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index 2a4d954..f194c14 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -106,16 +106,14 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): # Reshape volumes to correct size if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) + box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.)))) volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) elif volumes.dim() == 4: pass else: - raise ValueError( - f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." - ) + raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " + f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " + f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).") volumes_ds = torch.empty( (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 5acc6ba..93316a0 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -1,7 +1,7 @@ from numbers import Number import pandas as pd import os - +from typing import List def validate_generic_config(config: dict, reference: dict) -> None: """ diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 39cbb67..35d9791 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -13,7 +13,7 @@ @dataclass_json @dataclass class MapToMapResultsValidator: - """ + ''' Validate the output dictionary of the map-to-map distance matrix computation. config: dict, input config dictionary. @@ -22,8 +22,7 @@ class MapToMapResultsValidator: l2: dict, L2 results. bioem: dict, BioEM results. fsc: dict, FSC results. - """ - + ''' config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None @@ -50,7 +49,7 @@ class ReplicateValidatorEMD: Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline. q_opt: List[float], optimal user submitted distribution, which sums to 1. - EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). + EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). The transport plan is a joint distribution, such that: summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution. transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns). @@ -62,7 +61,6 @@ class ReplicateValidatorEMD: The transport plan is a joint distribution, such that: summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution. """ - q_opt: List[float] EMD_opt: float transport_plan_opt: List[List[float]] @@ -89,9 +87,8 @@ class ReplicateValidatorKL: iter_stop: int, number of iterations until convergence. eps_stop: float, stopping criterion. klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q). - klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). + klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). """ - q_opt: List[float] klpq_opt: float klqp_opt: float @@ -109,12 +106,11 @@ def __post_init__(self): @dataclass_json @dataclass class MetricDistToDistValidator: - """ + ''' Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline. replicates: dict, dictionary of replicates. - """ - + ''' replicates: dict def validate_replicates(self, n_replicates): @@ -130,7 +126,7 @@ def validate_replicates(self, n_replicates): @dataclass_json @dataclass class DistributionToDistributionResultsValidator: - """ + ''' Validate the output dictionary of the distribution-to-distribution pipeline. config: dict, input config dictionary. @@ -140,8 +136,7 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - """ - + ''' config: dict user_submitted_populations: torch.Tensor id: str diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index ecc5cde..c70d5aa 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -2,7 +2,7 @@ data: n_pix: 224 psize: 2.146 submission: - fname: tests/data/dataset_2_submissions/submission_10000.pt + fname: tests/data/dataset_2_submissions/test_submission_0_n8.pt volume_key: volumes metadata_key: populations label_key: id diff --git a/tests/config_files/test_config_svd.yaml b/tests/config_files/test_config_svd.yaml index ff0b413..c392525 100644 --- a/tests/config_files/test_config_svd.yaml +++ b/tests/config_files/test_config_svd.yaml @@ -1,6 +1,6 @@ path_to_volumes: tests/data/dataset_2_submissions/ box_size_ds: 32 -submission_list: [10000] +submission_list: [0] experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref" # optional unless experiment_mode is "all_vs_ref" path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.pt diff --git a/tests/data/Ground_truth/1.mrc b/tests/data/Ground_truth/1.mrc deleted file mode 100644 index 8c2a13e..0000000 --- a/tests/data/Ground_truth/1.mrc +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f866dcc00684a672ad12f5bfd61254843a0213a4b34efbe1510e002e3e357f20 -size 44958720 diff --git a/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc b/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc deleted file mode 100644 index 7b02aea..0000000 --- a/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2483eb52fbeb28a2d5851749270f0127fb287563332c923b33312bef567627d4 -size 44958720 diff --git a/tests/data/Ground_truth/test_maps_gt_flat_10.npy b/tests/data/Ground_truth/test_maps_gt_flat_10.npy deleted file mode 100644 index 983e61b..0000000 --- a/tests/data/Ground_truth/test_maps_gt_flat_10.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4b4733700fca5d5b1de2a8e106a46259c4a9b19b81f946776929e2cde0bdbd6e -size 449577088 diff --git a/tests/data/Ground_truth/test_maps_gt_flat_10.pt b/tests/data/Ground_truth/test_maps_gt_flat_10.pt deleted file mode 100644 index d0c5e40..0000000 --- a/tests/data/Ground_truth/test_maps_gt_flat_10.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e204191c9536d8a0d09fec5e5648acfa6a09c5182f60c0b3c6744ac5c4a3cb82 -size 449577746 diff --git a/tests/data/Ground_truth/test_metadata_10.csv b/tests/data/Ground_truth/test_metadata_10.csv deleted file mode 100644 index de041fc..0000000 --- a/tests/data/Ground_truth/test_metadata_10.csv +++ /dev/null @@ -1,11 +0,0 @@ -index,volumes,populations_count,pc1,populations -3238,13396.mrc,1,-231.62100638454024,2.9636654614427123e-05 -3020,10063.mrc,13,-179.257841640357,0.0003852765099875 -1113,01421.mrc,12,-141.8237629192062,0.0003556398553731 -3592,21858.mrc,4,-101.62462603216916,0.0001185466184577 -1947,03298.mrc,5,-34.99878410436052,0.0001481832730721 -2097,03764.mrc,6,3.946553364334135,0.0001778199276865 -1574,02336.mrc,5,44.70670231717438,0.0001481832730721 -2813,08011.mrc,8,108.6308222660271,0.0002370932369154 -88,00090.mrc,21,147.70416251702042,0.0006223697469029 -771,00906.mrc,11,186.3446095998357,0.0003260032007586 diff --git a/tests/data/dataset_2_submissions/submission_10000.pt b/tests/data/dataset_2_submissions/submission_10000.pt deleted file mode 100644 index b3a67d4..0000000 --- a/tests/data/dataset_2_submissions/submission_10000.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9abfc103cbf4a8885b22716cfb3ae0a9bb038f81a2b5f3c1e237ec14d407dba6 -size 359662738 diff --git a/tests/data/test_maps_gt_flat_2.pt b/tests/data/test_maps_gt_flat_2.pt index f74c145..dc3293c 100644 Binary files a/tests/data/test_maps_gt_flat_2.pt and b/tests/data/test_maps_gt_flat_2.pt differ diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc deleted file mode 100644 index 8c2a13e..0000000 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f866dcc00684a672ad12f5bfd61254843a0213a4b34efbe1510e002e3e357f20 -size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc deleted file mode 100644 index 4381ab3..0000000 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4551a156c3fc495f8e895ee8c37cea397b04b06902e5d6bb2516e8699fe9bc8e -size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc deleted file mode 100644 index 68fb0e9..0000000 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9e46737d911b12fa0624169b0f5c6ae746df9b5c68114d6a76831b5c974fe114 -size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc deleted file mode 100644 index 9f7b60e..0000000 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:203123b395a1d39a90c1d46f071ffaa096299c9d969c35cf532afe2c5c403641 -size 44958720 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt b/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt deleted file mode 100644 index 53e4b60..0000000 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt +++ /dev/null @@ -1,4 +0,0 @@ -0.25 -0.25 -0.25 -0.25 diff --git a/tests/scripts/fetch_test_data.sh b/tests/scripts/fetch_test_data.sh new file mode 100644 index 0000000..c252871 --- /dev/null +++ b/tests/scripts/fetch_test_data.sh @@ -0,0 +1,12 @@ +mkdir -p tests/data/dataset_2_submissions tests/data/dataset_2_submissions tests/results tests/data/unprocessed_dataset_2_submissions/submission_x tests/data/Ground_truth/ tests/data/Ground_truth +wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/dataset_2_submissions/test_submission_0_n8.pt?download=true -O tests/data/dataset_2_submissions/test_submission_0_n8.pt +ADIR=$(pwd) +ln -s $ADIR/tests/data/dataset_2_submissions/test_submission_0_n8.pt $ADIR/tests/data/dataset_2_submissions/submission_0.pt # symlink for svd which needs submission_0.pt for filename +wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_maps_gt_flat_10.pt?download=true -O tests/data/Ground_truth/test_maps_gt_flat_10.pt +wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_metadata_10.csv?download=true -O tests/data/Ground_truth/test_metadata_10.csv +wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/1.mrc?download=true -O tests/data/Ground_truth/1.mrc +wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/Ground_truth/mask_dilated_wide_224x224.mrc?download=true -O tests/data/Ground_truth/mask_dilated_wide_224x224.mrc +for FILE in 1.mrc 2.mrc 3.mrc 4.mrc populations.txt +do + wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/unprocessed_dataset_2_submissions/submission_x/${FILE}?download=true -O tests/data/unprocessed_dataset_2_submissions/submission_x/${FILE} +done diff --git a/tests/test_distribution_to_distribution.py b/tests/test_distribution_to_distribution.py index a4cfb79..d9c340b 100644 --- a/tests/test_distribution_to_distribution.py +++ b/tests/test_distribution_to_distribution.py @@ -2,8 +2,6 @@ from cryo_challenge._commands import run_distribution2distribution_pipeline -def test_run_distribution2distribution_pipeline(): - args = OmegaConf.create( - {"config": "tests/config_files/test_config_distribution_to_distribution.yaml"} - ) - run_distribution2distribution_pipeline.main(args) +def test_run_distribution2distribution_pipeline(): + args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'}) + run_distribution2distribution_pipeline.main(args) \ No newline at end of file diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index e31f29f..c782a8c 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -2,8 +2,6 @@ from cryo_challenge._commands import run_map2map_pipeline -def test_run_map2map_pipeline(): - args = OmegaConf.create( - {"config": "tests/config_files/test_config_map_to_map.yaml"} - ) - run_map2map_pipeline.main(args) +def test_run_map2map_pipeline(): + args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'}) + run_map2map_pipeline.main(args) \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 31db34e..cbf54e4 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_preprocessing -def test_run_preprocessing(): - args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"}) - run_preprocessing.main(args) +def test_run_preprocessing(): + args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'}) + run_preprocessing.main(args) \ No newline at end of file diff --git a/tests/test_svd.py b/tests/test_svd.py index ea166ea..568370e 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): - args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) - run_svd.main(args) +def test_run_preprocessing(): + args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'}) + run_svd.main(args) \ No newline at end of file diff --git a/tutorials/1_tutorial_preprocessing.ipynb b/tutorials/1_tutorial_preprocessing.ipynb index b2063ca..f04b511 100644 --- a/tutorials/1_tutorial_preprocessing.ipynb +++ b/tutorials/1_tutorial_preprocessing.ipynb @@ -242,7 +242,7 @@ "# Select path to Config file\n", "# An example of this file is available in the path ../config_files/config_preproc.yaml\n", "config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_preproc_path.filter_pattern = \"*.yaml\"\n", + "config_preproc_path.filter_pattern = '*.yaml'\n", "display(config_preproc_path)" ] }, @@ -265,7 +265,7 @@ "if os.path.isabs(output_path):\n", " full_output_path = output_path\n", "else:\n", - " full_output_path = os.path.join(os.getcwd(), \"..\", output_path)" + " full_output_path = os.path.join(os.getcwd(), '..', output_path)" ] }, { diff --git a/tutorials/2_tutorial_svd.ipynb b/tutorials/2_tutorial_svd.ipynb index fe8f432..b41bfba 100644 --- a/tutorials/2_tutorial_svd.ipynb +++ b/tutorials/2_tutorial_svd.ipynb @@ -62,7 +62,7 @@ "# Select path to SVD config file\n", "# An example of this file is available in the path ../config_files/config_svd.yaml\n", "config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_svd_path.filter_pattern = \"*.yaml\"\n", + "config_svd_path.filter_pattern = '*.yaml'\n", "display(config_svd_path)" ] }, @@ -125,7 +125,7 @@ "source": [ "# Select path to SVD results\n", "svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_results_path.filter_pattern = \"*.pt\"\n", + "svd_results_path.filter_pattern = '*.pt'\n", "display(svd_results_path)" ] }, @@ -316,7 +316,7 @@ "source": [ "# Select path to SVD results\n", "svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n", + "svd_all_vs_all_results_path.filter_pattern = '*.pt'\n", "display(svd_all_vs_all_results_path)" ] }, diff --git a/tutorials/3_tutorial_map2map.ipynb b/tutorials/3_tutorial_map2map.ipynb index 0497578..a3701ff 100644 --- a/tutorials/3_tutorial_map2map.ipynb +++ b/tutorials/3_tutorial_map2map.ipynb @@ -23,8 +23,15 @@ "\n", "from cryo_challenge.data._validation.config_validators import (\n", " validate_input_config_mtm,\n", - ")\n", - "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator" + " validate_config_mtm_data, \n", + " validate_config_mtm_data_submission, \n", + " validate_config_mtm_data_ground_truth, \n", + " validate_config_mtm_data_mask, \n", + " validate_config_mtm_analysis, \n", + " validate_config_mtm_analysis_normalize, \n", + " )\n", + "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator\n", + "from cryo_challenge.data._validation.config_validators import validate_maptomap_result" ] }, { @@ -73,7 +80,7 @@ "# Select path to Map to Map config file\n", "# An example of this file is available in the path ../config_files/config_map_to_map.yaml\n", "config_m2m_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_m2m_path.filter_pattern = \"*.yaml\"\n", + "config_m2m_path.filter_pattern = '*.yaml'\n", "display(config_m2m_path)" ] }, @@ -334,8 +341,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join(\"../\", config[\"output\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)" + "with open(os.path.join('../',config[\"output\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)\n" ] }, { diff --git a/tutorials/4_tutorial_distribution2distribution.ipynb b/tutorials/4_tutorial_distribution2distribution.ipynb index 2a6fc53..07dc1d9 100644 --- a/tutorials/4_tutorial_distribution2distribution.ipynb +++ b/tutorials/4_tutorial_distribution2distribution.ipynb @@ -30,7 +30,7 @@ "import pickle\n", "from ipyfilechooser import FileChooser\n", "\n", - "from cryo_challenge.data import validate_input_config_disttodist\n", + "from cryo_challenge.data import validate_input_config_disttodist, validate_config_dtd_optimal_q_kl\n", "from cryo_challenge.data import DistributionToDistributionResultsValidator" ] }, @@ -65,7 +65,7 @@ "# Select path to Distribution to Distribution config file\n", "# An example of this file is available in the path ../config_files/config_distribution_to_distribution.yaml\n", "config_d2d_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_d2d_path.filter_pattern = \"*.yaml\"\n", + "config_d2d_path.filter_pattern = '*.yaml'\n", "display(config_d2d_path)" ] }, @@ -199,8 +199,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join(\"../\", config[\"output_fname\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)" + "with open(os.path.join('../',config[\"output_fname\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)\n" ] }, { @@ -286,6 +286,7 @@ } ], "source": [ + "from cryo_challenge.data import MetricDistToDistValidator\n", "MetricDistToDistValidator?" ] }, @@ -301,7 +302,9 @@ "execution_count": 30, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from cryo_challenge.data import ReplicateValidatorEMD, ReplicateValidatorKL" + ] }, { "cell_type": "code", diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index e2648c8..ed8a924 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -23,9 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import (\n", - " sort_by_transport,\n", - ")\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -104,7 +102,6 @@ " map2map_results: List[str]\n", " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", - "\n", "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", "config = PlottingConfig.from_dict(config)\n", @@ -118,7 +115,7 @@ "outputs": [], "source": [ "metadata_df = pd.read_csv(config.gt_metadata)\n", - "metadata_df.sort_values(\"pc1\", inplace=True)\n", + "metadata_df.sort_values('pc1', inplace=True)\n", "gt_ordering = metadata_df.index.tolist()" ] }, @@ -139,14 +136,12 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, \"rb\") as f:\n", + " with open(fname, 'rb') as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[map2map_distance][\"user_submission_label\"]\n", + " anonymous_label = data[map2map_distance]['user_submission_label']\n", " data_d[anonymous_label] = data\n", " return data_d\n", - "\n", - "\n", - "data_d = get_fsc_distances(config.map2map_results, \"fsc\")" + "data_d = get_fsc_distances(config.map2map_results, 'fsc')" ] }, { @@ -167,24 +162,22 @@ ], "source": [ "def plot_fsc_distances(data_d, gt_ordering):\n", + "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(\n", - " 3,\n", - " n_plts // 3,\n", - " figsize=(40, 20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r\"$d_{FSC}$\", y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(3,n_plts//3,\n", + " figsize=(40,20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r'$d_{FSC}$', y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, data) in enumerate(data_d.items()):\n", - " map2map_dist_matrix = data[\"fsc\"][\"cost_matrix\"].iloc[gt_ordering].values\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(\n", - " map2map_dist_matrix\n", - " )\n", + " map2map_dist_matrix = data['fsc']['cost_matrix'].iloc[gt_ordering].values\n", + " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", + "\n", "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", @@ -192,24 +185,14 @@ " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", "\n", - " ax = axis[idx // ncols, idx % ncols].imshow(\n", - " sorted_map2map_dist_matrix,\n", - " aspect=\"auto\",\n", - " cmap=\"Blues_r\",\n", - " vmin=vmin,\n", - " vmax=vmax,\n", - " )\n", - "\n", - " axis[idx // ncols, idx % ncols].tick_params(\n", - " axis=\"both\", labelsize=smaller_fontsize\n", - " )\n", + " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", + "\n", + "\n", + " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx // ncols, idx % ncols].set_title(\n", - " plot_panel_label, fontsize=smaller_fontsize\n", - " )\n", - "\n", + " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", "\n", "plot_fsc_distances(data_d, gt_ordering)" ] @@ -231,13 +214,12 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, \"rb\") as f:\n", + " with open(fname, 'rb') as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[\"fsc\"][\"user_submission_label\"]\n", - " data_d[anonymous_label] = data[\"fsc\"][\"computed_assets\"][\"fsc_matrix\"]\n", + " anonymous_label = data['fsc']['user_submission_label']\n", + " data_d[anonymous_label] = data['fsc']['computed_assets']['fsc_matrix']\n", " return data_d\n", "\n", - "\n", "fscs_sorted_d = get_full_fsc_curve(config.map2map_results)" ] }, @@ -247,11 +229,9 @@ "metadata": {}, "outputs": [], "source": [ - "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(\n", - " fscs_sorted_d[\"Cookie Dough\"], threshold=0.5\n", - ")\n", - "n_fourier_bins = fscs_sorted_d[\"Cookie Dough\"].shape[-1]\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1, n_fourier_bins + 1) / n_fourier_bins)" + "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", + "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" ] }, { @@ -281,7 +261,7 @@ } ], "source": [ - "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect=\"auto\", cmap=\"Blues_r\")\n", + "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect='auto', cmap='Blues_r')\n", "plt.colorbar()" ] }, @@ -303,18 +283,17 @@ ], "source": [ "def plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={}):\n", + "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(\n", - " 3,\n", - " n_plts // 3,\n", - " figsize=(40, 20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r\"Resolution $(\\AA)$ at $FSC=0.5$\", y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(3,n_plts//3,\n", + " figsize=(40,20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r'Resolution $(\\AA)$ at $FSC=0.5$', y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", @@ -323,38 +302,27 @@ "\n", " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", + "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", " vmin = map2map_dist_matrix.min()\n", " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", - " if \"vmax\" in overwrite_dict.keys():\n", - " vmax = overwrite_dict[\"vmax\"]\n", - " if \"vmin\" in overwrite_dict.keys():\n", - " vmin = overwrite_dict[\"vmin\"]\n", - "\n", - " ax = axis[idx // ncols, idx % ncols].imshow(\n", - " sorted_map2map_dist_matrix,\n", - " aspect=\"auto\",\n", - " cmap=\"Blues_r\",\n", - " vmin=vmin,\n", - " vmax=vmax,\n", - " )\n", - "\n", - " axis[idx // ncols, idx % ncols].tick_params(\n", - " axis=\"both\", labelsize=smaller_fontsize\n", - " )\n", + " if 'vmax' in overwrite_dict.keys():\n", + " vmax = overwrite_dict['vmax']\n", + " if 'vmin' in overwrite_dict.keys():\n", + " vmin = overwrite_dict['vmin']\n", + "\n", + " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", + "\n", + "\n", + " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx // ncols, idx % ncols].set_title(\n", - " plot_panel_label, fontsize=smaller_fontsize\n", - " )\n", - "\n", + " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", "\n", - "plot_res_at_fsc_threshold_distances(\n", - " fscs_sorted_d, gt_ordering, overwrite_dict={\"vmax\": 31}\n", - ")" + "plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={'vmax': 31})" ] }, { @@ -370,9 +338,9 @@ "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"][0]\n", + "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", "\n", - "with open(fname, \"rb\") as f:\n", + "with open(fname, 'rb') as f:\n", " data = pickle.load(f)" ] }, @@ -386,16 +354,13 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, \"rb\") as f:\n", + " with open(fname, 'rb') as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[\"id\"]\n", + " anonymous_label = data['id']\n", " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "\n", - "dist2dist_results_d = get_dist2dist_results(\n", - " config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"]\n", - ")" + "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" ] }, { @@ -416,58 +381,46 @@ ], "source": [ "window_size = 15\n", - "nrows, ncols = 3, 4\n", + "nrows, ncols = 3,4\n", "suptitle = f\"Submitted populations vs optimal populations \\n d_FSC (no rank) | n_replicates={data['config']['n_replicates']} | window_size={window_size} | n_pool_microstate={data['config']['n_pool_microstate']}\"\n", "\n", - "\n", "def plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols):\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=(40, 25))\n", "\n", - " fig.suptitle(suptitle, fontsize=30, y=0.95)\n", + " fig, axes = plt.subplots(nrows, ncols, figsize=(40,25))\n", + " \n", + " fig.suptitle(\n", + " suptitle,\n", + " fontsize=30,\n", + " y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (_, data) in enumerate(dist2dist_results_d.items()):\n", - " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", - " data[\"user_submitted_populations\"], color=\"black\", label=\"submited\"\n", - " )\n", - " axes[idx_fname // ncols, idx_fname % ncols].set_title(data[\"id\"], fontsize=30)\n", + " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", + " \n", "\n", - " def window_q(q_opt, window_size):\n", - " running_avg = np.convolve(\n", - " q_opt, np.ones(window_size) / window_size, mode=\"same\"\n", - " )\n", - " return running_avg\n", + " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", + " axes[idx_fname//ncols, idx_fname%ncols].set_title(data['id'], fontsize=30)\n", "\n", - " for replicate_idx in range(data[\"config\"][\"n_replicates\"]):\n", - " if replicate_idx == 0:\n", - " label_d = {\n", - " \"EMD\": \"EMD\",\n", - " \"KL\": \"KL\",\n", - " \"KL_raw\": \"Unwindowed\",\n", - " \"EMD_raw\": \"Unwindowed\",\n", - " }\n", + " def window_q(q_opt,window_size):\n", + " running_avg = np.convolve(q_opt, np.ones(window_size)/window_size, mode='same')\n", + " return running_avg\n", + " \n", + " for replicate_idx in range(data['config']['n_replicates']):\n", + " \n", + " if replicate_idx == 0: \n", + " label_d = {'EMD': 'EMD', 'KL': 'KL', 'KL_raw': 'Unwindowed', 'EMD_raw': 'Unwindowed'}\n", " else:\n", - " label_d = {\"EMD\": None, \"KL\": None, \"KL_raw\": None, \"EMD_raw\": None}\n", - " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", - " window_q(\n", - " data[\"fsc\"][\"replicates\"][replicate_idx][\"EMD\"][\"q_opt\"],\n", - " window_size,\n", - " ),\n", - " color=\"blue\",\n", - " alpha=alpha,\n", - " label=label_d[\"EMD\"],\n", - " )\n", - "\n", - " axes[idx_fname // ncols, idx_fname % ncols].set_xlabel(\"Submission index\")\n", - " axes[idx_fname // ncols, idx_fname % ncols].set_ylabel(\"Population\")\n", - "\n", - " legend = axes[idx_fname // ncols, idx_fname % ncols].legend()\n", + " label_d = {'EMD': None, 'KL': None, 'KL_raw': None, 'EMD_raw': None}\n", + " axes[idx_fname//ncols, idx_fname%ncols].plot(window_q(data['fsc']['replicates'][replicate_idx]['EMD']['q_opt'],window_size), color='blue', alpha=alpha, label=label_d['EMD'])\n", + "\n", + " axes[idx_fname//ncols, idx_fname%ncols].set_xlabel('Submission index')\n", + " axes[idx_fname//ncols, idx_fname%ncols].set_ylabel('Population')\n", + "\n", + " legend = axes[idx_fname//ncols, idx_fname%ncols].legend()\n", " for line, text in zip(legend.get_lines(), legend.get_texts()):\n", " text.set_color(line.get_color())\n", " line.set_alpha(1)\n", "\n", - "\n", - "plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols)" + "plot_q_opt_distances(dist2dist_results_d, suptitle,nrows, ncols)" ] }, { @@ -484,6 +437,7 @@ "outputs": [], "source": [ "def wragle_pkl_to_dataframe(pkl_globs):\n", + "\n", " fnames = []\n", " for fname_glob in pkl_globs:\n", " fnames.extend(glob.glob(fname_glob))\n", @@ -492,50 +446,28 @@ "\n", " df_list = []\n", " n_replicates = 30\n", - " metric = \"fsc\"\n", + " metric = 'fsc'\n", "\n", " for fname in fnames:\n", - " with open(fname, \"rb\") as f:\n", + " with open(fname, 'rb') as f:\n", " data = pickle.load(f)\n", "\n", - " df_list.append(\n", - " pd.DataFrame(\n", - " {\n", - " \"EMD_opt\": [\n", - " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_opt\"]\n", - " for i in range(n_replicates)\n", - " ],\n", - " \"EMD_submitted\": [\n", - " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_submitted\"]\n", - " for i in range(n_replicates)\n", - " ],\n", - " \"klpq_opt\": [\n", - " data[metric][\"replicates\"][i][\"KL\"][\"klpq_opt\"]\n", - " for i in range(n_replicates)\n", - " ],\n", - " \"klqp_opt\": [\n", - " data[metric][\"replicates\"][i][\"KL\"][\"klqp_opt\"]\n", - " for i in range(n_replicates)\n", - " ],\n", - " \"klpq_submitted\": [\n", - " data[metric][\"replicates\"][i][\"KL\"][\"klpq_submitted\"]\n", - " for i in range(n_replicates)\n", - " ],\n", - " \"klqp_submitted\": [\n", - " data[metric][\"replicates\"][i][\"KL\"][\"klqp_submitted\"]\n", - " for i in range(n_replicates)\n", - " ],\n", - " \"id\": data[\"id\"],\n", - " \"n_pool_microstate\": data[\"config\"][\"n_pool_microstate\"],\n", - " }\n", - " )\n", - " )\n", + " df_list.append(pd.DataFrame({\n", + " 'EMD_opt': [data[metric]['replicates'][i]['EMD']['EMD_opt'] for i in range(n_replicates)],\n", + " 'EMD_submitted': [data[metric]['replicates'][i]['EMD']['EMD_submitted'] for i in range(n_replicates)],\n", + " 'klpq_opt': [data[metric]['replicates'][i]['KL']['klpq_opt'] for i in range(n_replicates)],\n", + " 'klqp_opt': [data[metric]['replicates'][i]['KL']['klqp_opt'] for i in range(n_replicates)],\n", + " 'klpq_submitted': [data[metric]['replicates'][i]['KL']['klpq_submitted'] for i in range(n_replicates)], \n", + " 'klqp_submitted': [data[metric]['replicates'][i]['KL']['klqp_submitted'] for i in range(n_replicates)], \n", + " 'id': data['id'],\n", + " 'n_pool_microstate': data['config']['n_pool_microstate'],\n", + " }))\n", "\n", " df = pd.concat(df_list)\n", - " df[\"EMD_opt_norm\"] = df[\"EMD_opt\"] / df[\"n_pool_microstate\"]\n", - " df[\"EMD_submitted_norm\"] = df[\"EMD_submitted\"] / df[\"n_pool_microstate\"]\n", + " df['EMD_opt_norm'] = df['EMD_opt'] / df['n_pool_microstate']\n", + " df['EMD_submitted_norm'] = df['EMD_submitted'] / df['n_pool_microstate']\n", "\n", - " return df" + " return df\n" ] }, { @@ -544,7 +476,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results[\"emd_plot\"][\"pkl_globs\"])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" ] }, { @@ -566,96 +498,50 @@ "source": [ "def plot_EMD_vs_EMDopt(df, suptitle=None):\n", " alpha = 1\n", - " df_average = df.groupby([\"n_pool_microstate\", \"id\"]).mean().reset_index()\n", - " df_std = (\n", - " df_average.groupby([\"id\"])\n", - " .std()\n", - " .reset_index()\n", - " .filter([\"EMD_opt_norm\", \"EMD_submitted_norm\", \"id\"])\n", - " .rename(\n", - " columns={\n", - " \"EMD_opt_norm\": \"EMD_opt_norm_std\",\n", - " \"EMD_submitted_norm\": \"EMD_submitted_norm_std\",\n", - " }\n", - " )\n", - " )\n", - " df_average = df.groupby([\"id\"]).mean().reset_index()\n", - "\n", - " df_average_and_error = pd.merge(df_average, df_std, on=\"id\")\n", + " df_average = df.groupby(['n_pool_microstate','id']).mean().reset_index()\n", + " df_std = df_average.groupby(['id']).std().reset_index().filter(['EMD_opt_norm','EMD_submitted_norm', 'id']).rename(columns={'EMD_opt_norm':'EMD_opt_norm_std', 'EMD_submitted_norm':'EMD_submitted_norm_std'})\n", + " df_average = df.groupby(['id']).mean().reset_index()\n", + "\n", + " df_average_and_error = pd.merge(df_average, df_std, on='id')\n", "\n", " # Get unique ids\n", - " ids = df_average_and_error[\"id\"].unique()\n", + " ids = df_average_and_error['id'].unique()\n", "\n", " # Define marker styles\n", - " markers = [\n", - " \"o\",\n", - " \"v\",\n", - " \"^\",\n", - " \"<\",\n", - " \">\",\n", - " \"s\",\n", - " \"p\",\n", - " \"*\",\n", - " \"h\",\n", - " \"H\",\n", - " \"+\",\n", - " \"x\",\n", - " \"D\",\n", - " \"d\",\n", - " \"|\",\n", - " \"_\",\n", - " ]\n", + " markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']\n", " marker_size = 250\n", "\n", - " plt.style.use(\"seaborn-v0_8-poster\")\n", - " plot_width, plot_height = 8, 6\n", + " plt.style.use('seaborn-v0_8-poster')\n", + " plot_width, plot_height = 8,6\n", " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", " for idx, id_label in enumerate(ids):\n", - " df_average_id = df_average[df_average[\"id\"] == id_label]\n", - " sns.scatterplot(\n", - " x=\"EMD_submitted_norm\",\n", - " y=\"EMD_opt_norm\",\n", - " data=df_average_id,\n", - " alpha=alpha,\n", - " marker=markers[idx % len(markers)],\n", - " label=id_label,\n", - " s=marker_size,\n", - " )\n", - "\n", - " plt.errorbar(\n", - " x=df_average_and_error[\"EMD_submitted_norm\"],\n", - " y=df_average_and_error[\"EMD_opt_norm\"],\n", - " xerr=df_average_and_error[\"EMD_submitted_norm_std\"],\n", - " yerr=df_average_and_error[\"EMD_opt_norm_std\"],\n", - " fmt=\"\",\n", - " alpha=0.05,\n", - " linestyle=\"None\",\n", - " ecolor=\"k\",\n", - " elinewidth=2,\n", - " capsize=5,\n", - " )\n", + " df_average_id = df_average[df_average['id'] == id_label]\n", + " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", + "\n", + " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", + " y=df_average_and_error['EMD_opt_norm'], \n", + " xerr=df_average_and_error['EMD_submitted_norm_std'], \n", + " yerr=df_average_and_error['EMD_opt_norm_std'], \n", + " fmt='', alpha=0.05, linestyle='None', ecolor='k', elinewidth=2, capsize=5)\n", "\n", " plt.xlim(left=0.5)\n", " plt.ylim(bottom=0.5)\n", "\n", - " limits = [\n", - " np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", - " np.max([plt.xlim(), plt.ylim()]),\n", - " ] # max of both axes\n", + " limits = [np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", + " np.max([plt.xlim(), plt.ylim()])] # max of both axes\n", "\n", - " plt.plot(limits, limits, \"k-\", alpha=0.75, zorder=0)\n", + " plt.plot(limits, limits, 'k-', alpha=0.75, zorder=0)\n", " plt.xlim(limits)\n", " plt.ylim(limits)\n", - " legend = plt.legend(loc=\"upper left\", fontsize=12)\n", + " legend = plt.legend(loc='upper left', fontsize=12)\n", " for handle in legend.legend_handles:\n", " handle.set_alpha(1)\n", "\n", " plt.suptitle(suptitle)\n", "\n", - "\n", - "suptitle = r\"$d_{FSC}$ (no rank)\"\n", + "suptitle = r'$d_{FSC}$ (no rank)'\n", "plot_EMD_vs_EMDopt(df, suptitle=None)" ] } diff --git a/tutorials/README.md b/tutorials/README.md index 842bdf7..d427ba5 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -25,4 +25,4 @@ This notebook walks through generating and analyzing (plots) the map to map dist - output: a `.pkl` file ### `5_tutorial_plotting.ipynb` -This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. +This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. \ No newline at end of file