From 0ce2df11a5d9320a85caa4c6b61dd066e994e1b4 Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Mon, 12 Aug 2024 12:22:27 -0400 Subject: [PATCH] implement power spectrum and bfactor scaling --- .github/workflows/main_merge_check.yml | 2 +- .../config_distribution_to_distribution.yaml | 2 +- .../config_map_to_map_distance_matrix.yaml | 10 +- src/cryo_challenge/__init__.py | 4 +- .../distribution_to_distribution.py | 17 +- .../_map_to_map/map_to_map_distance_matrix.py | 2 +- src/cryo_challenge/_ploting/plotting_utils.py | 5 +- .../_preprocessing/bfactor_normalize.py | 36 ++ .../_preprocessing/normalize.py | 22 -- src/cryo_challenge/data/__init__.py | 24 +- src/cryo_challenge/data/_io/svd_io_utils.py | 10 +- .../data/_validation/output_validators.py | 21 +- src/cryo_challenge/power_spectrum_utils.py | 71 ++++ .../submission_x/submission_config.json | 2 +- tests/test_distribution_to_distribution.py | 8 +- tests/test_map_to_map.py | 8 +- tests/test_power_spectrum_and_bfactor.py | 72 ++++ tests/test_preprocessing.py | 6 +- tests/test_svd.py | 6 +- tutorials/1_tutorial_preprocessing.ipynb | 4 +- tutorials/2_tutorial_svd.ipynb | 6 +- tutorials/3_tutorial_map2map.ipynb | 17 +- ...4_tutorial_distribution2distribution.ipynb | 13 +- tutorials/5_tutorial_plotting.ipynb | 344 ++++++++++++------ tutorials/README.md | 2 +- 25 files changed, 500 insertions(+), 214 deletions(-) create mode 100644 src/cryo_challenge/_preprocessing/bfactor_normalize.py delete mode 100644 src/cryo_challenge/_preprocessing/normalize.py create mode 100644 src/cryo_challenge/power_spectrum_utils.py create mode 100644 tests/test_power_spectrum_and_bfactor.py diff --git a/.github/workflows/main_merge_check.yml b/.github/workflows/main_merge_check.yml index 8e6f5c5..b4aa6e1 100644 --- a/.github/workflows/main_merge_check.yml +++ b/.github/workflows/main_merge_check.yml @@ -11,4 +11,4 @@ jobs: if: github.base_ref == 'main' && github.head_ref != 'dev' run: | echo "ERROR: You can only merge to main from dev." - exit 1 \ No newline at end of file + exit 1 diff --git a/config_files/config_distribution_to_distribution.yaml b/config_files/config_distribution_to_distribution.yaml index da8b3d9..eaa94ac 100644 --- a/config_files/config_distribution_to_distribution.yaml +++ b/config_files/config_distribution_to_distribution.yaml @@ -12,4 +12,4 @@ cvxpy_solver: ECOS optimal_q_kl: n_iter: 100000 break_atol: 0.0001 -output_fname: results/distribution_to_distribution_submission_0.pkl \ No newline at end of file +output_fname: results/distribution_to_distribution_submission_0.pkl diff --git a/config_files/config_map_to_map_distance_matrix.yaml b/config_files/config_map_to_map_distance_matrix.yaml index 4302227..5a98e28 100644 --- a/config_files/config_map_to_map_distance_matrix.yaml +++ b/config_files/config_map_to_map_distance_matrix.yaml @@ -1,15 +1,15 @@ data: n_pix: 224 - psize: 2.146 + psize: 2.146 submission: fname: data/dataset_2_ground_truth/submission_0.pt volume_key: volumes metadata_key: populations label_key: id ground_truth: - volumes: data/dataset_2_ground_truth/maps_gt_flat.pt - metadata: data/dataset_2_ground_truth/metadata.csv - mask: + volumes: data/dataset_2_ground_truth/maps_gt_flat.pt + metadata: data/dataset_2_ground_truth/metadata.csv + mask: do: true volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc analysis: @@ -23,4 +23,4 @@ analysis: normalize: do: true method: median_zscore -output: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file +output: results/map_to_map_distance_matrix_submission_0.pkl diff --git a/src/cryo_challenge/__init__.py b/src/cryo_challenge/__init__.py index cafea4e..934c6a8 100644 --- a/src/cryo_challenge/__init__.py +++ b/src/cryo_challenge/__init__.py @@ -1 +1,3 @@ -from cryo_challenge.__about__ import __version__ \ No newline at end of file +from cryo_challenge.__about__ import __version__ + +__all__ = ["__version__"] diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 70961d8..18c57bb 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -2,8 +2,6 @@ import numpy as np import pickle from scipy.stats import rankdata -import yaml -import argparse import torch import ot @@ -14,10 +12,12 @@ def sort_by_transport(cost): - m,n = cost.shape - _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) - indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) - return cost[:,indices], indices, transport + m, n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost( + np.ones(m) / m, np.ones(n) / n, cost + ) + indices = np.argsort((transport * np.arange(m)[..., None]).sum(0)) + return cost[:, indices], indices, transport def compute_wasserstein_between_distributions_from_weights_and_cost( @@ -65,7 +65,6 @@ def make_assignment_matrix(cost_matrix): def run(config): - metadata_df = pd.read_csv(config["gt_metadata_fname"]) metadata_df.sort_values("pc1", inplace=True) @@ -73,7 +72,7 @@ def run(config): data = pickle.load(f) # user_submitted_populations = np.ones(80)/80 - user_submitted_populations = data["user_submitted_populations"]#.numpy() + user_submitted_populations = data["user_submitted_populations"] # .numpy() id = torch.load(data["config"]["data"]["submission"]["fname"])["id"] results_dict = {} @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol): DistributionToDistributionResultsValidator.from_dict(results_dict) with open(config["output_fname"], "wb") as f: pickle.dump(results_dict, f) - + return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 0578dfa..47d9a00 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -42,7 +42,7 @@ def run(config): user_submission_label = submission[label_key] # n_trunc = 10 - metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc] + metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc] results_dict = {} results_dict["config"] = config diff --git a/src/cryo_challenge/_ploting/plotting_utils.py b/src/cryo_challenge/_ploting/plotting_utils.py index 04d5ca9..681ab01 100644 --- a/src/cryo_challenge/_ploting/plotting_utils.py +++ b/src/cryo_challenge/_ploting/plotting_utils.py @@ -1,6 +1,7 @@ import numpy as np + def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist \ No newline at end of file + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist diff --git a/src/cryo_challenge/_preprocessing/bfactor_normalize.py b/src/cryo_challenge/_preprocessing/bfactor_normalize.py new file mode 100644 index 0000000..2af5e98 --- /dev/null +++ b/src/cryo_challenge/_preprocessing/bfactor_normalize.py @@ -0,0 +1,36 @@ +import torch +from ..power_spectrum_utils import _centered_fftn, _centered_ifftn + + +def _compute_bfactor_scaling(b_factor, box_size, voxel_size): + x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + y = x.clone() + z = x.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) + + return bfactor_scaling_torch + + +def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): + if not in_place: + volumes = volumes.clone() + + b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) + + if len(volumes.shape) == 3: + volumes = _centered_fftn(volumes, dim=(0, 1, 2)) + volumes = volumes * b_factor_scaling + volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real + + elif len(volumes.shape) == 4: + volumes = _centered_fftn(volumes, dim=(1, 2, 3)) + volumes = volumes * b_factor_scaling[None, ...] + volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real + + else: + raise ValueError("Input volumes must have 3 or 4 dimensions.") + + return volumes diff --git a/src/cryo_challenge/_preprocessing/normalize.py b/src/cryo_challenge/_preprocessing/normalize.py deleted file mode 100644 index 73449bf..0000000 --- a/src/cryo_challenge/_preprocessing/normalize.py +++ /dev/null @@ -1,22 +0,0 @@ -''' -TODO: Need to implement this properly - -def normalize_mean_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.mean(-1, keepdims=True)) / vols_flat.std( - -1, keepdims=True - ) - - -def normalize_median_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.median(-1, keepdims=True).values) / vols_flat.std( - -1, keepdims=True - ) -''' diff --git a/src/cryo_challenge/data/__init__.py b/src/cryo_challenge/data/__init__.py index 8b4655c..fb27bbd 100644 --- a/src/cryo_challenge/data/__init__.py +++ b/src/cryo_challenge/data/__init__.py @@ -1,6 +1,18 @@ -from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist -from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl -from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator -from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator -from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD -from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL +from ._validation.config_validators import ( + validate_input_config_disttodist as validate_input_config_disttodist, +) +from ._validation.config_validators import ( + validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, +) +from cryo_challenge.data._validation.output_validators import ( + DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, +) +from cryo_challenge.data._validation.output_validators import ( + MetricDistToDistValidator as MetricDistToDistValidator, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorEMD as ReplicateValidatorEMD, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorKL as ReplicateValidatorKL, +) diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index f194c14..2a4d954 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): # Reshape volumes to correct size if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.)))) + box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) elif volumes.dim() == 4: pass else: - raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).") + raise ValueError( + f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " + f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " + f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." + ) volumes_ds = torch.empty( (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 35d9791..39cbb67 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -13,7 +13,7 @@ @dataclass_json @dataclass class MapToMapResultsValidator: - ''' + """ Validate the output dictionary of the map-to-map distance matrix computation. config: dict, input config dictionary. @@ -22,7 +22,8 @@ class MapToMapResultsValidator: l2: dict, L2 results. bioem: dict, BioEM results. fsc: dict, FSC results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None @@ -49,7 +50,7 @@ class ReplicateValidatorEMD: Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline. q_opt: List[float], optimal user submitted distribution, which sums to 1. - EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). + EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). The transport plan is a joint distribution, such that: summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution. transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns). @@ -61,6 +62,7 @@ class ReplicateValidatorEMD: The transport plan is a joint distribution, such that: summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution. """ + q_opt: List[float] EMD_opt: float transport_plan_opt: List[List[float]] @@ -87,8 +89,9 @@ class ReplicateValidatorKL: iter_stop: int, number of iterations until convergence. eps_stop: float, stopping criterion. klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q). - klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). + klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). """ + q_opt: List[float] klpq_opt: float klqp_opt: float @@ -106,11 +109,12 @@ def __post_init__(self): @dataclass_json @dataclass class MetricDistToDistValidator: - ''' + """ Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline. replicates: dict, dictionary of replicates. - ''' + """ + replicates: dict def validate_replicates(self, n_replicates): @@ -126,7 +130,7 @@ def validate_replicates(self, n_replicates): @dataclass_json @dataclass class DistributionToDistributionResultsValidator: - ''' + """ Validate the output dictionary of the distribution-to-distribution pipeline. config: dict, input config dictionary. @@ -136,7 +140,8 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor id: str diff --git a/src/cryo_challenge/power_spectrum_utils.py b/src/cryo_challenge/power_spectrum_utils.py new file mode 100644 index 0000000..0a6338f --- /dev/null +++ b/src/cryo_challenge/power_spectrum_utils.py @@ -0,0 +1,71 @@ +import torch + + +def _cart2sph(x, y, z): + """ + Converts a grid in cartesian coordinates to spherical coordinates. + + Parameters + ---------- + x: torch.tensor + x-coordinate of the grid. + y: torch.tensor + y-coordinate of the grid. + z: torch.tensor + """ + hxy = torch.hypot(x, y) + r = torch.hypot(hxy, z) + el = torch.atan2(z, hxy) + az = torch.atan2(y, x) + return az, el, r + + +def _grid_3d(n, dtype=torch.float32): + start = -n // 2 + 1 + end = n // 2 + + if n % 2 == 0: + start -= 1 / 2 + end -= 1 / 2 + + grid = torch.linspace(start, end, n, dtype=dtype) + z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") + + phi, theta, r = _cart2sph(x, y, z) + + theta = torch.pi / 2 - theta + + return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} + + +def _centered_fftn(x, dim=None): + x = torch.fft.fftn(x, dim=dim) + x = torch.fft.fftshift(x, dim=dim) + return x + + +def _centered_ifftn(x, dim=None): + x = torch.fft.fftshift(x, dim=dim) + x = torch.fft.ifftn(x, dim=dim) + return x + + +def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5): + inner_diameter = shell_width + index + outer_diameter = shell_width + (index + 1) + mask = (radii > inner_diameter) & (radii < outer_diameter) + return torch.sum(mask * volume) / torch.sum(mask) + + +def compute_power_spectrum(volume, shell_width=0.5): + L = volume.shape[0] + dtype = torch.float32 + radii = _grid_3d(L, dtype=dtype)["r"] + + # Compute centered Fourier transforms. + vol_fft = torch.abs(_centered_fftn(volume)) ** 2 + + power_spectrum = torch.vmap( + _compute_power_spectrum_shell, in_dims=(0, None, None, None) + )(torch.arange(0, L // 2), vol_fft, radii, shell_width) + return power_spectrum diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index b8318b9..3355706 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -9,7 +9,7 @@ "0": { "name": "raw_submission_in_testdata", "align": 1, - "flavor_name": "test flavor", + "flavor_name": "test flavor", "box_size": 244, "pixel_size": 2.146, "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", diff --git a/tests/test_distribution_to_distribution.py b/tests/test_distribution_to_distribution.py index d9c340b..a4cfb79 100644 --- a/tests/test_distribution_to_distribution.py +++ b/tests/test_distribution_to_distribution.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_distribution2distribution_pipeline -def test_run_distribution2distribution_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'}) - run_distribution2distribution_pipeline.main(args) \ No newline at end of file +def test_run_distribution2distribution_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_distribution_to_distribution.yaml"} + ) + run_distribution2distribution_pipeline.main(args) diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c782a8c..e31f29f 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_map2map_pipeline -def test_run_map2map_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'}) - run_map2map_pipeline.main(args) \ No newline at end of file +def test_run_map2map_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map.yaml"} + ) + run_map2map_pipeline.main(args) diff --git a/tests/test_power_spectrum_and_bfactor.py b/tests/test_power_spectrum_and_bfactor.py new file mode 100644 index 0000000..8496ba2 --- /dev/null +++ b/tests/test_power_spectrum_and_bfactor.py @@ -0,0 +1,72 @@ +import torch +from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum +from cryo_challenge._preprocessing.bfactor_normalize import ( + _compute_bfactor_scaling, + bfactor_normalize_volumes, +) + + +def test_compute_power_spectrum(): + box_size = 224 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.073 * 2 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + b_factor = 170 + + gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape) + gaussian_volume = _centered_ifftn(gaussian_volume) + + power_spectrum = compute_power_spectrum(gaussian_volume) + power_spectrum_slice = ( + torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2 + ) + + mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2) + + assert mean_squared_error < 1e-3 + + return + + +def test_bfactor_normalize_volumes(): + box_size = 128 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.5 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape) + oscillatory_volume = _centered_ifftn(oscillatory_volume) + bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size) + + norm_oscillatory_vol = bfactor_normalize_volumes( + oscillatory_volume, 170, voxel_size, in_place=False + ) + + ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[ + : box_size // 2, 0, 0 + ] + ps_norm_osci = torch.fft.fftn( + norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward" + )[: box_size // 2, 0, 0] + ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0] + + ps_osci = torch.abs(ps_osci) ** 2 + ps_norm_osci = torch.abs(ps_norm_osci) ** 2 + ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2 + + assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index cbf54e4..31db34e 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_preprocessing -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'}) - run_preprocessing.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"}) + run_preprocessing.main(args) diff --git a/tests/test_svd.py b/tests/test_svd.py index 568370e..ea166ea 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'}) - run_svd.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) + run_svd.main(args) diff --git a/tutorials/1_tutorial_preprocessing.ipynb b/tutorials/1_tutorial_preprocessing.ipynb index cc6a459..84db8c9 100644 --- a/tutorials/1_tutorial_preprocessing.ipynb +++ b/tutorials/1_tutorial_preprocessing.ipynb @@ -203,7 +203,7 @@ "# Select path to Config file\n", "# An example of this file is available in the path ../config_files/config_preproc.yaml\n", "config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_preproc_path.filter_pattern = '*.yaml'\n", + "config_preproc_path.filter_pattern = \"*.yaml\"\n", "display(config_preproc_path)" ] }, @@ -226,7 +226,7 @@ "if os.path.isabs(output_path):\n", " full_output_path = output_path\n", "else:\n", - " full_output_path = os.path.join(os.getcwd(), '..', output_path)" + " full_output_path = os.path.join(os.getcwd(), \"..\", output_path)" ] }, { diff --git a/tutorials/2_tutorial_svd.ipynb b/tutorials/2_tutorial_svd.ipynb index b41bfba..fe8f432 100644 --- a/tutorials/2_tutorial_svd.ipynb +++ b/tutorials/2_tutorial_svd.ipynb @@ -62,7 +62,7 @@ "# Select path to SVD config file\n", "# An example of this file is available in the path ../config_files/config_svd.yaml\n", "config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_svd_path.filter_pattern = '*.yaml'\n", + "config_svd_path.filter_pattern = \"*.yaml\"\n", "display(config_svd_path)" ] }, @@ -125,7 +125,7 @@ "source": [ "# Select path to SVD results\n", "svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_results_path.filter_pattern = '*.pt'\n", + "svd_results_path.filter_pattern = \"*.pt\"\n", "display(svd_results_path)" ] }, @@ -316,7 +316,7 @@ "source": [ "# Select path to SVD results\n", "svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_all_vs_all_results_path.filter_pattern = '*.pt'\n", + "svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n", "display(svd_all_vs_all_results_path)" ] }, diff --git a/tutorials/3_tutorial_map2map.ipynb b/tutorials/3_tutorial_map2map.ipynb index a3701ff..0497578 100644 --- a/tutorials/3_tutorial_map2map.ipynb +++ b/tutorials/3_tutorial_map2map.ipynb @@ -23,15 +23,8 @@ "\n", "from cryo_challenge.data._validation.config_validators import (\n", " validate_input_config_mtm,\n", - " validate_config_mtm_data, \n", - " validate_config_mtm_data_submission, \n", - " validate_config_mtm_data_ground_truth, \n", - " validate_config_mtm_data_mask, \n", - " validate_config_mtm_analysis, \n", - " validate_config_mtm_analysis_normalize, \n", - " )\n", - "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator\n", - "from cryo_challenge.data._validation.config_validators import validate_maptomap_result" + ")\n", + "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator" ] }, { @@ -80,7 +73,7 @@ "# Select path to Map to Map config file\n", "# An example of this file is available in the path ../config_files/config_map_to_map.yaml\n", "config_m2m_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_m2m_path.filter_pattern = '*.yaml'\n", + "config_m2m_path.filter_pattern = \"*.yaml\"\n", "display(config_m2m_path)" ] }, @@ -341,8 +334,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { diff --git a/tutorials/4_tutorial_distribution2distribution.ipynb b/tutorials/4_tutorial_distribution2distribution.ipynb index 07dc1d9..2a6fc53 100644 --- a/tutorials/4_tutorial_distribution2distribution.ipynb +++ b/tutorials/4_tutorial_distribution2distribution.ipynb @@ -30,7 +30,7 @@ "import pickle\n", "from ipyfilechooser import FileChooser\n", "\n", - "from cryo_challenge.data import validate_input_config_disttodist, validate_config_dtd_optimal_q_kl\n", + "from cryo_challenge.data import validate_input_config_disttodist\n", "from cryo_challenge.data import DistributionToDistributionResultsValidator" ] }, @@ -65,7 +65,7 @@ "# Select path to Distribution to Distribution config file\n", "# An example of this file is available in the path ../config_files/config_distribution_to_distribution.yaml\n", "config_d2d_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_d2d_path.filter_pattern = '*.yaml'\n", + "config_d2d_path.filter_pattern = \"*.yaml\"\n", "display(config_d2d_path)" ] }, @@ -199,8 +199,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output_fname\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output_fname\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { @@ -286,7 +286,6 @@ } ], "source": [ - "from cryo_challenge.data import MetricDistToDistValidator\n", "MetricDistToDistValidator?" ] }, @@ -302,9 +301,7 @@ "execution_count": 30, "metadata": {}, "outputs": [], - "source": [ - "from cryo_challenge.data import ReplicateValidatorEMD, ReplicateValidatorKL" - ] + "source": [] }, { "cell_type": "code", diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index ed8a924..e2648c8 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -23,7 +23,9 @@ "metadata": {}, "outputs": [], "source": [ - "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import (\n", + " sort_by_transport,\n", + ")\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -102,6 +104,7 @@ " map2map_results: List[str]\n", " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", + "\n", "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", "config = PlottingConfig.from_dict(config)\n", @@ -115,7 +118,7 @@ "outputs": [], "source": [ "metadata_df = pd.read_csv(config.gt_metadata)\n", - "metadata_df.sort_values('pc1', inplace=True)\n", + "metadata_df.sort_values(\"pc1\", inplace=True)\n", "gt_ordering = metadata_df.index.tolist()" ] }, @@ -136,12 +139,14 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[map2map_distance]['user_submission_label']\n", + " anonymous_label = data[map2map_distance][\"user_submission_label\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", - "data_d = get_fsc_distances(config.map2map_results, 'fsc')" + "\n", + "\n", + "data_d = get_fsc_distances(config.map2map_results, \"fsc\")" ] }, { @@ -162,22 +167,24 @@ ], "source": [ "def plot_fsc_distances(data_d, gt_ordering):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'$d_{FSC}$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"$d_{FSC}$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, data) in enumerate(data_d.items()):\n", - " map2map_dist_matrix = data['fsc']['cost_matrix'].iloc[gt_ordering].values\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", - "\n", + " map2map_dist_matrix = data[\"fsc\"][\"cost_matrix\"].iloc[gt_ordering].values\n", + " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(\n", + " map2map_dist_matrix\n", + " )\n", "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", @@ -185,14 +192,24 @@ " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", "plot_fsc_distances(data_d, gt_ordering)" ] @@ -214,12 +231,13 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['fsc']['user_submission_label']\n", - " data_d[anonymous_label] = data['fsc']['computed_assets']['fsc_matrix']\n", + " anonymous_label = data[\"fsc\"][\"user_submission_label\"]\n", + " data_d[anonymous_label] = data[\"fsc\"][\"computed_assets\"][\"fsc_matrix\"]\n", " return data_d\n", "\n", + "\n", "fscs_sorted_d = get_full_fsc_curve(config.map2map_results)" ] }, @@ -229,9 +247,11 @@ "metadata": {}, "outputs": [], "source": [ - "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", - "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" + "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(\n", + " fscs_sorted_d[\"Cookie Dough\"], threshold=0.5\n", + ")\n", + "n_fourier_bins = fscs_sorted_d[\"Cookie Dough\"].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1, n_fourier_bins + 1) / n_fourier_bins)" ] }, { @@ -261,7 +281,7 @@ } ], "source": [ - "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect='auto', cmap='Blues_r')\n", + "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect=\"auto\", cmap=\"Blues_r\")\n", "plt.colorbar()" ] }, @@ -283,17 +303,18 @@ ], "source": [ "def plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={}):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'Resolution $(\\AA)$ at $FSC=0.5$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"Resolution $(\\AA)$ at $FSC=0.5$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", @@ -302,27 +323,38 @@ "\n", " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", - "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", " vmin = map2map_dist_matrix.min()\n", " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", - " if 'vmax' in overwrite_dict.keys():\n", - " vmax = overwrite_dict['vmax']\n", - " if 'vmin' in overwrite_dict.keys():\n", - " vmin = overwrite_dict['vmin']\n", - "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " if \"vmax\" in overwrite_dict.keys():\n", + " vmax = overwrite_dict[\"vmax\"]\n", + " if \"vmin\" in overwrite_dict.keys():\n", + " vmin = overwrite_dict[\"vmin\"]\n", + "\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", - "plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={'vmax': 31})" + "plot_res_at_fsc_threshold_distances(\n", + " fscs_sorted_d, gt_ordering, overwrite_dict={\"vmax\": 31}\n", + ")" ] }, { @@ -338,9 +370,9 @@ "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", + "fname = config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"][0]\n", "\n", - "with open(fname, 'rb') as f:\n", + "with open(fname, \"rb\") as f:\n", " data = pickle.load(f)" ] }, @@ -354,13 +386,16 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['id']\n", + " anonymous_label = data[\"id\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" + "\n", + "dist2dist_results_d = get_dist2dist_results(\n", + " config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"]\n", + ")" ] }, { @@ -381,46 +416,58 @@ ], "source": [ "window_size = 15\n", - "nrows, ncols = 3,4\n", + "nrows, ncols = 3, 4\n", "suptitle = f\"Submitted populations vs optimal populations \\n d_FSC (no rank) | n_replicates={data['config']['n_replicates']} | window_size={window_size} | n_pool_microstate={data['config']['n_pool_microstate']}\"\n", "\n", + "\n", "def plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols):\n", + " fig, axes = plt.subplots(nrows, ncols, figsize=(40, 25))\n", "\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=(40,25))\n", - " \n", - " fig.suptitle(\n", - " suptitle,\n", - " fontsize=30,\n", - " y=0.95)\n", + " fig.suptitle(suptitle, fontsize=30, y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", - " \n", + " for idx_fname, (_, data) in enumerate(dist2dist_results_d.items()):\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " data[\"user_submitted_populations\"], color=\"black\", label=\"submited\"\n", + " )\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_title(data[\"id\"], fontsize=30)\n", "\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_title(data['id'], fontsize=30)\n", - "\n", - " def window_q(q_opt,window_size):\n", - " running_avg = np.convolve(q_opt, np.ones(window_size)/window_size, mode='same')\n", + " def window_q(q_opt, window_size):\n", + " running_avg = np.convolve(\n", + " q_opt, np.ones(window_size) / window_size, mode=\"same\"\n", + " )\n", " return running_avg\n", - " \n", - " for replicate_idx in range(data['config']['n_replicates']):\n", - " \n", - " if replicate_idx == 0: \n", - " label_d = {'EMD': 'EMD', 'KL': 'KL', 'KL_raw': 'Unwindowed', 'EMD_raw': 'Unwindowed'}\n", - " else:\n", - " label_d = {'EMD': None, 'KL': None, 'KL_raw': None, 'EMD_raw': None}\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(window_q(data['fsc']['replicates'][replicate_idx]['EMD']['q_opt'],window_size), color='blue', alpha=alpha, label=label_d['EMD'])\n", - "\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_xlabel('Submission index')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_ylabel('Population')\n", "\n", - " legend = axes[idx_fname//ncols, idx_fname%ncols].legend()\n", + " for replicate_idx in range(data[\"config\"][\"n_replicates\"]):\n", + " if replicate_idx == 0:\n", + " label_d = {\n", + " \"EMD\": \"EMD\",\n", + " \"KL\": \"KL\",\n", + " \"KL_raw\": \"Unwindowed\",\n", + " \"EMD_raw\": \"Unwindowed\",\n", + " }\n", + " else:\n", + " label_d = {\"EMD\": None, \"KL\": None, \"KL_raw\": None, \"EMD_raw\": None}\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " window_q(\n", + " data[\"fsc\"][\"replicates\"][replicate_idx][\"EMD\"][\"q_opt\"],\n", + " window_size,\n", + " ),\n", + " color=\"blue\",\n", + " alpha=alpha,\n", + " label=label_d[\"EMD\"],\n", + " )\n", + "\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_xlabel(\"Submission index\")\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_ylabel(\"Population\")\n", + "\n", + " legend = axes[idx_fname // ncols, idx_fname % ncols].legend()\n", " for line, text in zip(legend.get_lines(), legend.get_texts()):\n", " text.set_color(line.get_color())\n", " line.set_alpha(1)\n", "\n", - "plot_q_opt_distances(dist2dist_results_d, suptitle,nrows, ncols)" + "\n", + "plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols)" ] }, { @@ -437,7 +484,6 @@ "outputs": [], "source": [ "def wragle_pkl_to_dataframe(pkl_globs):\n", - "\n", " fnames = []\n", " for fname_glob in pkl_globs:\n", " fnames.extend(glob.glob(fname_glob))\n", @@ -446,28 +492,50 @@ "\n", " df_list = []\n", " n_replicates = 30\n", - " metric = 'fsc'\n", + " metric = \"fsc\"\n", "\n", " for fname in fnames:\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", "\n", - " df_list.append(pd.DataFrame({\n", - " 'EMD_opt': [data[metric]['replicates'][i]['EMD']['EMD_opt'] for i in range(n_replicates)],\n", - " 'EMD_submitted': [data[metric]['replicates'][i]['EMD']['EMD_submitted'] for i in range(n_replicates)],\n", - " 'klpq_opt': [data[metric]['replicates'][i]['KL']['klpq_opt'] for i in range(n_replicates)],\n", - " 'klqp_opt': [data[metric]['replicates'][i]['KL']['klqp_opt'] for i in range(n_replicates)],\n", - " 'klpq_submitted': [data[metric]['replicates'][i]['KL']['klpq_submitted'] for i in range(n_replicates)], \n", - " 'klqp_submitted': [data[metric]['replicates'][i]['KL']['klqp_submitted'] for i in range(n_replicates)], \n", - " 'id': data['id'],\n", - " 'n_pool_microstate': data['config']['n_pool_microstate'],\n", - " }))\n", + " df_list.append(\n", + " pd.DataFrame(\n", + " {\n", + " \"EMD_opt\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"EMD_submitted\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"id\": data[\"id\"],\n", + " \"n_pool_microstate\": data[\"config\"][\"n_pool_microstate\"],\n", + " }\n", + " )\n", + " )\n", "\n", " df = pd.concat(df_list)\n", - " df['EMD_opt_norm'] = df['EMD_opt'] / df['n_pool_microstate']\n", - " df['EMD_submitted_norm'] = df['EMD_submitted'] / df['n_pool_microstate']\n", + " df[\"EMD_opt_norm\"] = df[\"EMD_opt\"] / df[\"n_pool_microstate\"]\n", + " df[\"EMD_submitted_norm\"] = df[\"EMD_submitted\"] / df[\"n_pool_microstate\"]\n", "\n", - " return df\n" + " return df" ] }, { @@ -476,7 +544,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results[\"emd_plot\"][\"pkl_globs\"])" ] }, { @@ -498,50 +566,96 @@ "source": [ "def plot_EMD_vs_EMDopt(df, suptitle=None):\n", " alpha = 1\n", - " df_average = df.groupby(['n_pool_microstate','id']).mean().reset_index()\n", - " df_std = df_average.groupby(['id']).std().reset_index().filter(['EMD_opt_norm','EMD_submitted_norm', 'id']).rename(columns={'EMD_opt_norm':'EMD_opt_norm_std', 'EMD_submitted_norm':'EMD_submitted_norm_std'})\n", - " df_average = df.groupby(['id']).mean().reset_index()\n", - "\n", - " df_average_and_error = pd.merge(df_average, df_std, on='id')\n", + " df_average = df.groupby([\"n_pool_microstate\", \"id\"]).mean().reset_index()\n", + " df_std = (\n", + " df_average.groupby([\"id\"])\n", + " .std()\n", + " .reset_index()\n", + " .filter([\"EMD_opt_norm\", \"EMD_submitted_norm\", \"id\"])\n", + " .rename(\n", + " columns={\n", + " \"EMD_opt_norm\": \"EMD_opt_norm_std\",\n", + " \"EMD_submitted_norm\": \"EMD_submitted_norm_std\",\n", + " }\n", + " )\n", + " )\n", + " df_average = df.groupby([\"id\"]).mean().reset_index()\n", + "\n", + " df_average_and_error = pd.merge(df_average, df_std, on=\"id\")\n", "\n", " # Get unique ids\n", - " ids = df_average_and_error['id'].unique()\n", + " ids = df_average_and_error[\"id\"].unique()\n", "\n", " # Define marker styles\n", - " markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']\n", + " markers = [\n", + " \"o\",\n", + " \"v\",\n", + " \"^\",\n", + " \"<\",\n", + " \">\",\n", + " \"s\",\n", + " \"p\",\n", + " \"*\",\n", + " \"h\",\n", + " \"H\",\n", + " \"+\",\n", + " \"x\",\n", + " \"D\",\n", + " \"d\",\n", + " \"|\",\n", + " \"_\",\n", + " ]\n", " marker_size = 250\n", "\n", - " plt.style.use('seaborn-v0_8-poster')\n", - " plot_width, plot_height = 8,6\n", + " plt.style.use(\"seaborn-v0_8-poster\")\n", + " plot_width, plot_height = 8, 6\n", " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", " for idx, id_label in enumerate(ids):\n", - " df_average_id = df_average[df_average['id'] == id_label]\n", - " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", - "\n", - " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", - " y=df_average_and_error['EMD_opt_norm'], \n", - " xerr=df_average_and_error['EMD_submitted_norm_std'], \n", - " yerr=df_average_and_error['EMD_opt_norm_std'], \n", - " fmt='', alpha=0.05, linestyle='None', ecolor='k', elinewidth=2, capsize=5)\n", + " df_average_id = df_average[df_average[\"id\"] == id_label]\n", + " sns.scatterplot(\n", + " x=\"EMD_submitted_norm\",\n", + " y=\"EMD_opt_norm\",\n", + " data=df_average_id,\n", + " alpha=alpha,\n", + " marker=markers[idx % len(markers)],\n", + " label=id_label,\n", + " s=marker_size,\n", + " )\n", + "\n", + " plt.errorbar(\n", + " x=df_average_and_error[\"EMD_submitted_norm\"],\n", + " y=df_average_and_error[\"EMD_opt_norm\"],\n", + " xerr=df_average_and_error[\"EMD_submitted_norm_std\"],\n", + " yerr=df_average_and_error[\"EMD_opt_norm_std\"],\n", + " fmt=\"\",\n", + " alpha=0.05,\n", + " linestyle=\"None\",\n", + " ecolor=\"k\",\n", + " elinewidth=2,\n", + " capsize=5,\n", + " )\n", "\n", " plt.xlim(left=0.5)\n", " plt.ylim(bottom=0.5)\n", "\n", - " limits = [np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", - " np.max([plt.xlim(), plt.ylim()])] # max of both axes\n", + " limits = [\n", + " np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", + " np.max([plt.xlim(), plt.ylim()]),\n", + " ] # max of both axes\n", "\n", - " plt.plot(limits, limits, 'k-', alpha=0.75, zorder=0)\n", + " plt.plot(limits, limits, \"k-\", alpha=0.75, zorder=0)\n", " plt.xlim(limits)\n", " plt.ylim(limits)\n", - " legend = plt.legend(loc='upper left', fontsize=12)\n", + " legend = plt.legend(loc=\"upper left\", fontsize=12)\n", " for handle in legend.legend_handles:\n", " handle.set_alpha(1)\n", "\n", " plt.suptitle(suptitle)\n", "\n", - "suptitle = r'$d_{FSC}$ (no rank)'\n", + "\n", + "suptitle = r\"$d_{FSC}$ (no rank)\"\n", "plot_EMD_vs_EMDopt(df, suptitle=None)" ] } diff --git a/tutorials/README.md b/tutorials/README.md index d427ba5..842bdf7 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -25,4 +25,4 @@ This notebook walks through generating and analyzing (plots) the map to map dist - output: a `.pkl` file ### `5_tutorial_plotting.ipynb` -This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. \ No newline at end of file +This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results.