diff --git a/.github/workflows/main_merge_check.yml b/.github/workflows/main_merge_check.yml index 8e6f5c5..b4aa6e1 100644 --- a/.github/workflows/main_merge_check.yml +++ b/.github/workflows/main_merge_check.yml @@ -11,4 +11,4 @@ jobs: if: github.base_ref == 'main' && github.head_ref != 'dev' run: | echo "ERROR: You can only merge to main from dev." - exit 1 \ No newline at end of file + exit 1 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..1b4ed47 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,12 @@ +# Runs the Ruff linter and formatter. + +name: Lint + +on: [push] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index adc0ebb..4100565 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,8 +12,28 @@ The "-e" flag will install the package in editable mode, which means you can edi ## Things to do before pushing to GitHub -In this project we use Ruff for linting, and pre-commit to make sure that the code being pushed is not broken or goes against PEP8 guidelines. When you run `git commit` the pre-commit pipeline should rune automatically. In the near future we will start using pytest and mypy to perform more checks. +### Using pre-commit hooks for code formatting and linting +When you install in developer mode with `".[dev]` you will install the [pre-commit](https://pre-commit.com/) package. To set up this package simply run + +```bash +pre-commit install +``` + +Then, everytime before doing a commit (that is before `git add` and `git commit`) run the following command: + +```bash +pre-commit run --all-files +``` + +This will run `ruff` linting and formatting. If there is anything that cannot be automatically fixed, the command will let you know the file and line that needs to be fixed before being able to commit. Once you have fixed everything, you will be able to run `git add` and `git commit` without issue. + + +### Make sure tests run + +```bash +python -m pytest tests/ +``` ## Best practices for contributing diff --git a/config_files/config_distribution_to_distribution.yaml b/config_files/config_distribution_to_distribution.yaml index da8b3d9..eaa94ac 100644 --- a/config_files/config_distribution_to_distribution.yaml +++ b/config_files/config_distribution_to_distribution.yaml @@ -12,4 +12,4 @@ cvxpy_solver: ECOS optimal_q_kl: n_iter: 100000 break_atol: 0.0001 -output_fname: results/distribution_to_distribution_submission_0.pkl \ No newline at end of file +output_fname: results/distribution_to_distribution_submission_0.pkl diff --git a/config_files/config_map_to_map.yaml b/config_files/config_map_to_map.yaml index 4302227..5a98e28 100644 --- a/config_files/config_map_to_map.yaml +++ b/config_files/config_map_to_map.yaml @@ -1,15 +1,15 @@ data: n_pix: 224 - psize: 2.146 + psize: 2.146 submission: fname: data/dataset_2_ground_truth/submission_0.pt volume_key: volumes metadata_key: populations label_key: id ground_truth: - volumes: data/dataset_2_ground_truth/maps_gt_flat.pt - metadata: data/dataset_2_ground_truth/metadata.csv - mask: + volumes: data/dataset_2_ground_truth/maps_gt_flat.pt + metadata: data/dataset_2_ground_truth/metadata.csv + mask: do: true volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc analysis: @@ -23,4 +23,4 @@ analysis: normalize: do: true method: median_zscore -output: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file +output: results/map_to_map_distance_matrix_submission_0.pkl diff --git a/src/cryo_challenge/__init__.py b/src/cryo_challenge/__init__.py index cafea4e..934c6a8 100644 --- a/src/cryo_challenge/__init__.py +++ b/src/cryo_challenge/__init__.py @@ -1 +1,3 @@ -from cryo_challenge.__about__ import __version__ \ No newline at end of file +from cryo_challenge.__about__ import __version__ + +__all__ = ["__version__"] diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 70961d8..18c57bb 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -2,8 +2,6 @@ import numpy as np import pickle from scipy.stats import rankdata -import yaml -import argparse import torch import ot @@ -14,10 +12,12 @@ def sort_by_transport(cost): - m,n = cost.shape - _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) - indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) - return cost[:,indices], indices, transport + m, n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost( + np.ones(m) / m, np.ones(n) / n, cost + ) + indices = np.argsort((transport * np.arange(m)[..., None]).sum(0)) + return cost[:, indices], indices, transport def compute_wasserstein_between_distributions_from_weights_and_cost( @@ -65,7 +65,6 @@ def make_assignment_matrix(cost_matrix): def run(config): - metadata_df = pd.read_csv(config["gt_metadata_fname"]) metadata_df.sort_values("pc1", inplace=True) @@ -73,7 +72,7 @@ def run(config): data = pickle.load(f) # user_submitted_populations = np.ones(80)/80 - user_submitted_populations = data["user_submitted_populations"]#.numpy() + user_submitted_populations = data["user_submitted_populations"] # .numpy() id = torch.load(data["config"]["data"]["submission"]["fname"])["id"] results_dict = {} @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol): DistributionToDistributionResultsValidator.from_dict(results_dict) with open(config["output_fname"], "wb") as f: pickle.dump(results_dict, f) - + return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 55d01d3..1e0b71d 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -1,9 +1,10 @@ import math import torch from typing import Optional, Sequence -from typing_extensions import override +from typing_extensions import override import mrcfile + class MapToMapDistance: def __init__(self, config): self.config = config @@ -25,30 +26,33 @@ def get_distance_matrix(self, maps1, maps2): )(maps1) return distance_matrix - + def get_computed_assets(self, maps1, maps2): """Return any computed assets that are needed for (downstream) analysis.""" return {} + class L2DistanceNorm(MapToMapDistance): def __init__(self, config): super().__init__(config) @override def get_distance(self, map1, map2): - return torch.norm(map1 - map2)**2 - + return torch.norm(map1 - map2) ** 2 + + class L2DistanceSum(MapToMapDistance): def __init__(self, config): super().__init__(config) def compute_cost_l2(self, map_1, map_2): return ((map_1 - map_2) ** 2).sum() - - @override + + @override def get_distance(self, map1, map2): return self.compute_cost_l2(map1, map2) - + + class Correlation(MapToMapDistance): def __init__(self, config): super().__init__(config) @@ -56,9 +60,10 @@ def __init__(self, config): def compute_cost_corr(self, map_1, map_2): return (map_1 * map_2).sum() - @override + @override def get_distance(self, map1, map2): - return self.compute_cost_corr(map1, map2) + return self.compute_cost_corr(map1, map2) + class BioEM3dDistance(MapToMapDistance): def __init__(self, config): @@ -94,7 +99,12 @@ def compute_bioem3d_cost(self, map1, map2): N = len(m1) t1 = 2 * torch.pi * math.exp(1) - t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc + t2 = ( + N * (ccc * coo - coc * coc) + + 2 * co * coc * cc + - ccc * co * co + - coo * cc * cc + ) t3 = (N - 2) * (N * ccc - cc * cc) smallest_float = torch.finfo(m1.dtype).tiny @@ -107,10 +117,11 @@ def compute_bioem3d_cost(self, map1, map2): cost = -log_prob return cost - @override + @override def get_distance(self, map1, map2): - return self.compute_bioem3d_cost(map1, map2) - + return self.compute_bioem3d_cost(map1, map2) + + class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) @@ -203,22 +214,26 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): return cost_matrix, fsc_matrix @override - def get_distance_matrix(self, maps1, maps2): # custom method + def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 n_pix = self.config["data"]["n_pix"] maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]).data.astype(bool).flatten() + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() ) maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) maps_user_flat_cube[:, mask] = maps_user_flat - - cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) - self.stored_computed_assets = {'fsc_matrix': fsc_matrix} + + cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( + maps_gt_flat_cube, maps_user_flat_cube, n_pix + ) + self.stored_computed_assets = {"fsc_matrix": fsc_matrix} return cost_matrix @override def get_computed_assets(self, maps1, maps2): - return self.stored_computed_assets # must run get_distance_matrix first \ No newline at end of file + return self.stored_computed_assets # must run get_distance_matrix first diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index e39942f..116dcac 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -4,15 +4,21 @@ import torch from ..data._validation.output_validators import MapToMapResultsValidator -from .._map_to_map.map_to_map_distance import FSCDistance, Correlation, L2DistanceSum, BioEM3dDistance +from .._map_to_map.map_to_map_distance import ( + FSCDistance, + Correlation, + L2DistanceSum, + BioEM3dDistance, +) AVAILABLE_MAP2MAP_DISTANCES = { - "fsc": FSCDistance, - "corr": Correlation, - "l2": L2DistanceSum, - "bioem": BioEM3dDistance, - } + "fsc": FSCDistance, + "corr": Correlation, + "l2": L2DistanceSum, + "bioem": BioEM3dDistance, +} + def run(config): """ @@ -40,7 +46,6 @@ def run(config): submission[submission_metadata_key] / submission[submission_metadata_key].sum() ) - maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) diff --git a/src/cryo_challenge/_ploting/plotting_utils.py b/src/cryo_challenge/_ploting/plotting_utils.py index 04d5ca9..681ab01 100644 --- a/src/cryo_challenge/_ploting/plotting_utils.py +++ b/src/cryo_challenge/_ploting/plotting_utils.py @@ -1,6 +1,7 @@ import numpy as np + def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist \ No newline at end of file + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist diff --git a/src/cryo_challenge/data/__init__.py b/src/cryo_challenge/data/__init__.py index 8b4655c..fb27bbd 100644 --- a/src/cryo_challenge/data/__init__.py +++ b/src/cryo_challenge/data/__init__.py @@ -1,6 +1,18 @@ -from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist -from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl -from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator -from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator -from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD -from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL +from ._validation.config_validators import ( + validate_input_config_disttodist as validate_input_config_disttodist, +) +from ._validation.config_validators import ( + validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, +) +from cryo_challenge.data._validation.output_validators import ( + DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, +) +from cryo_challenge.data._validation.output_validators import ( + MetricDistToDistValidator as MetricDistToDistValidator, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorEMD as ReplicateValidatorEMD, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorKL as ReplicateValidatorKL, +) diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index f194c14..2a4d954 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): # Reshape volumes to correct size if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.)))) + box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) elif volumes.dim() == 4: pass else: - raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).") + raise ValueError( + f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " + f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " + f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." + ) volumes_ds = torch.empty( (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 35d9791..39cbb67 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -13,7 +13,7 @@ @dataclass_json @dataclass class MapToMapResultsValidator: - ''' + """ Validate the output dictionary of the map-to-map distance matrix computation. config: dict, input config dictionary. @@ -22,7 +22,8 @@ class MapToMapResultsValidator: l2: dict, L2 results. bioem: dict, BioEM results. fsc: dict, FSC results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None @@ -49,7 +50,7 @@ class ReplicateValidatorEMD: Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline. q_opt: List[float], optimal user submitted distribution, which sums to 1. - EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). + EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). The transport plan is a joint distribution, such that: summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution. transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns). @@ -61,6 +62,7 @@ class ReplicateValidatorEMD: The transport plan is a joint distribution, such that: summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution. """ + q_opt: List[float] EMD_opt: float transport_plan_opt: List[List[float]] @@ -87,8 +89,9 @@ class ReplicateValidatorKL: iter_stop: int, number of iterations until convergence. eps_stop: float, stopping criterion. klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q). - klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). + klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). """ + q_opt: List[float] klpq_opt: float klqp_opt: float @@ -106,11 +109,12 @@ def __post_init__(self): @dataclass_json @dataclass class MetricDistToDistValidator: - ''' + """ Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline. replicates: dict, dictionary of replicates. - ''' + """ + replicates: dict def validate_replicates(self, n_replicates): @@ -126,7 +130,7 @@ def validate_replicates(self, n_replicates): @dataclass_json @dataclass class DistributionToDistributionResultsValidator: - ''' + """ Validate the output dictionary of the distribution-to-distribution pipeline. config: dict, input config dictionary. @@ -136,7 +140,8 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor id: str diff --git a/tests/test_distribution_to_distribution.py b/tests/test_distribution_to_distribution.py index d9c340b..a4cfb79 100644 --- a/tests/test_distribution_to_distribution.py +++ b/tests/test_distribution_to_distribution.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_distribution2distribution_pipeline -def test_run_distribution2distribution_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'}) - run_distribution2distribution_pipeline.main(args) \ No newline at end of file +def test_run_distribution2distribution_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_distribution_to_distribution.yaml"} + ) + run_distribution2distribution_pipeline.main(args) diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c782a8c..e31f29f 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_map2map_pipeline -def test_run_map2map_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'}) - run_map2map_pipeline.main(args) \ No newline at end of file +def test_run_map2map_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map.yaml"} + ) + run_map2map_pipeline.main(args) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index cbf54e4..31db34e 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_preprocessing -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'}) - run_preprocessing.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"}) + run_preprocessing.main(args) diff --git a/tests/test_svd.py b/tests/test_svd.py index 568370e..ea166ea 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'}) - run_svd.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) + run_svd.main(args) diff --git a/tutorials/1_tutorial_preprocessing.ipynb b/tutorials/1_tutorial_preprocessing.ipynb index cc6a459..84db8c9 100644 --- a/tutorials/1_tutorial_preprocessing.ipynb +++ b/tutorials/1_tutorial_preprocessing.ipynb @@ -203,7 +203,7 @@ "# Select path to Config file\n", "# An example of this file is available in the path ../config_files/config_preproc.yaml\n", "config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_preproc_path.filter_pattern = '*.yaml'\n", + "config_preproc_path.filter_pattern = \"*.yaml\"\n", "display(config_preproc_path)" ] }, @@ -226,7 +226,7 @@ "if os.path.isabs(output_path):\n", " full_output_path = output_path\n", "else:\n", - " full_output_path = os.path.join(os.getcwd(), '..', output_path)" + " full_output_path = os.path.join(os.getcwd(), \"..\", output_path)" ] }, { diff --git a/tutorials/2_tutorial_svd.ipynb b/tutorials/2_tutorial_svd.ipynb index b41bfba..fe8f432 100644 --- a/tutorials/2_tutorial_svd.ipynb +++ b/tutorials/2_tutorial_svd.ipynb @@ -62,7 +62,7 @@ "# Select path to SVD config file\n", "# An example of this file is available in the path ../config_files/config_svd.yaml\n", "config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_svd_path.filter_pattern = '*.yaml'\n", + "config_svd_path.filter_pattern = \"*.yaml\"\n", "display(config_svd_path)" ] }, @@ -125,7 +125,7 @@ "source": [ "# Select path to SVD results\n", "svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_results_path.filter_pattern = '*.pt'\n", + "svd_results_path.filter_pattern = \"*.pt\"\n", "display(svd_results_path)" ] }, @@ -316,7 +316,7 @@ "source": [ "# Select path to SVD results\n", "svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_all_vs_all_results_path.filter_pattern = '*.pt'\n", + "svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n", "display(svd_all_vs_all_results_path)" ] }, diff --git a/tutorials/3_tutorial_map2map.ipynb b/tutorials/3_tutorial_map2map.ipynb index a3701ff..0497578 100644 --- a/tutorials/3_tutorial_map2map.ipynb +++ b/tutorials/3_tutorial_map2map.ipynb @@ -23,15 +23,8 @@ "\n", "from cryo_challenge.data._validation.config_validators import (\n", " validate_input_config_mtm,\n", - " validate_config_mtm_data, \n", - " validate_config_mtm_data_submission, \n", - " validate_config_mtm_data_ground_truth, \n", - " validate_config_mtm_data_mask, \n", - " validate_config_mtm_analysis, \n", - " validate_config_mtm_analysis_normalize, \n", - " )\n", - "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator\n", - "from cryo_challenge.data._validation.config_validators import validate_maptomap_result" + ")\n", + "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator" ] }, { @@ -80,7 +73,7 @@ "# Select path to Map to Map config file\n", "# An example of this file is available in the path ../config_files/config_map_to_map.yaml\n", "config_m2m_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_m2m_path.filter_pattern = '*.yaml'\n", + "config_m2m_path.filter_pattern = \"*.yaml\"\n", "display(config_m2m_path)" ] }, @@ -341,8 +334,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { diff --git a/tutorials/4_tutorial_distribution2distribution.ipynb b/tutorials/4_tutorial_distribution2distribution.ipynb index 07dc1d9..2a6fc53 100644 --- a/tutorials/4_tutorial_distribution2distribution.ipynb +++ b/tutorials/4_tutorial_distribution2distribution.ipynb @@ -30,7 +30,7 @@ "import pickle\n", "from ipyfilechooser import FileChooser\n", "\n", - "from cryo_challenge.data import validate_input_config_disttodist, validate_config_dtd_optimal_q_kl\n", + "from cryo_challenge.data import validate_input_config_disttodist\n", "from cryo_challenge.data import DistributionToDistributionResultsValidator" ] }, @@ -65,7 +65,7 @@ "# Select path to Distribution to Distribution config file\n", "# An example of this file is available in the path ../config_files/config_distribution_to_distribution.yaml\n", "config_d2d_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_d2d_path.filter_pattern = '*.yaml'\n", + "config_d2d_path.filter_pattern = \"*.yaml\"\n", "display(config_d2d_path)" ] }, @@ -199,8 +199,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output_fname\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output_fname\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { @@ -286,7 +286,6 @@ } ], "source": [ - "from cryo_challenge.data import MetricDistToDistValidator\n", "MetricDistToDistValidator?" ] }, @@ -302,9 +301,7 @@ "execution_count": 30, "metadata": {}, "outputs": [], - "source": [ - "from cryo_challenge.data import ReplicateValidatorEMD, ReplicateValidatorKL" - ] + "source": [] }, { "cell_type": "code", diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index ed8a924..e2648c8 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -23,7 +23,9 @@ "metadata": {}, "outputs": [], "source": [ - "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import (\n", + " sort_by_transport,\n", + ")\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -102,6 +104,7 @@ " map2map_results: List[str]\n", " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", + "\n", "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", "config = PlottingConfig.from_dict(config)\n", @@ -115,7 +118,7 @@ "outputs": [], "source": [ "metadata_df = pd.read_csv(config.gt_metadata)\n", - "metadata_df.sort_values('pc1', inplace=True)\n", + "metadata_df.sort_values(\"pc1\", inplace=True)\n", "gt_ordering = metadata_df.index.tolist()" ] }, @@ -136,12 +139,14 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[map2map_distance]['user_submission_label']\n", + " anonymous_label = data[map2map_distance][\"user_submission_label\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", - "data_d = get_fsc_distances(config.map2map_results, 'fsc')" + "\n", + "\n", + "data_d = get_fsc_distances(config.map2map_results, \"fsc\")" ] }, { @@ -162,22 +167,24 @@ ], "source": [ "def plot_fsc_distances(data_d, gt_ordering):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'$d_{FSC}$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"$d_{FSC}$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, data) in enumerate(data_d.items()):\n", - " map2map_dist_matrix = data['fsc']['cost_matrix'].iloc[gt_ordering].values\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", - "\n", + " map2map_dist_matrix = data[\"fsc\"][\"cost_matrix\"].iloc[gt_ordering].values\n", + " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(\n", + " map2map_dist_matrix\n", + " )\n", "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", @@ -185,14 +192,24 @@ " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", "plot_fsc_distances(data_d, gt_ordering)" ] @@ -214,12 +231,13 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['fsc']['user_submission_label']\n", - " data_d[anonymous_label] = data['fsc']['computed_assets']['fsc_matrix']\n", + " anonymous_label = data[\"fsc\"][\"user_submission_label\"]\n", + " data_d[anonymous_label] = data[\"fsc\"][\"computed_assets\"][\"fsc_matrix\"]\n", " return data_d\n", "\n", + "\n", "fscs_sorted_d = get_full_fsc_curve(config.map2map_results)" ] }, @@ -229,9 +247,11 @@ "metadata": {}, "outputs": [], "source": [ - "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", - "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" + "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(\n", + " fscs_sorted_d[\"Cookie Dough\"], threshold=0.5\n", + ")\n", + "n_fourier_bins = fscs_sorted_d[\"Cookie Dough\"].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1, n_fourier_bins + 1) / n_fourier_bins)" ] }, { @@ -261,7 +281,7 @@ } ], "source": [ - "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect='auto', cmap='Blues_r')\n", + "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect=\"auto\", cmap=\"Blues_r\")\n", "plt.colorbar()" ] }, @@ -283,17 +303,18 @@ ], "source": [ "def plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={}):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'Resolution $(\\AA)$ at $FSC=0.5$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"Resolution $(\\AA)$ at $FSC=0.5$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", @@ -302,27 +323,38 @@ "\n", " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", - "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", " vmin = map2map_dist_matrix.min()\n", " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", - " if 'vmax' in overwrite_dict.keys():\n", - " vmax = overwrite_dict['vmax']\n", - " if 'vmin' in overwrite_dict.keys():\n", - " vmin = overwrite_dict['vmin']\n", - "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " if \"vmax\" in overwrite_dict.keys():\n", + " vmax = overwrite_dict[\"vmax\"]\n", + " if \"vmin\" in overwrite_dict.keys():\n", + " vmin = overwrite_dict[\"vmin\"]\n", + "\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", - "plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={'vmax': 31})" + "plot_res_at_fsc_threshold_distances(\n", + " fscs_sorted_d, gt_ordering, overwrite_dict={\"vmax\": 31}\n", + ")" ] }, { @@ -338,9 +370,9 @@ "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", + "fname = config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"][0]\n", "\n", - "with open(fname, 'rb') as f:\n", + "with open(fname, \"rb\") as f:\n", " data = pickle.load(f)" ] }, @@ -354,13 +386,16 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['id']\n", + " anonymous_label = data[\"id\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" + "\n", + "dist2dist_results_d = get_dist2dist_results(\n", + " config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"]\n", + ")" ] }, { @@ -381,46 +416,58 @@ ], "source": [ "window_size = 15\n", - "nrows, ncols = 3,4\n", + "nrows, ncols = 3, 4\n", "suptitle = f\"Submitted populations vs optimal populations \\n d_FSC (no rank) | n_replicates={data['config']['n_replicates']} | window_size={window_size} | n_pool_microstate={data['config']['n_pool_microstate']}\"\n", "\n", + "\n", "def plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols):\n", + " fig, axes = plt.subplots(nrows, ncols, figsize=(40, 25))\n", "\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=(40,25))\n", - " \n", - " fig.suptitle(\n", - " suptitle,\n", - " fontsize=30,\n", - " y=0.95)\n", + " fig.suptitle(suptitle, fontsize=30, y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", - " \n", + " for idx_fname, (_, data) in enumerate(dist2dist_results_d.items()):\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " data[\"user_submitted_populations\"], color=\"black\", label=\"submited\"\n", + " )\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_title(data[\"id\"], fontsize=30)\n", "\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_title(data['id'], fontsize=30)\n", - "\n", - " def window_q(q_opt,window_size):\n", - " running_avg = np.convolve(q_opt, np.ones(window_size)/window_size, mode='same')\n", + " def window_q(q_opt, window_size):\n", + " running_avg = np.convolve(\n", + " q_opt, np.ones(window_size) / window_size, mode=\"same\"\n", + " )\n", " return running_avg\n", - " \n", - " for replicate_idx in range(data['config']['n_replicates']):\n", - " \n", - " if replicate_idx == 0: \n", - " label_d = {'EMD': 'EMD', 'KL': 'KL', 'KL_raw': 'Unwindowed', 'EMD_raw': 'Unwindowed'}\n", - " else:\n", - " label_d = {'EMD': None, 'KL': None, 'KL_raw': None, 'EMD_raw': None}\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(window_q(data['fsc']['replicates'][replicate_idx]['EMD']['q_opt'],window_size), color='blue', alpha=alpha, label=label_d['EMD'])\n", - "\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_xlabel('Submission index')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_ylabel('Population')\n", "\n", - " legend = axes[idx_fname//ncols, idx_fname%ncols].legend()\n", + " for replicate_idx in range(data[\"config\"][\"n_replicates\"]):\n", + " if replicate_idx == 0:\n", + " label_d = {\n", + " \"EMD\": \"EMD\",\n", + " \"KL\": \"KL\",\n", + " \"KL_raw\": \"Unwindowed\",\n", + " \"EMD_raw\": \"Unwindowed\",\n", + " }\n", + " else:\n", + " label_d = {\"EMD\": None, \"KL\": None, \"KL_raw\": None, \"EMD_raw\": None}\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " window_q(\n", + " data[\"fsc\"][\"replicates\"][replicate_idx][\"EMD\"][\"q_opt\"],\n", + " window_size,\n", + " ),\n", + " color=\"blue\",\n", + " alpha=alpha,\n", + " label=label_d[\"EMD\"],\n", + " )\n", + "\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_xlabel(\"Submission index\")\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_ylabel(\"Population\")\n", + "\n", + " legend = axes[idx_fname // ncols, idx_fname % ncols].legend()\n", " for line, text in zip(legend.get_lines(), legend.get_texts()):\n", " text.set_color(line.get_color())\n", " line.set_alpha(1)\n", "\n", - "plot_q_opt_distances(dist2dist_results_d, suptitle,nrows, ncols)" + "\n", + "plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols)" ] }, { @@ -437,7 +484,6 @@ "outputs": [], "source": [ "def wragle_pkl_to_dataframe(pkl_globs):\n", - "\n", " fnames = []\n", " for fname_glob in pkl_globs:\n", " fnames.extend(glob.glob(fname_glob))\n", @@ -446,28 +492,50 @@ "\n", " df_list = []\n", " n_replicates = 30\n", - " metric = 'fsc'\n", + " metric = \"fsc\"\n", "\n", " for fname in fnames:\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", "\n", - " df_list.append(pd.DataFrame({\n", - " 'EMD_opt': [data[metric]['replicates'][i]['EMD']['EMD_opt'] for i in range(n_replicates)],\n", - " 'EMD_submitted': [data[metric]['replicates'][i]['EMD']['EMD_submitted'] for i in range(n_replicates)],\n", - " 'klpq_opt': [data[metric]['replicates'][i]['KL']['klpq_opt'] for i in range(n_replicates)],\n", - " 'klqp_opt': [data[metric]['replicates'][i]['KL']['klqp_opt'] for i in range(n_replicates)],\n", - " 'klpq_submitted': [data[metric]['replicates'][i]['KL']['klpq_submitted'] for i in range(n_replicates)], \n", - " 'klqp_submitted': [data[metric]['replicates'][i]['KL']['klqp_submitted'] for i in range(n_replicates)], \n", - " 'id': data['id'],\n", - " 'n_pool_microstate': data['config']['n_pool_microstate'],\n", - " }))\n", + " df_list.append(\n", + " pd.DataFrame(\n", + " {\n", + " \"EMD_opt\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"EMD_submitted\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"id\": data[\"id\"],\n", + " \"n_pool_microstate\": data[\"config\"][\"n_pool_microstate\"],\n", + " }\n", + " )\n", + " )\n", "\n", " df = pd.concat(df_list)\n", - " df['EMD_opt_norm'] = df['EMD_opt'] / df['n_pool_microstate']\n", - " df['EMD_submitted_norm'] = df['EMD_submitted'] / df['n_pool_microstate']\n", + " df[\"EMD_opt_norm\"] = df[\"EMD_opt\"] / df[\"n_pool_microstate\"]\n", + " df[\"EMD_submitted_norm\"] = df[\"EMD_submitted\"] / df[\"n_pool_microstate\"]\n", "\n", - " return df\n" + " return df" ] }, { @@ -476,7 +544,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results[\"emd_plot\"][\"pkl_globs\"])" ] }, { @@ -498,50 +566,96 @@ "source": [ "def plot_EMD_vs_EMDopt(df, suptitle=None):\n", " alpha = 1\n", - " df_average = df.groupby(['n_pool_microstate','id']).mean().reset_index()\n", - " df_std = df_average.groupby(['id']).std().reset_index().filter(['EMD_opt_norm','EMD_submitted_norm', 'id']).rename(columns={'EMD_opt_norm':'EMD_opt_norm_std', 'EMD_submitted_norm':'EMD_submitted_norm_std'})\n", - " df_average = df.groupby(['id']).mean().reset_index()\n", - "\n", - " df_average_and_error = pd.merge(df_average, df_std, on='id')\n", + " df_average = df.groupby([\"n_pool_microstate\", \"id\"]).mean().reset_index()\n", + " df_std = (\n", + " df_average.groupby([\"id\"])\n", + " .std()\n", + " .reset_index()\n", + " .filter([\"EMD_opt_norm\", \"EMD_submitted_norm\", \"id\"])\n", + " .rename(\n", + " columns={\n", + " \"EMD_opt_norm\": \"EMD_opt_norm_std\",\n", + " \"EMD_submitted_norm\": \"EMD_submitted_norm_std\",\n", + " }\n", + " )\n", + " )\n", + " df_average = df.groupby([\"id\"]).mean().reset_index()\n", + "\n", + " df_average_and_error = pd.merge(df_average, df_std, on=\"id\")\n", "\n", " # Get unique ids\n", - " ids = df_average_and_error['id'].unique()\n", + " ids = df_average_and_error[\"id\"].unique()\n", "\n", " # Define marker styles\n", - " markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']\n", + " markers = [\n", + " \"o\",\n", + " \"v\",\n", + " \"^\",\n", + " \"<\",\n", + " \">\",\n", + " \"s\",\n", + " \"p\",\n", + " \"*\",\n", + " \"h\",\n", + " \"H\",\n", + " \"+\",\n", + " \"x\",\n", + " \"D\",\n", + " \"d\",\n", + " \"|\",\n", + " \"_\",\n", + " ]\n", " marker_size = 250\n", "\n", - " plt.style.use('seaborn-v0_8-poster')\n", - " plot_width, plot_height = 8,6\n", + " plt.style.use(\"seaborn-v0_8-poster\")\n", + " plot_width, plot_height = 8, 6\n", " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", " for idx, id_label in enumerate(ids):\n", - " df_average_id = df_average[df_average['id'] == id_label]\n", - " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", - "\n", - " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", - " y=df_average_and_error['EMD_opt_norm'], \n", - " xerr=df_average_and_error['EMD_submitted_norm_std'], \n", - " yerr=df_average_and_error['EMD_opt_norm_std'], \n", - " fmt='', alpha=0.05, linestyle='None', ecolor='k', elinewidth=2, capsize=5)\n", + " df_average_id = df_average[df_average[\"id\"] == id_label]\n", + " sns.scatterplot(\n", + " x=\"EMD_submitted_norm\",\n", + " y=\"EMD_opt_norm\",\n", + " data=df_average_id,\n", + " alpha=alpha,\n", + " marker=markers[idx % len(markers)],\n", + " label=id_label,\n", + " s=marker_size,\n", + " )\n", + "\n", + " plt.errorbar(\n", + " x=df_average_and_error[\"EMD_submitted_norm\"],\n", + " y=df_average_and_error[\"EMD_opt_norm\"],\n", + " xerr=df_average_and_error[\"EMD_submitted_norm_std\"],\n", + " yerr=df_average_and_error[\"EMD_opt_norm_std\"],\n", + " fmt=\"\",\n", + " alpha=0.05,\n", + " linestyle=\"None\",\n", + " ecolor=\"k\",\n", + " elinewidth=2,\n", + " capsize=5,\n", + " )\n", "\n", " plt.xlim(left=0.5)\n", " plt.ylim(bottom=0.5)\n", "\n", - " limits = [np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", - " np.max([plt.xlim(), plt.ylim()])] # max of both axes\n", + " limits = [\n", + " np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", + " np.max([plt.xlim(), plt.ylim()]),\n", + " ] # max of both axes\n", "\n", - " plt.plot(limits, limits, 'k-', alpha=0.75, zorder=0)\n", + " plt.plot(limits, limits, \"k-\", alpha=0.75, zorder=0)\n", " plt.xlim(limits)\n", " plt.ylim(limits)\n", - " legend = plt.legend(loc='upper left', fontsize=12)\n", + " legend = plt.legend(loc=\"upper left\", fontsize=12)\n", " for handle in legend.legend_handles:\n", " handle.set_alpha(1)\n", "\n", " plt.suptitle(suptitle)\n", "\n", - "suptitle = r'$d_{FSC}$ (no rank)'\n", + "\n", + "suptitle = r\"$d_{FSC}$ (no rank)\"\n", "plot_EMD_vs_EMDopt(df, suptitle=None)" ] } diff --git a/tutorials/README.md b/tutorials/README.md index d427ba5..842bdf7 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -25,4 +25,4 @@ This notebook walks through generating and analyzing (plots) the map to map dist - output: a `.pkl` file ### `5_tutorial_plotting.ipynb` -This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. \ No newline at end of file +This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results.