diff --git a/.github/workflows/main_merge_check.yml b/.github/workflows/main_merge_check.yml index 8e6f5c5..b4aa6e1 100644 --- a/.github/workflows/main_merge_check.yml +++ b/.github/workflows/main_merge_check.yml @@ -11,4 +11,4 @@ jobs: if: github.base_ref == 'main' && github.head_ref != 'dev' run: | echo "ERROR: You can only merge to main from dev." - exit 1 \ No newline at end of file + exit 1 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..1b4ed47 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,12 @@ +# Runs the Ruff linter and formatter. + +name: Lint + +on: [push] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 6c665b3..ffc74b0 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -32,13 +32,6 @@ jobs: pip install . pip install pytest omegaconf - - name: Get test data from OSF - run: | - sh tests/scripts/fetch_test_data.sh - - name: Test with pytest run: | - pytest tests/test_preprocessing.py - pytest tests/test_svd.py - pytest tests/test_map_to_map.py - pytest tests/test_distribution_to_distribution.py + pytest tests diff --git a/.gitignore b/.gitignore index 5ede44a..8ddf6bb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,10 +3,7 @@ data/dataset_2_submissions data/dataset_1_submissions data/dataset_2_ground_truth -# data for testing and resulting outputs -tests/data/Ground_truth -tests/data/dataset_2_submissions/ -tests/data/unprocessed_dataset_2_submissions/submission_x/ +# testing results tests/results/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index adc0ebb..4100565 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,8 +12,28 @@ The "-e" flag will install the package in editable mode, which means you can edi ## Things to do before pushing to GitHub -In this project we use Ruff for linting, and pre-commit to make sure that the code being pushed is not broken or goes against PEP8 guidelines. When you run `git commit` the pre-commit pipeline should rune automatically. In the near future we will start using pytest and mypy to perform more checks. +### Using pre-commit hooks for code formatting and linting +When you install in developer mode with `".[dev]` you will install the [pre-commit](https://pre-commit.com/) package. To set up this package simply run + +```bash +pre-commit install +``` + +Then, everytime before doing a commit (that is before `git add` and `git commit`) run the following command: + +```bash +pre-commit run --all-files +``` + +This will run `ruff` linting and formatting. If there is anything that cannot be automatically fixed, the command will let you know the file and line that needs to be fixed before being able to commit. Once you have fixed everything, you will be able to run `git add` and `git commit` without issue. + + +### Make sure tests run + +```bash +python -m pytest tests/ +``` ## Best practices for contributing diff --git a/config_files/config_distribution_to_distribution.yaml b/config_files/config_distribution_to_distribution.yaml index da8b3d9..eaa94ac 100644 --- a/config_files/config_distribution_to_distribution.yaml +++ b/config_files/config_distribution_to_distribution.yaml @@ -12,4 +12,4 @@ cvxpy_solver: ECOS optimal_q_kl: n_iter: 100000 break_atol: 0.0001 -output_fname: results/distribution_to_distribution_submission_0.pkl \ No newline at end of file +output_fname: results/distribution_to_distribution_submission_0.pkl diff --git a/config_files/config_map_to_map_distance_matrix.yaml b/config_files/config_map_to_map.yaml similarity index 61% rename from config_files/config_map_to_map_distance_matrix.yaml rename to config_files/config_map_to_map.yaml index 3c0994c..bb66486 100644 --- a/config_files/config_map_to_map_distance_matrix.yaml +++ b/config_files/config_map_to_map.yaml @@ -1,15 +1,15 @@ data: n_pix: 224 - psize: 2.146 + psize: 2.146 submission: fname: data/dataset_2_ground_truth/submission_0.pt volume_key: volumes metadata_key: populations label_key: id ground_truth: - volumes: data/dataset_2_ground_truth/maps_gt_flat.pt - metadata: data/dataset_2_ground_truth/metadata.csv - mask: + volumes: data/dataset_2_ground_truth/maps_gt_flat.pt + metadata: data/dataset_2_ground_truth/metadata.csv + mask: do: true volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc analysis: @@ -24,6 +24,4 @@ analysis: normalize: do: true method: median_zscore -output: results/map_to_map_distance_matrix_submission_0.pkl -# external: -# res: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file +output: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file diff --git a/src/cryo_challenge/__init__.py b/src/cryo_challenge/__init__.py index cafea4e..934c6a8 100644 --- a/src/cryo_challenge/__init__.py +++ b/src/cryo_challenge/__init__.py @@ -1 +1,3 @@ -from cryo_challenge.__about__ import __version__ \ No newline at end of file +from cryo_challenge.__about__ import __version__ + +__all__ = ["__version__"] diff --git a/src/cryo_challenge/_commands/run_map2map_pipeline.py b/src/cryo_challenge/_commands/run_map2map_pipeline.py index c985f6a..ab36f7a 100644 --- a/src/cryo_challenge/_commands/run_map2map_pipeline.py +++ b/src/cryo_challenge/_commands/run_map2map_pipeline.py @@ -6,8 +6,8 @@ import os import yaml -from cryo_challenge._map_to_map.map_to_map_distance_matrix import run -from cryo_challenge.data._validation.config_validators import validate_input_config_mtm +from .._map_to_map.map_to_map_pipeline import run +from ..data._validation.config_validators import validate_input_config_mtm def add_args(parser): diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 70961d8..18c57bb 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -2,8 +2,6 @@ import numpy as np import pickle from scipy.stats import rankdata -import yaml -import argparse import torch import ot @@ -14,10 +12,12 @@ def sort_by_transport(cost): - m,n = cost.shape - _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) - indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) - return cost[:,indices], indices, transport + m, n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost( + np.ones(m) / m, np.ones(n) / n, cost + ) + indices = np.argsort((transport * np.arange(m)[..., None]).sum(0)) + return cost[:, indices], indices, transport def compute_wasserstein_between_distributions_from_weights_and_cost( @@ -65,7 +65,6 @@ def make_assignment_matrix(cost_matrix): def run(config): - metadata_df = pd.read_csv(config["gt_metadata_fname"]) metadata_df.sort_values("pc1", inplace=True) @@ -73,7 +72,7 @@ def run(config): data = pickle.load(f) # user_submitted_populations = np.ones(80)/80 - user_submitted_populations = data["user_submitted_populations"]#.numpy() + user_submitted_populations = data["user_submitted_populations"] # .numpy() id = torch.load(data["config"]["data"]["submission"]["fname"])["id"] results_dict = {} @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol): DistributionToDistributionResultsValidator.from_dict(results_dict) with open(config["output_fname"], "wb") as f: pickle.dump(results_dict, f) - + return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 34e126b..b9f12f5 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -1,143 +1,260 @@ import math import torch from typing import Optional, Sequence +from typing_extensions import override +import mrcfile +import numpy as np -def fourier_shell_correlation( - x: torch.Tensor, - y: torch.Tensor, - dim: Sequence[int] = (-3, -2, -1), - normalize: bool = True, - max_k: Optional[int] = None, -): - """Computes Fourier Shell / Ring Correlation (FSC) between x and y. - - Parameters - ---------- - x : torch.Tensor - First input tensor. - y : torch.Tensor - Second input tensor. - dim : Tuple[int, ...] - Dimensions over which to take the Fourier transform. - normalize : bool - Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). - Note that when `normalize=False`, we still divide by the number of elements in each shell. - max_k : int - The maximum shell to compute the correlation for. - - Returns - ------- - torch.Tensor - The correlation between x and y for each Fourier shell. - """ # noqa: E501 - batch_shape = x.shape[: -len(dim)] - - freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] - freq_total = ( - torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) - ) - - x_f = torch.fft.fftn(x, dim=dim) - y_f = torch.fft.fftn(y, dim=dim) - - n = min(x.shape[d] for d in dim) - - if max_k is None: - max_k = n // 2 - - result = x.new_zeros(batch_shape + (max_k,)) - - for i in range(1, max_k + 1): - mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) - x_ri = x_f[..., mask] - y_fi = y_f[..., mask] - - if x.is_cuda: - c_i = torch.linalg.vecdot(x_ri, y_fi).real - else: - # vecdot currently bugged on CPU for torch 2.0 in some configurations - c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real - - if normalize: - c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) - else: - c_i /= x_ri.shape[-1] - - result[..., i - 1] = c_i - - return result - - -def compute_bioem3d_cost(map1, map2): - """ - Compute the cost between two maps using the BioEM cost function in 3D. - - Notes - ----- - See Eq. 10 in 10.1016/j.jsb.2013.10.006 - - Parameters - ---------- - map1 : torch.Tensor - shape (n_pix,n_pix,n_pix) - map2 : torch.Tensor - shape (n_pix,n_pix,n_pix) - - Returns - ------- - cost : torch.Tensor - shape (1,) - """ - m1, m2 = map1.reshape(-1), map2.reshape(-1) - co = m1.sum() - cc = m2.sum() - coo = m1.pow(2).sum() - ccc = m2.pow(2).sum() - coc = (m1 * m2).sum() - - N = len(m1) - - t1 = 2 * torch.pi * math.exp(1) - t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc - t3 = (N - 2) * (N * ccc - cc * cc) - - smallest_float = torch.finfo(m1.dtype).tiny - log_prob = ( - 0.5 * torch.pi - + math.log(t1) * (1 - N / 2) - + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) - + t3.clamp(smallest_float).log() * (N / 2 - 2) - ) - cost = -log_prob - return cost - - -def compute_cost_l2(map_1, map_2): - return ((map_1 - map_2) ** 2).sum() - - -def compute_cost_corr(map_1, map_2): - return (map_1 * map_2).sum() - - -def compute_cost_fsc_chunk(maps_gt_flat, maps_user_flat, n_pix): - """ - Compute the cost between two maps using the Fourier Shell Correlation in 3D. - - Notes - ----- - fourier_shell_correlation can only batch on first input set of maps, - so we compute the cost one row (gt map idx) at a time - """ - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) - for idx in range(len(maps_gt_flat)): - corr_vector = fourier_shell_correlation( - maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), - maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), +class MapToMapDistance: + def __init__(self, config): + self.config = config + + def get_distance(self, map1, map2): + """Compute the distance between two maps.""" + raise NotImplementedError() + + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + """Compute the distance matrix between two sets of maps.""" + chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1) + + return distance_matrix + + def get_computed_assets(self, maps1, maps2, global_store_of_running_results): + """Return any computed assets that are needed for (downstream) analysis.""" + return {} + + +class L2DistanceNorm(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + @override + def get_distance(self, map1, map2): + return torch.norm(map1 - map2) ** 2 + + +class L2DistanceSum(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def compute_cost_l2(self, map_1, map_2): + return ((map_1 - map_2) ** 2).sum() + + @override + def get_distance(self, map1, map2): + return self.compute_cost_l2(map1, map2) + + +class Correlation(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def compute_cost_corr(self, map_1, map_2): + return (map_1 * map_2).sum() + + @override + def get_distance(self, map1, map2): + return self.compute_cost_corr(map1, map2) + + +class BioEM3dDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def compute_bioem3d_cost(self, map1, map2): + """ + Compute the cost between two maps using the BioEM cost function in 3D. + + Notes + ----- + See Eq. 10 in 10.1016/j.jsb.2013.10.006 + + Parameters + ---------- + map1 : torch.Tensor + shape (n_pix,n_pix,n_pix) + map2 : torch.Tensor + shape (n_pix,n_pix,n_pix) + + Returns + ------- + cost : torch.Tensor + shape (1,) + """ + m1, m2 = map1.reshape(-1), map2.reshape(-1) + co = m1.sum() + cc = m2.sum() + coo = m1.pow(2).sum() + ccc = m2.pow(2).sum() + coc = (m1 * m2).sum() + + N = len(m1) + + t1 = 2 * torch.pi * math.exp(1) + t2 = ( + N * (ccc * coo - coc * coc) + + 2 * co * coc * cc + - ccc * co * co + - coo * cc * cc + ) + t3 = (N - 2) * (N * ccc - cc * cc) + + smallest_float = torch.finfo(m1.dtype).tiny + log_prob = ( + 0.5 * torch.pi + + math.log(t1) * (1 - N / 2) + + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) + + t3.clamp(smallest_float).log() * (N / 2 - 2) + ) + cost = -log_prob + return cost + + @override + def get_distance(self, map1, map2): + return self.compute_bioem3d_cost(map1, map2) + + +class FSCDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def fourier_shell_correlation( + self, + x: torch.Tensor, + y: torch.Tensor, + dim: Sequence[int] = (-3, -2, -1), + normalize: bool = True, + max_k: Optional[int] = None, + ): + """Computes Fourier Shell / Ring Correlation (FSC) between x and y. + + Parameters + ---------- + x : torch.Tensor + First input tensor. + y : torch.Tensor + Second input tensor. + dim : Tuple[int, ...] + Dimensions over which to take the Fourier transform. + normalize : bool + Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). + Note that when `normalize=False`, we still divide by the number of elements in each shell. + max_k : int + The maximum shell to compute the correlation for. + + Returns + ------- + torch.Tensor + The correlation between x and y for each Fourier shell. + """ # noqa: E501 + batch_shape = x.shape[: -len(dim)] + + freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] + freq_total = ( + torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) + ) + + x_f = torch.fft.fftn(x, dim=dim) + y_f = torch.fft.fftn(y, dim=dim) + + n = min(x.shape[d] for d in dim) + + if max_k is None: + max_k = n // 2 + + result = x.new_zeros(batch_shape + (max_k,)) + + for i in range(1, max_k + 1): + mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) + x_ri = x_f[..., mask] + y_fi = y_f[..., mask] + + if x.is_cuda: + c_i = torch.linalg.vecdot(x_ri, y_fi).real + else: + # vecdot currently bugged on CPU for torch 2.0 in some configurations + c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real + + if normalize: + c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) + else: + c_i /= x_ri.shape[-1] + + result[..., i - 1] = c_i + + return result + + def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): + """ + Compute the cost between two maps using the Fourier Shell Correlation in 3D. + + Notes + ----- + fourier_shell_correlation can only batch on first input set of maps, + so we compute the cost one row (gt map idx) at a time + """ + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) + for idx in range(len(maps_gt_flat)): + corr_vector = self.fourier_shell_correlation( + maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), + maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), + ) + dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff + fsc_matrix[idx] = corr_vector + cost_matrix[idx] = dist + return cost_matrix, fsc_matrix + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method + maps_gt_flat = maps1 + maps_user_flat = maps2 + n_pix = self.config["data"]["n_pix"] + maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) + mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() ) - dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff - fsc_matrix[idx] = corr_vector - cost_matrix[idx] = dist - return cost_matrix, fsc_matrix + maps_gt_flat_cube[:, mask] = maps_gt_flat + maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) + maps_user_flat_cube[:, mask] = maps_user_flat + + cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( + maps_gt_flat_cube, maps_user_flat_cube, n_pix + ) + self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + return cost_matrix + + @override + def get_computed_assets(self, maps1, maps2, global_store_of_running_results): + return self.stored_computed_assets # must run get_distance_matrix first + + +class ResDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method + # get fsc matrix + fourier_pixel_max = self.config['data']['n_pix'] // 2 # TODO: check for odd psizes if this should be +1 + psize = self.config['data']['psize'] + fsc_matrix = global_store_of_running_results['fsc']['computed_assets']['fsc_matrix'] + units_Angstroms = 2 * psize / (np.arange(1,fourier_pixel_max+1) / fourier_pixel_max) + def res_at_fsc_threshold(fscs, threshold=0.5): + res_fsc_half = np.argmin(fscs > threshold, axis=-1) + fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist + res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) + self.stored_computed_assets = {'fraction_nyquist': fraction_nyquist} + return units_Angstroms[res_fsc_half] \ No newline at end of file diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py deleted file mode 100644 index 1f7d2a4..0000000 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ /dev/null @@ -1,195 +0,0 @@ -import mrcfile -import pandas as pd -import pickle -import torch -import numpy as np - -from cryo_challenge._map_to_map.map_to_map_distance import ( - compute_bioem3d_cost, - compute_cost_corr, - compute_cost_fsc_chunk, - compute_cost_l2, -) -from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator - - -class MapToMapDistance: - def __init__(self, config): - self.config = config - - def get_distance(self, map1, map2): - raise NotImplementedError() - - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - chunk_size_submission = self.config["analysis"]["chunk_size_submission"] - chunk_size_gt = self.config["analysis"]["chunk_size_gt"] - distance_matrix = torch.vmap( - lambda maps1: torch.vmap( - lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, - )(maps2), - chunk_size=chunk_size_gt, - )(maps1) - - return distance_matrix.numpy() - - def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - return {} - -class L2DistanceNorm(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - def get_distance(self, map1, map2): - return torch.norm(map1 - map2)**2 - -class L2DistanceSum(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - def get_distance(self, map1, map2): - return compute_cost_l2(map1, map2) - -class Correlation(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - def get_distance(self, map1, map2): - return compute_cost_corr(map1, map2) - -class BioEM3dDistance(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - def get_distance(self, map1, map2): - return compute_bioem3d_cost(map1, map2) - -class FSCDistance(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method - maps_gt_flat = maps1 - maps_user_flat = maps2 - n_pix = self.config["data"]["n_pix"] - maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]).data.astype(bool).flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat - maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat - - cost_matrix, fsc_matrix = compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) - self.stored_computed_assets = {'fsc_matrix': fsc_matrix} - return cost_matrix.numpy() - - def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - return self.stored_computed_assets # must run get_distance_matrix first - -class ResDistance(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method - # get fsc matrix - fourier_pixel_max = self.config['data']['n_pix'] // 2 # TODO: check for odd psizes if this should be +1 - psize = self.config['data']['psize'] - fsc_matrix = global_store_of_running_results['fsc']['computed_assets']['fsc_matrix'] - units_Angstroms = 2 * psize / (np.arange(1,fourier_pixel_max+1) / fourier_pixel_max) - def res_at_fsc_threshold(fscs, threshold=0.5): - res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist - res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) - self.stored_computed_assets = {'fraction_nyquist': fraction_nyquist} - return units_Angstroms[res_fsc_half] - - - -def run(config): - """ - Compare a submission to ground truth. - """ - - n_pix = config["data"]["n_pix"] - - submission = torch.load(config["data"]["submission"]["fname"]) - submission_volume_key = config["data"]["submission"]["volume_key"] - submission_metadata_key = config["data"]["submission"]["metadata_key"] - label_key = config["data"]["submission"]["label_key"] - user_submission_label = submission[label_key] - - metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) - - results_dict = {} - results_dict["config"] = config - results_dict["user_submitted_populations"] = ( - submission[submission_metadata_key] / submission[submission_metadata_key].sum() - ) - - map_to_map_distances = { - "fsc": FSCDistance(config), - "corr": Correlation(config), - "l2": L2DistanceSum(config), - "bioem": BioEM3dDistance(config), - "res": ResDistance(config), - } - - maps_user_flat = submission[submission_volume_key].reshape( - len(submission["volumes"]), -1 - ) - maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( - -1, n_pix**3 - ) - - if config["data"]["mask"]["do"]: - mask = ( - mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() - ) - maps_gt_flat = maps_gt_flat[:, mask] - maps_user_flat = maps_user_flat[:, mask] - else: - maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) - maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) - - if config["analysis"]["normalize"]["do"]: - if config["analysis"]["normalize"]["method"] == "median_zscore": - maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) - maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values - maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) - - computed_assets = {} - for distance_label, map_to_map_distance in map_to_map_distances.items(): - if distance_label in config["analysis"]["metrics"]: # TODO: can remove - print("cost matrix", distance_label) - - cost_matrix = map_to_map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict, - ) - computed_assets = map_to_map_distance.get_computed_assets( - maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict, - ) - computed_assets.update(computed_assets) - - cost_matrix_df = pd.DataFrame( - cost_matrix, columns=None, index=metadata_gt.populations.tolist() - ) - - # output results - single_distance_results_dict = { - "cost_matrix": cost_matrix_df, - "user_submission_label": user_submission_label, - "computed_assets": computed_assets, - } - - results_dict[distance_label] = single_distance_results_dict - - # Validate before saving - _ = MapToMapResultsValidator.from_dict(results_dict) - - with open(config["output"], "wb") as f: - pickle.dump(results_dict, f) - - return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py new file mode 100644 index 0000000..923e42b --- /dev/null +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -0,0 +1,108 @@ +import mrcfile +import pandas as pd +import pickle +import torch +import numpy as np + +from ..data._validation.output_validators import MapToMapResultsValidator +from .._map_to_map.map_to_map_distance import ( + FSCDistance, + Correlation, + L2DistanceSum, + BioEM3dDistance, + ResDistance, +) + + +AVAILABLE_MAP2MAP_DISTANCES = { + "fsc": FSCDistance, + "corr": Correlation, + "l2": L2DistanceSum, + "bioem": BioEM3dDistance, + "res": ResDistance, +} + + +def run(config): + """ + Compare a submission to ground truth. + """ + + map_to_map_distances = { + distance_label: distance_class(config) + for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() + } + + n_pix = config["data"]["n_pix"] + + submission = torch.load(config["data"]["submission"]["fname"]) + submission_volume_key = config["data"]["submission"]["volume_key"] + submission_metadata_key = config["data"]["submission"]["metadata_key"] + label_key = config["data"]["submission"]["label_key"] + user_submission_label = submission[label_key] + + metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) + + results_dict = {} + results_dict["config"] = config + results_dict["user_submitted_populations"] = ( + submission[submission_metadata_key] / submission[submission_metadata_key].sum() + ) + + maps_user_flat = submission[submission_volume_key].reshape( + len(submission["volumes"]), -1 + ) + maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( + -1, n_pix**3 + ) + + if config["data"]["mask"]["do"]: + mask = ( + mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() + ) + maps_gt_flat = maps_gt_flat[:, mask] + maps_user_flat = maps_user_flat[:, mask] + else: + maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) + maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) + + if config["analysis"]["normalize"]["do"]: + if config["analysis"]["normalize"]["method"] == "median_zscore": + maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values + maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) + maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values + maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) + + computed_assets = {} + for distance_label, map_to_map_distance in map_to_map_distances.items(): + if distance_label in config["analysis"]["metrics"]: # TODO: can remove + print("cost matrix", distance_label) + + cost_matrix = map_to_map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict, + ) + computed_assets = map_to_map_distance.get_computed_assets( + maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict, + ) + computed_assets.update(computed_assets) + + cost_matrix_df = pd.DataFrame( + cost_matrix, columns=None, index=metadata_gt.populations.tolist() + ) + + # output results + single_distance_results_dict = { + "cost_matrix": cost_matrix_df, + "user_submission_label": user_submission_label, + "computed_assets": computed_assets, + } + + results_dict[distance_label] = single_distance_results_dict + + # Validate before saving + _ = MapToMapResultsValidator.from_dict(results_dict) + + with open(config["output"], "wb") as f: + pickle.dump(results_dict, f) + + return results_dict diff --git a/src/cryo_challenge/_ploting/plotting_utils.py b/src/cryo_challenge/_ploting/plotting_utils.py index 04d5ca9..681ab01 100644 --- a/src/cryo_challenge/_ploting/plotting_utils.py +++ b/src/cryo_challenge/_ploting/plotting_utils.py @@ -1,6 +1,7 @@ import numpy as np + def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist \ No newline at end of file + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist diff --git a/src/cryo_challenge/data/__init__.py b/src/cryo_challenge/data/__init__.py index 8b4655c..fb27bbd 100644 --- a/src/cryo_challenge/data/__init__.py +++ b/src/cryo_challenge/data/__init__.py @@ -1,6 +1,18 @@ -from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist -from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl -from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator -from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator -from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD -from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL +from ._validation.config_validators import ( + validate_input_config_disttodist as validate_input_config_disttodist, +) +from ._validation.config_validators import ( + validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, +) +from cryo_challenge.data._validation.output_validators import ( + DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, +) +from cryo_challenge.data._validation.output_validators import ( + MetricDistToDistValidator as MetricDistToDistValidator, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorEMD as ReplicateValidatorEMD, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorKL as ReplicateValidatorKL, +) diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index f194c14..2a4d954 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): # Reshape volumes to correct size if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.)))) + box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) elif volumes.dim() == 4: pass else: - raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).") + raise ValueError( + f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " + f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " + f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." + ) volumes_ds = torch.empty( (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 5f4a14f..9f76a6d 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -13,7 +13,7 @@ @dataclass_json @dataclass class MapToMapResultsValidator: - ''' + """ Validate the output dictionary of the map-to-map distance matrix computation. config: dict, input config dictionary. @@ -22,7 +22,8 @@ class MapToMapResultsValidator: l2: dict, L2 results. bioem: dict, BioEM results. fsc: dict, FSC results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None @@ -50,7 +51,7 @@ class ReplicateValidatorEMD: Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline. q_opt: List[float], optimal user submitted distribution, which sums to 1. - EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). + EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). The transport plan is a joint distribution, such that: summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution. transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns). @@ -62,6 +63,7 @@ class ReplicateValidatorEMD: The transport plan is a joint distribution, such that: summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution. """ + q_opt: List[float] EMD_opt: float transport_plan_opt: List[List[float]] @@ -88,8 +90,9 @@ class ReplicateValidatorKL: iter_stop: int, number of iterations until convergence. eps_stop: float, stopping criterion. klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q). - klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). + klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). """ + q_opt: List[float] klpq_opt: float klqp_opt: float @@ -107,11 +110,12 @@ def __post_init__(self): @dataclass_json @dataclass class MetricDistToDistValidator: - ''' + """ Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline. replicates: dict, dictionary of replicates. - ''' + """ + replicates: dict def validate_replicates(self, n_replicates): @@ -127,7 +131,7 @@ def validate_replicates(self, n_replicates): @dataclass_json @dataclass class DistributionToDistributionResultsValidator: - ''' + """ Validate the output dictionary of the distribution-to-distribution pipeline. config: dict, input config dictionary. @@ -137,7 +141,8 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor id: str diff --git a/tests/config_files/test_config_distribution_to_distribution.yaml b/tests/config_files/test_config_distribution_to_distribution.yaml index 05b6317..e4f465d 100644 --- a/tests/config_files/test_config_distribution_to_distribution.yaml +++ b/tests/config_files/test_config_distribution_to_distribution.yaml @@ -1,4 +1,4 @@ -input_fname: tests/results/test_map_to_map_distance_matrix_submission_0.pkl +input_fname: tests/data/data_for_dist_to_dist/test_map_to_map_distance_matrix_submission_0.pkl metrics: - l2 gt_metadata_fname: tests/data/Ground_truth/test_metadata_10.csv diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index cbf6d09..7dfa7e9 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -1,8 +1,8 @@ data: - n_pix: 224 - psize: 2.146 + n_pix: 16 + psize: 30.044 submission: - fname: tests/data/dataset_2_submissions/test_submission_0_n8.pt + fname: tests/data/dataset_2_submissions/submission_1000.pt volume_key: volumes metadata_key: populations label_key: id @@ -11,7 +11,7 @@ data: metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true - volume: tests/data/Ground_truth/mask_dilated_wide_224x224.mrc + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc analysis: metrics: - l2 diff --git a/tests/config_files/test_config_svd.yaml b/tests/config_files/test_config_svd.yaml index c392525..acff44b 100644 --- a/tests/config_files/test_config_svd.yaml +++ b/tests/config_files/test_config_svd.yaml @@ -1,6 +1,6 @@ path_to_volumes: tests/data/dataset_2_submissions/ -box_size_ds: 32 -submission_list: [0] +box_size_ds: 16 +submission_list: [1000] experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref" # optional unless experiment_mode is "all_vs_ref" path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.pt diff --git a/tests/data/Ground_truth/1.mrc b/tests/data/Ground_truth/1.mrc new file mode 100644 index 0000000..9e4c873 Binary files /dev/null and b/tests/data/Ground_truth/1.mrc differ diff --git a/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc b/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc new file mode 100644 index 0000000..e32296e Binary files /dev/null and b/tests/data/Ground_truth/mask_dilated_wide_224x224.mrc differ diff --git a/tests/data/Ground_truth/test_maps_gt_flat_10.npy b/tests/data/Ground_truth/test_maps_gt_flat_10.npy new file mode 100644 index 0000000..874afc8 Binary files /dev/null and b/tests/data/Ground_truth/test_maps_gt_flat_10.npy differ diff --git a/tests/data/Ground_truth/test_maps_gt_flat_10.pt b/tests/data/Ground_truth/test_maps_gt_flat_10.pt new file mode 100644 index 0000000..2bbe009 Binary files /dev/null and b/tests/data/Ground_truth/test_maps_gt_flat_10.pt differ diff --git a/tests/data/Ground_truth/test_mask_dilated_wide.mrc b/tests/data/Ground_truth/test_mask_dilated_wide.mrc new file mode 100644 index 0000000..307e80f Binary files /dev/null and b/tests/data/Ground_truth/test_mask_dilated_wide.mrc differ diff --git a/tests/data/Ground_truth/test_metadata_10.csv b/tests/data/Ground_truth/test_metadata_10.csv new file mode 100644 index 0000000..de041fc --- /dev/null +++ b/tests/data/Ground_truth/test_metadata_10.csv @@ -0,0 +1,11 @@ +index,volumes,populations_count,pc1,populations +3238,13396.mrc,1,-231.62100638454024,2.9636654614427123e-05 +3020,10063.mrc,13,-179.257841640357,0.0003852765099875 +1113,01421.mrc,12,-141.8237629192062,0.0003556398553731 +3592,21858.mrc,4,-101.62462603216916,0.0001185466184577 +1947,03298.mrc,5,-34.99878410436052,0.0001481832730721 +2097,03764.mrc,6,3.946553364334135,0.0001778199276865 +1574,02336.mrc,5,44.70670231717438,0.0001481832730721 +2813,08011.mrc,8,108.6308222660271,0.0002370932369154 +88,00090.mrc,21,147.70416251702042,0.0006223697469029 +771,00906.mrc,11,186.3446095998357,0.0003260032007586 diff --git a/tests/data/data_for_dist_to_dist/test_map_to_map_distance_matrix_submission_0.pkl b/tests/data/data_for_dist_to_dist/test_map_to_map_distance_matrix_submission_0.pkl new file mode 100644 index 0000000..26e9653 Binary files /dev/null and b/tests/data/data_for_dist_to_dist/test_map_to_map_distance_matrix_submission_0.pkl differ diff --git a/tests/data/dataset_2_submissions/submission_1000.pt b/tests/data/dataset_2_submissions/submission_1000.pt new file mode 100644 index 0000000..c2f6051 Binary files /dev/null and b/tests/data/dataset_2_submissions/submission_1000.pt differ diff --git a/tests/data/test_maps_gt_flat_2.pt b/tests/data/test_maps_gt_flat_2.pt deleted file mode 100644 index dc3293c..0000000 Binary files a/tests/data/test_maps_gt_flat_2.pt and /dev/null differ diff --git a/tests/data/test_metadata_2.csv b/tests/data/test_metadata_2.csv deleted file mode 100644 index b564400..0000000 --- a/tests/data/test_metadata_2.csv +++ /dev/null @@ -1,3 +0,0 @@ -index,volumes,populations_count,pc1,populations -3238,13396.mrc,1,-231.62100638454024,2.9636654614427123e-05 -3789,30099.mrc,2,243.32448171011487,5.927330922885425e-05 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc new file mode 100644 index 0000000..eb89602 Binary files /dev/null and b/tests/data/unprocessed_dataset_2_submissions/submission_x/1.mrc differ diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc new file mode 100644 index 0000000..7796db1 Binary files /dev/null and b/tests/data/unprocessed_dataset_2_submissions/submission_x/2.mrc differ diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc new file mode 100644 index 0000000..9d050e2 Binary files /dev/null and b/tests/data/unprocessed_dataset_2_submissions/submission_x/3.mrc differ diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc b/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc new file mode 100644 index 0000000..ec4d181 Binary files /dev/null and b/tests/data/unprocessed_dataset_2_submissions/submission_x/4.mrc differ diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt b/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt new file mode 100644 index 0000000..53e4b60 --- /dev/null +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt @@ -0,0 +1,4 @@ +0.25 +0.25 +0.25 +0.25 diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index b8318b9..1fb797d 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -2,16 +2,16 @@ "gt": { "name": "gt", "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", - "box_size": 224, - "pixel_size": 2.146, + "box_size": 16, + "pixel_size": 30.044, "ref_align_fname": "1.mrc" }, "0": { "name": "raw_submission_in_testdata", "align": 1, - "flavor_name": "test flavor", - "box_size": 244, - "pixel_size": 2.146, + "flavor_name": "test flavor", + "box_size": 32, + "pixel_size": 15.022, "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", "flip": 1, "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", diff --git a/tests/scripts/fetch_test_data.sh b/tests/scripts/fetch_test_data.sh deleted file mode 100644 index c252871..0000000 --- a/tests/scripts/fetch_test_data.sh +++ /dev/null @@ -1,12 +0,0 @@ -mkdir -p tests/data/dataset_2_submissions tests/data/dataset_2_submissions tests/results tests/data/unprocessed_dataset_2_submissions/submission_x tests/data/Ground_truth/ tests/data/Ground_truth -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/dataset_2_submissions/test_submission_0_n8.pt?download=true -O tests/data/dataset_2_submissions/test_submission_0_n8.pt -ADIR=$(pwd) -ln -s $ADIR/tests/data/dataset_2_submissions/test_submission_0_n8.pt $ADIR/tests/data/dataset_2_submissions/submission_0.pt # symlink for svd which needs submission_0.pt for filename -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_maps_gt_flat_10.pt?download=true -O tests/data/Ground_truth/test_maps_gt_flat_10.pt -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_metadata_10.csv?download=true -O tests/data/Ground_truth/test_metadata_10.csv -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/1.mrc?download=true -O tests/data/Ground_truth/1.mrc -wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/Ground_truth/mask_dilated_wide_224x224.mrc?download=true -O tests/data/Ground_truth/mask_dilated_wide_224x224.mrc -for FILE in 1.mrc 2.mrc 3.mrc 4.mrc populations.txt -do - wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/unprocessed_dataset_2_submissions/submission_x/${FILE}?download=true -O tests/data/unprocessed_dataset_2_submissions/submission_x/${FILE} -done diff --git a/tests/test_distribution_to_distribution.py b/tests/test_distribution_to_distribution.py index d9c340b..a4cfb79 100644 --- a/tests/test_distribution_to_distribution.py +++ b/tests/test_distribution_to_distribution.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_distribution2distribution_pipeline -def test_run_distribution2distribution_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'}) - run_distribution2distribution_pipeline.main(args) \ No newline at end of file +def test_run_distribution2distribution_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_distribution_to_distribution.yaml"} + ) + run_distribution2distribution_pipeline.main(args) diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c782a8c..e31f29f 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_map2map_pipeline -def test_run_map2map_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'}) - run_map2map_pipeline.main(args) \ No newline at end of file +def test_run_map2map_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map.yaml"} + ) + run_map2map_pipeline.main(args) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index cbf54e4..31db34e 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_preprocessing -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'}) - run_preprocessing.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"}) + run_preprocessing.main(args) diff --git a/tests/test_svd.py b/tests/test_svd.py index 568370e..ea166ea 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'}) - run_svd.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) + run_svd.main(args) diff --git a/tutorials/1_tutorial_preprocessing.ipynb b/tutorials/1_tutorial_preprocessing.ipynb index cc6a459..84db8c9 100644 --- a/tutorials/1_tutorial_preprocessing.ipynb +++ b/tutorials/1_tutorial_preprocessing.ipynb @@ -203,7 +203,7 @@ "# Select path to Config file\n", "# An example of this file is available in the path ../config_files/config_preproc.yaml\n", "config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_preproc_path.filter_pattern = '*.yaml'\n", + "config_preproc_path.filter_pattern = \"*.yaml\"\n", "display(config_preproc_path)" ] }, @@ -226,7 +226,7 @@ "if os.path.isabs(output_path):\n", " full_output_path = output_path\n", "else:\n", - " full_output_path = os.path.join(os.getcwd(), '..', output_path)" + " full_output_path = os.path.join(os.getcwd(), \"..\", output_path)" ] }, { diff --git a/tutorials/2_tutorial_svd.ipynb b/tutorials/2_tutorial_svd.ipynb index b41bfba..fe8f432 100644 --- a/tutorials/2_tutorial_svd.ipynb +++ b/tutorials/2_tutorial_svd.ipynb @@ -62,7 +62,7 @@ "# Select path to SVD config file\n", "# An example of this file is available in the path ../config_files/config_svd.yaml\n", "config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_svd_path.filter_pattern = '*.yaml'\n", + "config_svd_path.filter_pattern = \"*.yaml\"\n", "display(config_svd_path)" ] }, @@ -125,7 +125,7 @@ "source": [ "# Select path to SVD results\n", "svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_results_path.filter_pattern = '*.pt'\n", + "svd_results_path.filter_pattern = \"*.pt\"\n", "display(svd_results_path)" ] }, @@ -316,7 +316,7 @@ "source": [ "# Select path to SVD results\n", "svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_all_vs_all_results_path.filter_pattern = '*.pt'\n", + "svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n", "display(svd_all_vs_all_results_path)" ] }, diff --git a/tutorials/3_tutorial_map2map.ipynb b/tutorials/3_tutorial_map2map.ipynb index a3701ff..0497578 100644 --- a/tutorials/3_tutorial_map2map.ipynb +++ b/tutorials/3_tutorial_map2map.ipynb @@ -23,15 +23,8 @@ "\n", "from cryo_challenge.data._validation.config_validators import (\n", " validate_input_config_mtm,\n", - " validate_config_mtm_data, \n", - " validate_config_mtm_data_submission, \n", - " validate_config_mtm_data_ground_truth, \n", - " validate_config_mtm_data_mask, \n", - " validate_config_mtm_analysis, \n", - " validate_config_mtm_analysis_normalize, \n", - " )\n", - "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator\n", - "from cryo_challenge.data._validation.config_validators import validate_maptomap_result" + ")\n", + "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator" ] }, { @@ -80,7 +73,7 @@ "# Select path to Map to Map config file\n", "# An example of this file is available in the path ../config_files/config_map_to_map.yaml\n", "config_m2m_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_m2m_path.filter_pattern = '*.yaml'\n", + "config_m2m_path.filter_pattern = \"*.yaml\"\n", "display(config_m2m_path)" ] }, @@ -341,8 +334,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { diff --git a/tutorials/4_tutorial_distribution2distribution.ipynb b/tutorials/4_tutorial_distribution2distribution.ipynb index 07dc1d9..2a6fc53 100644 --- a/tutorials/4_tutorial_distribution2distribution.ipynb +++ b/tutorials/4_tutorial_distribution2distribution.ipynb @@ -30,7 +30,7 @@ "import pickle\n", "from ipyfilechooser import FileChooser\n", "\n", - "from cryo_challenge.data import validate_input_config_disttodist, validate_config_dtd_optimal_q_kl\n", + "from cryo_challenge.data import validate_input_config_disttodist\n", "from cryo_challenge.data import DistributionToDistributionResultsValidator" ] }, @@ -65,7 +65,7 @@ "# Select path to Distribution to Distribution config file\n", "# An example of this file is available in the path ../config_files/config_distribution_to_distribution.yaml\n", "config_d2d_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_d2d_path.filter_pattern = '*.yaml'\n", + "config_d2d_path.filter_pattern = \"*.yaml\"\n", "display(config_d2d_path)" ] }, @@ -199,8 +199,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output_fname\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output_fname\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { @@ -286,7 +286,6 @@ } ], "source": [ - "from cryo_challenge.data import MetricDistToDistValidator\n", "MetricDistToDistValidator?" ] }, @@ -302,9 +301,7 @@ "execution_count": 30, "metadata": {}, "outputs": [], - "source": [ - "from cryo_challenge.data import ReplicateValidatorEMD, ReplicateValidatorKL" - ] + "source": [] }, { "cell_type": "code", diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index ed8a924..e2648c8 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -23,7 +23,9 @@ "metadata": {}, "outputs": [], "source": [ - "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import (\n", + " sort_by_transport,\n", + ")\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -102,6 +104,7 @@ " map2map_results: List[str]\n", " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", + "\n", "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", "config = PlottingConfig.from_dict(config)\n", @@ -115,7 +118,7 @@ "outputs": [], "source": [ "metadata_df = pd.read_csv(config.gt_metadata)\n", - "metadata_df.sort_values('pc1', inplace=True)\n", + "metadata_df.sort_values(\"pc1\", inplace=True)\n", "gt_ordering = metadata_df.index.tolist()" ] }, @@ -136,12 +139,14 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[map2map_distance]['user_submission_label']\n", + " anonymous_label = data[map2map_distance][\"user_submission_label\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", - "data_d = get_fsc_distances(config.map2map_results, 'fsc')" + "\n", + "\n", + "data_d = get_fsc_distances(config.map2map_results, \"fsc\")" ] }, { @@ -162,22 +167,24 @@ ], "source": [ "def plot_fsc_distances(data_d, gt_ordering):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'$d_{FSC}$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"$d_{FSC}$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, data) in enumerate(data_d.items()):\n", - " map2map_dist_matrix = data['fsc']['cost_matrix'].iloc[gt_ordering].values\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", - "\n", + " map2map_dist_matrix = data[\"fsc\"][\"cost_matrix\"].iloc[gt_ordering].values\n", + " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(\n", + " map2map_dist_matrix\n", + " )\n", "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", @@ -185,14 +192,24 @@ " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", "plot_fsc_distances(data_d, gt_ordering)" ] @@ -214,12 +231,13 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['fsc']['user_submission_label']\n", - " data_d[anonymous_label] = data['fsc']['computed_assets']['fsc_matrix']\n", + " anonymous_label = data[\"fsc\"][\"user_submission_label\"]\n", + " data_d[anonymous_label] = data[\"fsc\"][\"computed_assets\"][\"fsc_matrix\"]\n", " return data_d\n", "\n", + "\n", "fscs_sorted_d = get_full_fsc_curve(config.map2map_results)" ] }, @@ -229,9 +247,11 @@ "metadata": {}, "outputs": [], "source": [ - "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", - "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" + "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(\n", + " fscs_sorted_d[\"Cookie Dough\"], threshold=0.5\n", + ")\n", + "n_fourier_bins = fscs_sorted_d[\"Cookie Dough\"].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1, n_fourier_bins + 1) / n_fourier_bins)" ] }, { @@ -261,7 +281,7 @@ } ], "source": [ - "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect='auto', cmap='Blues_r')\n", + "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect=\"auto\", cmap=\"Blues_r\")\n", "plt.colorbar()" ] }, @@ -283,17 +303,18 @@ ], "source": [ "def plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={}):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'Resolution $(\\AA)$ at $FSC=0.5$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"Resolution $(\\AA)$ at $FSC=0.5$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", @@ -302,27 +323,38 @@ "\n", " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", - "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", " vmin = map2map_dist_matrix.min()\n", " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", - " if 'vmax' in overwrite_dict.keys():\n", - " vmax = overwrite_dict['vmax']\n", - " if 'vmin' in overwrite_dict.keys():\n", - " vmin = overwrite_dict['vmin']\n", - "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " if \"vmax\" in overwrite_dict.keys():\n", + " vmax = overwrite_dict[\"vmax\"]\n", + " if \"vmin\" in overwrite_dict.keys():\n", + " vmin = overwrite_dict[\"vmin\"]\n", + "\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", - "plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={'vmax': 31})" + "plot_res_at_fsc_threshold_distances(\n", + " fscs_sorted_d, gt_ordering, overwrite_dict={\"vmax\": 31}\n", + ")" ] }, { @@ -338,9 +370,9 @@ "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", + "fname = config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"][0]\n", "\n", - "with open(fname, 'rb') as f:\n", + "with open(fname, \"rb\") as f:\n", " data = pickle.load(f)" ] }, @@ -354,13 +386,16 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['id']\n", + " anonymous_label = data[\"id\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" + "\n", + "dist2dist_results_d = get_dist2dist_results(\n", + " config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"]\n", + ")" ] }, { @@ -381,46 +416,58 @@ ], "source": [ "window_size = 15\n", - "nrows, ncols = 3,4\n", + "nrows, ncols = 3, 4\n", "suptitle = f\"Submitted populations vs optimal populations \\n d_FSC (no rank) | n_replicates={data['config']['n_replicates']} | window_size={window_size} | n_pool_microstate={data['config']['n_pool_microstate']}\"\n", "\n", + "\n", "def plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols):\n", + " fig, axes = plt.subplots(nrows, ncols, figsize=(40, 25))\n", "\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=(40,25))\n", - " \n", - " fig.suptitle(\n", - " suptitle,\n", - " fontsize=30,\n", - " y=0.95)\n", + " fig.suptitle(suptitle, fontsize=30, y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", - " \n", + " for idx_fname, (_, data) in enumerate(dist2dist_results_d.items()):\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " data[\"user_submitted_populations\"], color=\"black\", label=\"submited\"\n", + " )\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_title(data[\"id\"], fontsize=30)\n", "\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_title(data['id'], fontsize=30)\n", - "\n", - " def window_q(q_opt,window_size):\n", - " running_avg = np.convolve(q_opt, np.ones(window_size)/window_size, mode='same')\n", + " def window_q(q_opt, window_size):\n", + " running_avg = np.convolve(\n", + " q_opt, np.ones(window_size) / window_size, mode=\"same\"\n", + " )\n", " return running_avg\n", - " \n", - " for replicate_idx in range(data['config']['n_replicates']):\n", - " \n", - " if replicate_idx == 0: \n", - " label_d = {'EMD': 'EMD', 'KL': 'KL', 'KL_raw': 'Unwindowed', 'EMD_raw': 'Unwindowed'}\n", - " else:\n", - " label_d = {'EMD': None, 'KL': None, 'KL_raw': None, 'EMD_raw': None}\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(window_q(data['fsc']['replicates'][replicate_idx]['EMD']['q_opt'],window_size), color='blue', alpha=alpha, label=label_d['EMD'])\n", - "\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_xlabel('Submission index')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_ylabel('Population')\n", "\n", - " legend = axes[idx_fname//ncols, idx_fname%ncols].legend()\n", + " for replicate_idx in range(data[\"config\"][\"n_replicates\"]):\n", + " if replicate_idx == 0:\n", + " label_d = {\n", + " \"EMD\": \"EMD\",\n", + " \"KL\": \"KL\",\n", + " \"KL_raw\": \"Unwindowed\",\n", + " \"EMD_raw\": \"Unwindowed\",\n", + " }\n", + " else:\n", + " label_d = {\"EMD\": None, \"KL\": None, \"KL_raw\": None, \"EMD_raw\": None}\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " window_q(\n", + " data[\"fsc\"][\"replicates\"][replicate_idx][\"EMD\"][\"q_opt\"],\n", + " window_size,\n", + " ),\n", + " color=\"blue\",\n", + " alpha=alpha,\n", + " label=label_d[\"EMD\"],\n", + " )\n", + "\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_xlabel(\"Submission index\")\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_ylabel(\"Population\")\n", + "\n", + " legend = axes[idx_fname // ncols, idx_fname % ncols].legend()\n", " for line, text in zip(legend.get_lines(), legend.get_texts()):\n", " text.set_color(line.get_color())\n", " line.set_alpha(1)\n", "\n", - "plot_q_opt_distances(dist2dist_results_d, suptitle,nrows, ncols)" + "\n", + "plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols)" ] }, { @@ -437,7 +484,6 @@ "outputs": [], "source": [ "def wragle_pkl_to_dataframe(pkl_globs):\n", - "\n", " fnames = []\n", " for fname_glob in pkl_globs:\n", " fnames.extend(glob.glob(fname_glob))\n", @@ -446,28 +492,50 @@ "\n", " df_list = []\n", " n_replicates = 30\n", - " metric = 'fsc'\n", + " metric = \"fsc\"\n", "\n", " for fname in fnames:\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", "\n", - " df_list.append(pd.DataFrame({\n", - " 'EMD_opt': [data[metric]['replicates'][i]['EMD']['EMD_opt'] for i in range(n_replicates)],\n", - " 'EMD_submitted': [data[metric]['replicates'][i]['EMD']['EMD_submitted'] for i in range(n_replicates)],\n", - " 'klpq_opt': [data[metric]['replicates'][i]['KL']['klpq_opt'] for i in range(n_replicates)],\n", - " 'klqp_opt': [data[metric]['replicates'][i]['KL']['klqp_opt'] for i in range(n_replicates)],\n", - " 'klpq_submitted': [data[metric]['replicates'][i]['KL']['klpq_submitted'] for i in range(n_replicates)], \n", - " 'klqp_submitted': [data[metric]['replicates'][i]['KL']['klqp_submitted'] for i in range(n_replicates)], \n", - " 'id': data['id'],\n", - " 'n_pool_microstate': data['config']['n_pool_microstate'],\n", - " }))\n", + " df_list.append(\n", + " pd.DataFrame(\n", + " {\n", + " \"EMD_opt\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"EMD_submitted\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"id\": data[\"id\"],\n", + " \"n_pool_microstate\": data[\"config\"][\"n_pool_microstate\"],\n", + " }\n", + " )\n", + " )\n", "\n", " df = pd.concat(df_list)\n", - " df['EMD_opt_norm'] = df['EMD_opt'] / df['n_pool_microstate']\n", - " df['EMD_submitted_norm'] = df['EMD_submitted'] / df['n_pool_microstate']\n", + " df[\"EMD_opt_norm\"] = df[\"EMD_opt\"] / df[\"n_pool_microstate\"]\n", + " df[\"EMD_submitted_norm\"] = df[\"EMD_submitted\"] / df[\"n_pool_microstate\"]\n", "\n", - " return df\n" + " return df" ] }, { @@ -476,7 +544,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results[\"emd_plot\"][\"pkl_globs\"])" ] }, { @@ -498,50 +566,96 @@ "source": [ "def plot_EMD_vs_EMDopt(df, suptitle=None):\n", " alpha = 1\n", - " df_average = df.groupby(['n_pool_microstate','id']).mean().reset_index()\n", - " df_std = df_average.groupby(['id']).std().reset_index().filter(['EMD_opt_norm','EMD_submitted_norm', 'id']).rename(columns={'EMD_opt_norm':'EMD_opt_norm_std', 'EMD_submitted_norm':'EMD_submitted_norm_std'})\n", - " df_average = df.groupby(['id']).mean().reset_index()\n", - "\n", - " df_average_and_error = pd.merge(df_average, df_std, on='id')\n", + " df_average = df.groupby([\"n_pool_microstate\", \"id\"]).mean().reset_index()\n", + " df_std = (\n", + " df_average.groupby([\"id\"])\n", + " .std()\n", + " .reset_index()\n", + " .filter([\"EMD_opt_norm\", \"EMD_submitted_norm\", \"id\"])\n", + " .rename(\n", + " columns={\n", + " \"EMD_opt_norm\": \"EMD_opt_norm_std\",\n", + " \"EMD_submitted_norm\": \"EMD_submitted_norm_std\",\n", + " }\n", + " )\n", + " )\n", + " df_average = df.groupby([\"id\"]).mean().reset_index()\n", + "\n", + " df_average_and_error = pd.merge(df_average, df_std, on=\"id\")\n", "\n", " # Get unique ids\n", - " ids = df_average_and_error['id'].unique()\n", + " ids = df_average_and_error[\"id\"].unique()\n", "\n", " # Define marker styles\n", - " markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']\n", + " markers = [\n", + " \"o\",\n", + " \"v\",\n", + " \"^\",\n", + " \"<\",\n", + " \">\",\n", + " \"s\",\n", + " \"p\",\n", + " \"*\",\n", + " \"h\",\n", + " \"H\",\n", + " \"+\",\n", + " \"x\",\n", + " \"D\",\n", + " \"d\",\n", + " \"|\",\n", + " \"_\",\n", + " ]\n", " marker_size = 250\n", "\n", - " plt.style.use('seaborn-v0_8-poster')\n", - " plot_width, plot_height = 8,6\n", + " plt.style.use(\"seaborn-v0_8-poster\")\n", + " plot_width, plot_height = 8, 6\n", " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", " for idx, id_label in enumerate(ids):\n", - " df_average_id = df_average[df_average['id'] == id_label]\n", - " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", - "\n", - " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", - " y=df_average_and_error['EMD_opt_norm'], \n", - " xerr=df_average_and_error['EMD_submitted_norm_std'], \n", - " yerr=df_average_and_error['EMD_opt_norm_std'], \n", - " fmt='', alpha=0.05, linestyle='None', ecolor='k', elinewidth=2, capsize=5)\n", + " df_average_id = df_average[df_average[\"id\"] == id_label]\n", + " sns.scatterplot(\n", + " x=\"EMD_submitted_norm\",\n", + " y=\"EMD_opt_norm\",\n", + " data=df_average_id,\n", + " alpha=alpha,\n", + " marker=markers[idx % len(markers)],\n", + " label=id_label,\n", + " s=marker_size,\n", + " )\n", + "\n", + " plt.errorbar(\n", + " x=df_average_and_error[\"EMD_submitted_norm\"],\n", + " y=df_average_and_error[\"EMD_opt_norm\"],\n", + " xerr=df_average_and_error[\"EMD_submitted_norm_std\"],\n", + " yerr=df_average_and_error[\"EMD_opt_norm_std\"],\n", + " fmt=\"\",\n", + " alpha=0.05,\n", + " linestyle=\"None\",\n", + " ecolor=\"k\",\n", + " elinewidth=2,\n", + " capsize=5,\n", + " )\n", "\n", " plt.xlim(left=0.5)\n", " plt.ylim(bottom=0.5)\n", "\n", - " limits = [np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", - " np.max([plt.xlim(), plt.ylim()])] # max of both axes\n", + " limits = [\n", + " np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", + " np.max([plt.xlim(), plt.ylim()]),\n", + " ] # max of both axes\n", "\n", - " plt.plot(limits, limits, 'k-', alpha=0.75, zorder=0)\n", + " plt.plot(limits, limits, \"k-\", alpha=0.75, zorder=0)\n", " plt.xlim(limits)\n", " plt.ylim(limits)\n", - " legend = plt.legend(loc='upper left', fontsize=12)\n", + " legend = plt.legend(loc=\"upper left\", fontsize=12)\n", " for handle in legend.legend_handles:\n", " handle.set_alpha(1)\n", "\n", " plt.suptitle(suptitle)\n", "\n", - "suptitle = r'$d_{FSC}$ (no rank)'\n", + "\n", + "suptitle = r\"$d_{FSC}$ (no rank)\"\n", "plot_EMD_vs_EMDopt(df, suptitle=None)" ] } diff --git a/tutorials/README.md b/tutorials/README.md index d427ba5..842bdf7 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -25,4 +25,4 @@ This notebook walks through generating and analyzing (plots) the map to map dist - output: a `.pkl` file ### `5_tutorial_plotting.ipynb` -This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. \ No newline at end of file +This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results.