From 0ce2df11a5d9320a85caa4c6b61dd066e994e1b4 Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Mon, 12 Aug 2024 12:22:27 -0400 Subject: [PATCH 01/30] implement power spectrum and bfactor scaling --- .github/workflows/main_merge_check.yml | 2 +- .../config_distribution_to_distribution.yaml | 2 +- .../config_map_to_map_distance_matrix.yaml | 10 +- src/cryo_challenge/__init__.py | 4 +- .../distribution_to_distribution.py | 17 +- .../_map_to_map/map_to_map_distance_matrix.py | 2 +- src/cryo_challenge/_ploting/plotting_utils.py | 5 +- .../_preprocessing/bfactor_normalize.py | 36 ++ .../_preprocessing/normalize.py | 22 -- src/cryo_challenge/data/__init__.py | 24 +- src/cryo_challenge/data/_io/svd_io_utils.py | 10 +- .../data/_validation/output_validators.py | 21 +- src/cryo_challenge/power_spectrum_utils.py | 71 ++++ .../submission_x/submission_config.json | 2 +- tests/test_distribution_to_distribution.py | 8 +- tests/test_map_to_map.py | 8 +- tests/test_power_spectrum_and_bfactor.py | 72 ++++ tests/test_preprocessing.py | 6 +- tests/test_svd.py | 6 +- tutorials/1_tutorial_preprocessing.ipynb | 4 +- tutorials/2_tutorial_svd.ipynb | 6 +- tutorials/3_tutorial_map2map.ipynb | 17 +- ...4_tutorial_distribution2distribution.ipynb | 13 +- tutorials/5_tutorial_plotting.ipynb | 344 ++++++++++++------ tutorials/README.md | 2 +- 25 files changed, 500 insertions(+), 214 deletions(-) create mode 100644 src/cryo_challenge/_preprocessing/bfactor_normalize.py delete mode 100644 src/cryo_challenge/_preprocessing/normalize.py create mode 100644 src/cryo_challenge/power_spectrum_utils.py create mode 100644 tests/test_power_spectrum_and_bfactor.py diff --git a/.github/workflows/main_merge_check.yml b/.github/workflows/main_merge_check.yml index 8e6f5c5..b4aa6e1 100644 --- a/.github/workflows/main_merge_check.yml +++ b/.github/workflows/main_merge_check.yml @@ -11,4 +11,4 @@ jobs: if: github.base_ref == 'main' && github.head_ref != 'dev' run: | echo "ERROR: You can only merge to main from dev." - exit 1 \ No newline at end of file + exit 1 diff --git a/config_files/config_distribution_to_distribution.yaml b/config_files/config_distribution_to_distribution.yaml index da8b3d9..eaa94ac 100644 --- a/config_files/config_distribution_to_distribution.yaml +++ b/config_files/config_distribution_to_distribution.yaml @@ -12,4 +12,4 @@ cvxpy_solver: ECOS optimal_q_kl: n_iter: 100000 break_atol: 0.0001 -output_fname: results/distribution_to_distribution_submission_0.pkl \ No newline at end of file +output_fname: results/distribution_to_distribution_submission_0.pkl diff --git a/config_files/config_map_to_map_distance_matrix.yaml b/config_files/config_map_to_map_distance_matrix.yaml index 4302227..5a98e28 100644 --- a/config_files/config_map_to_map_distance_matrix.yaml +++ b/config_files/config_map_to_map_distance_matrix.yaml @@ -1,15 +1,15 @@ data: n_pix: 224 - psize: 2.146 + psize: 2.146 submission: fname: data/dataset_2_ground_truth/submission_0.pt volume_key: volumes metadata_key: populations label_key: id ground_truth: - volumes: data/dataset_2_ground_truth/maps_gt_flat.pt - metadata: data/dataset_2_ground_truth/metadata.csv - mask: + volumes: data/dataset_2_ground_truth/maps_gt_flat.pt + metadata: data/dataset_2_ground_truth/metadata.csv + mask: do: true volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc analysis: @@ -23,4 +23,4 @@ analysis: normalize: do: true method: median_zscore -output: results/map_to_map_distance_matrix_submission_0.pkl \ No newline at end of file +output: results/map_to_map_distance_matrix_submission_0.pkl diff --git a/src/cryo_challenge/__init__.py b/src/cryo_challenge/__init__.py index cafea4e..934c6a8 100644 --- a/src/cryo_challenge/__init__.py +++ b/src/cryo_challenge/__init__.py @@ -1 +1,3 @@ -from cryo_challenge.__about__ import __version__ \ No newline at end of file +from cryo_challenge.__about__ import __version__ + +__all__ = ["__version__"] diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 70961d8..18c57bb 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -2,8 +2,6 @@ import numpy as np import pickle from scipy.stats import rankdata -import yaml -import argparse import torch import ot @@ -14,10 +12,12 @@ def sort_by_transport(cost): - m,n = cost.shape - _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) - indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) - return cost[:,indices], indices, transport + m, n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost( + np.ones(m) / m, np.ones(n) / n, cost + ) + indices = np.argsort((transport * np.arange(m)[..., None]).sum(0)) + return cost[:, indices], indices, transport def compute_wasserstein_between_distributions_from_weights_and_cost( @@ -65,7 +65,6 @@ def make_assignment_matrix(cost_matrix): def run(config): - metadata_df = pd.read_csv(config["gt_metadata_fname"]) metadata_df.sort_values("pc1", inplace=True) @@ -73,7 +72,7 @@ def run(config): data = pickle.load(f) # user_submitted_populations = np.ones(80)/80 - user_submitted_populations = data["user_submitted_populations"]#.numpy() + user_submitted_populations = data["user_submitted_populations"] # .numpy() id = torch.load(data["config"]["data"]["submission"]["fname"])["id"] results_dict = {} @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol): DistributionToDistributionResultsValidator.from_dict(results_dict) with open(config["output_fname"], "wb") as f: pickle.dump(results_dict, f) - + return results_dict diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 0578dfa..47d9a00 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -42,7 +42,7 @@ def run(config): user_submission_label = submission[label_key] # n_trunc = 10 - metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc] + metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc] results_dict = {} results_dict["config"] = config diff --git a/src/cryo_challenge/_ploting/plotting_utils.py b/src/cryo_challenge/_ploting/plotting_utils.py index 04d5ca9..681ab01 100644 --- a/src/cryo_challenge/_ploting/plotting_utils.py +++ b/src/cryo_challenge/_ploting/plotting_utils.py @@ -1,6 +1,7 @@ import numpy as np + def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist \ No newline at end of file + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist diff --git a/src/cryo_challenge/_preprocessing/bfactor_normalize.py b/src/cryo_challenge/_preprocessing/bfactor_normalize.py new file mode 100644 index 0000000..2af5e98 --- /dev/null +++ b/src/cryo_challenge/_preprocessing/bfactor_normalize.py @@ -0,0 +1,36 @@ +import torch +from ..power_spectrum_utils import _centered_fftn, _centered_ifftn + + +def _compute_bfactor_scaling(b_factor, box_size, voxel_size): + x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + y = x.clone() + z = x.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) + + return bfactor_scaling_torch + + +def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): + if not in_place: + volumes = volumes.clone() + + b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) + + if len(volumes.shape) == 3: + volumes = _centered_fftn(volumes, dim=(0, 1, 2)) + volumes = volumes * b_factor_scaling + volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real + + elif len(volumes.shape) == 4: + volumes = _centered_fftn(volumes, dim=(1, 2, 3)) + volumes = volumes * b_factor_scaling[None, ...] + volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real + + else: + raise ValueError("Input volumes must have 3 or 4 dimensions.") + + return volumes diff --git a/src/cryo_challenge/_preprocessing/normalize.py b/src/cryo_challenge/_preprocessing/normalize.py deleted file mode 100644 index 73449bf..0000000 --- a/src/cryo_challenge/_preprocessing/normalize.py +++ /dev/null @@ -1,22 +0,0 @@ -''' -TODO: Need to implement this properly - -def normalize_mean_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.mean(-1, keepdims=True)) / vols_flat.std( - -1, keepdims=True - ) - - -def normalize_median_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.median(-1, keepdims=True).values) / vols_flat.std( - -1, keepdims=True - ) -''' diff --git a/src/cryo_challenge/data/__init__.py b/src/cryo_challenge/data/__init__.py index 8b4655c..fb27bbd 100644 --- a/src/cryo_challenge/data/__init__.py +++ b/src/cryo_challenge/data/__init__.py @@ -1,6 +1,18 @@ -from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist -from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl -from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator -from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator -from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD -from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL +from ._validation.config_validators import ( + validate_input_config_disttodist as validate_input_config_disttodist, +) +from ._validation.config_validators import ( + validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl, +) +from cryo_challenge.data._validation.output_validators import ( + DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator, +) +from cryo_challenge.data._validation.output_validators import ( + MetricDistToDistValidator as MetricDistToDistValidator, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorEMD as ReplicateValidatorEMD, +) +from cryo_challenge.data._validation.output_validators import ( + ReplicateValidatorKL as ReplicateValidatorKL, +) diff --git a/src/cryo_challenge/data/_io/svd_io_utils.py b/src/cryo_challenge/data/_io/svd_io_utils.py index f194c14..2a4d954 100644 --- a/src/cryo_challenge/data/_io/svd_io_utils.py +++ b/src/cryo_challenge/data/_io/svd_io_utils.py @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32): # Reshape volumes to correct size if volumes.dim() == 2: - box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.)))) + box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0)))) volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size)) elif volumes.dim() == 4: pass else: - raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " - f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " - f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).") + raise ValueError( + f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape " + f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the " + f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)." + ) volumes_ds = torch.empty( (volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 35d9791..39cbb67 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -13,7 +13,7 @@ @dataclass_json @dataclass class MapToMapResultsValidator: - ''' + """ Validate the output dictionary of the map-to-map distance matrix computation. config: dict, input config dictionary. @@ -22,7 +22,8 @@ class MapToMapResultsValidator: l2: dict, L2 results. bioem: dict, BioEM results. fsc: dict, FSC results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None @@ -49,7 +50,7 @@ class ReplicateValidatorEMD: Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline. q_opt: List[float], optimal user submitted distribution, which sums to 1. - EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). + EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt). The transport plan is a joint distribution, such that: summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution. transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns). @@ -61,6 +62,7 @@ class ReplicateValidatorEMD: The transport plan is a joint distribution, such that: summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution. """ + q_opt: List[float] EMD_opt: float transport_plan_opt: List[List[float]] @@ -87,8 +89,9 @@ class ReplicateValidatorKL: iter_stop: int, number of iterations until convergence. eps_stop: float, stopping criterion. klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q). - klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). + klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p). """ + q_opt: List[float] klpq_opt: float klqp_opt: float @@ -106,11 +109,12 @@ def __post_init__(self): @dataclass_json @dataclass class MetricDistToDistValidator: - ''' + """ Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline. replicates: dict, dictionary of replicates. - ''' + """ + replicates: dict def validate_replicates(self, n_replicates): @@ -126,7 +130,7 @@ def validate_replicates(self, n_replicates): @dataclass_json @dataclass class DistributionToDistributionResultsValidator: - ''' + """ Validate the output dictionary of the distribution-to-distribution pipeline. config: dict, input config dictionary. @@ -136,7 +140,8 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - ''' + """ + config: dict user_submitted_populations: torch.Tensor id: str diff --git a/src/cryo_challenge/power_spectrum_utils.py b/src/cryo_challenge/power_spectrum_utils.py new file mode 100644 index 0000000..0a6338f --- /dev/null +++ b/src/cryo_challenge/power_spectrum_utils.py @@ -0,0 +1,71 @@ +import torch + + +def _cart2sph(x, y, z): + """ + Converts a grid in cartesian coordinates to spherical coordinates. + + Parameters + ---------- + x: torch.tensor + x-coordinate of the grid. + y: torch.tensor + y-coordinate of the grid. + z: torch.tensor + """ + hxy = torch.hypot(x, y) + r = torch.hypot(hxy, z) + el = torch.atan2(z, hxy) + az = torch.atan2(y, x) + return az, el, r + + +def _grid_3d(n, dtype=torch.float32): + start = -n // 2 + 1 + end = n // 2 + + if n % 2 == 0: + start -= 1 / 2 + end -= 1 / 2 + + grid = torch.linspace(start, end, n, dtype=dtype) + z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") + + phi, theta, r = _cart2sph(x, y, z) + + theta = torch.pi / 2 - theta + + return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} + + +def _centered_fftn(x, dim=None): + x = torch.fft.fftn(x, dim=dim) + x = torch.fft.fftshift(x, dim=dim) + return x + + +def _centered_ifftn(x, dim=None): + x = torch.fft.fftshift(x, dim=dim) + x = torch.fft.ifftn(x, dim=dim) + return x + + +def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5): + inner_diameter = shell_width + index + outer_diameter = shell_width + (index + 1) + mask = (radii > inner_diameter) & (radii < outer_diameter) + return torch.sum(mask * volume) / torch.sum(mask) + + +def compute_power_spectrum(volume, shell_width=0.5): + L = volume.shape[0] + dtype = torch.float32 + radii = _grid_3d(L, dtype=dtype)["r"] + + # Compute centered Fourier transforms. + vol_fft = torch.abs(_centered_fftn(volume)) ** 2 + + power_spectrum = torch.vmap( + _compute_power_spectrum_shell, in_dims=(0, None, None, None) + )(torch.arange(0, L // 2), vol_fft, radii, shell_width) + return power_spectrum diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index b8318b9..3355706 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -9,7 +9,7 @@ "0": { "name": "raw_submission_in_testdata", "align": 1, - "flavor_name": "test flavor", + "flavor_name": "test flavor", "box_size": 244, "pixel_size": 2.146, "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", diff --git a/tests/test_distribution_to_distribution.py b/tests/test_distribution_to_distribution.py index d9c340b..a4cfb79 100644 --- a/tests/test_distribution_to_distribution.py +++ b/tests/test_distribution_to_distribution.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_distribution2distribution_pipeline -def test_run_distribution2distribution_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'}) - run_distribution2distribution_pipeline.main(args) \ No newline at end of file +def test_run_distribution2distribution_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_distribution_to_distribution.yaml"} + ) + run_distribution2distribution_pipeline.main(args) diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c782a8c..e31f29f 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -2,6 +2,8 @@ from cryo_challenge._commands import run_map2map_pipeline -def test_run_map2map_pipeline(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'}) - run_map2map_pipeline.main(args) \ No newline at end of file +def test_run_map2map_pipeline(): + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map.yaml"} + ) + run_map2map_pipeline.main(args) diff --git a/tests/test_power_spectrum_and_bfactor.py b/tests/test_power_spectrum_and_bfactor.py new file mode 100644 index 0000000..8496ba2 --- /dev/null +++ b/tests/test_power_spectrum_and_bfactor.py @@ -0,0 +1,72 @@ +import torch +from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum +from cryo_challenge._preprocessing.bfactor_normalize import ( + _compute_bfactor_scaling, + bfactor_normalize_volumes, +) + + +def test_compute_power_spectrum(): + box_size = 224 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.073 * 2 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + b_factor = 170 + + gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape) + gaussian_volume = _centered_ifftn(gaussian_volume) + + power_spectrum = compute_power_spectrum(gaussian_volume) + power_spectrum_slice = ( + torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2 + ) + + mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2) + + assert mean_squared_error < 1e-3 + + return + + +def test_bfactor_normalize_volumes(): + box_size = 128 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.5 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape) + oscillatory_volume = _centered_ifftn(oscillatory_volume) + bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size) + + norm_oscillatory_vol = bfactor_normalize_volumes( + oscillatory_volume, 170, voxel_size, in_place=False + ) + + ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[ + : box_size // 2, 0, 0 + ] + ps_norm_osci = torch.fft.fftn( + norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward" + )[: box_size // 2, 0, 0] + ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0] + + ps_osci = torch.abs(ps_osci) ** 2 + ps_norm_osci = torch.abs(ps_norm_osci) ** 2 + ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2 + + assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index cbf54e4..31db34e 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_preprocessing -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'}) - run_preprocessing.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"}) + run_preprocessing.main(args) diff --git a/tests/test_svd.py b/tests/test_svd.py index 568370e..ea166ea 100644 --- a/tests/test_svd.py +++ b/tests/test_svd.py @@ -2,6 +2,6 @@ from cryo_challenge._commands import run_svd -def test_run_preprocessing(): - args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'}) - run_svd.main(args) \ No newline at end of file +def test_run_preprocessing(): + args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"}) + run_svd.main(args) diff --git a/tutorials/1_tutorial_preprocessing.ipynb b/tutorials/1_tutorial_preprocessing.ipynb index cc6a459..84db8c9 100644 --- a/tutorials/1_tutorial_preprocessing.ipynb +++ b/tutorials/1_tutorial_preprocessing.ipynb @@ -203,7 +203,7 @@ "# Select path to Config file\n", "# An example of this file is available in the path ../config_files/config_preproc.yaml\n", "config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_preproc_path.filter_pattern = '*.yaml'\n", + "config_preproc_path.filter_pattern = \"*.yaml\"\n", "display(config_preproc_path)" ] }, @@ -226,7 +226,7 @@ "if os.path.isabs(output_path):\n", " full_output_path = output_path\n", "else:\n", - " full_output_path = os.path.join(os.getcwd(), '..', output_path)" + " full_output_path = os.path.join(os.getcwd(), \"..\", output_path)" ] }, { diff --git a/tutorials/2_tutorial_svd.ipynb b/tutorials/2_tutorial_svd.ipynb index b41bfba..fe8f432 100644 --- a/tutorials/2_tutorial_svd.ipynb +++ b/tutorials/2_tutorial_svd.ipynb @@ -62,7 +62,7 @@ "# Select path to SVD config file\n", "# An example of this file is available in the path ../config_files/config_svd.yaml\n", "config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_svd_path.filter_pattern = '*.yaml'\n", + "config_svd_path.filter_pattern = \"*.yaml\"\n", "display(config_svd_path)" ] }, @@ -125,7 +125,7 @@ "source": [ "# Select path to SVD results\n", "svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_results_path.filter_pattern = '*.pt'\n", + "svd_results_path.filter_pattern = \"*.pt\"\n", "display(svd_results_path)" ] }, @@ -316,7 +316,7 @@ "source": [ "# Select path to SVD results\n", "svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n", - "svd_all_vs_all_results_path.filter_pattern = '*.pt'\n", + "svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n", "display(svd_all_vs_all_results_path)" ] }, diff --git a/tutorials/3_tutorial_map2map.ipynb b/tutorials/3_tutorial_map2map.ipynb index a3701ff..0497578 100644 --- a/tutorials/3_tutorial_map2map.ipynb +++ b/tutorials/3_tutorial_map2map.ipynb @@ -23,15 +23,8 @@ "\n", "from cryo_challenge.data._validation.config_validators import (\n", " validate_input_config_mtm,\n", - " validate_config_mtm_data, \n", - " validate_config_mtm_data_submission, \n", - " validate_config_mtm_data_ground_truth, \n", - " validate_config_mtm_data_mask, \n", - " validate_config_mtm_analysis, \n", - " validate_config_mtm_analysis_normalize, \n", - " )\n", - "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator\n", - "from cryo_challenge.data._validation.config_validators import validate_maptomap_result" + ")\n", + "from cryo_challenge.data._validation.output_validators import MapToMapResultsValidator" ] }, { @@ -80,7 +73,7 @@ "# Select path to Map to Map config file\n", "# An example of this file is available in the path ../config_files/config_map_to_map.yaml\n", "config_m2m_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_m2m_path.filter_pattern = '*.yaml'\n", + "config_m2m_path.filter_pattern = \"*.yaml\"\n", "display(config_m2m_path)" ] }, @@ -341,8 +334,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { diff --git a/tutorials/4_tutorial_distribution2distribution.ipynb b/tutorials/4_tutorial_distribution2distribution.ipynb index 07dc1d9..2a6fc53 100644 --- a/tutorials/4_tutorial_distribution2distribution.ipynb +++ b/tutorials/4_tutorial_distribution2distribution.ipynb @@ -30,7 +30,7 @@ "import pickle\n", "from ipyfilechooser import FileChooser\n", "\n", - "from cryo_challenge.data import validate_input_config_disttodist, validate_config_dtd_optimal_q_kl\n", + "from cryo_challenge.data import validate_input_config_disttodist\n", "from cryo_challenge.data import DistributionToDistributionResultsValidator" ] }, @@ -65,7 +65,7 @@ "# Select path to Distribution to Distribution config file\n", "# An example of this file is available in the path ../config_files/config_distribution_to_distribution.yaml\n", "config_d2d_path = FileChooser(os.path.expanduser(\"~\"))\n", - "config_d2d_path.filter_pattern = '*.yaml'\n", + "config_d2d_path.filter_pattern = \"*.yaml\"\n", "display(config_d2d_path)" ] }, @@ -199,8 +199,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(os.path.join('../',config[\"output_fname\"]), \"rb\") as f:\n", - " results_dict = pickle.load(f)\n" + "with open(os.path.join(\"../\", config[\"output_fname\"]), \"rb\") as f:\n", + " results_dict = pickle.load(f)" ] }, { @@ -286,7 +286,6 @@ } ], "source": [ - "from cryo_challenge.data import MetricDistToDistValidator\n", "MetricDistToDistValidator?" ] }, @@ -302,9 +301,7 @@ "execution_count": 30, "metadata": {}, "outputs": [], - "source": [ - "from cryo_challenge.data import ReplicateValidatorEMD, ReplicateValidatorKL" - ] + "source": [] }, { "cell_type": "code", diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index ed8a924..e2648c8 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -23,7 +23,9 @@ "metadata": {}, "outputs": [], "source": [ - "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import (\n", + " sort_by_transport,\n", + ")\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -102,6 +104,7 @@ " map2map_results: List[str]\n", " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", + "\n", "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", "config = PlottingConfig.from_dict(config)\n", @@ -115,7 +118,7 @@ "outputs": [], "source": [ "metadata_df = pd.read_csv(config.gt_metadata)\n", - "metadata_df.sort_values('pc1', inplace=True)\n", + "metadata_df.sort_values(\"pc1\", inplace=True)\n", "gt_ordering = metadata_df.index.tolist()" ] }, @@ -136,12 +139,14 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data[map2map_distance]['user_submission_label']\n", + " anonymous_label = data[map2map_distance][\"user_submission_label\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", - "data_d = get_fsc_distances(config.map2map_results, 'fsc')" + "\n", + "\n", + "data_d = get_fsc_distances(config.map2map_results, \"fsc\")" ] }, { @@ -162,22 +167,24 @@ ], "source": [ "def plot_fsc_distances(data_d, gt_ordering):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'$d_{FSC}$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"$d_{FSC}$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, data) in enumerate(data_d.items()):\n", - " map2map_dist_matrix = data['fsc']['cost_matrix'].iloc[gt_ordering].values\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", - "\n", + " map2map_dist_matrix = data[\"fsc\"][\"cost_matrix\"].iloc[gt_ordering].values\n", + " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(\n", + " map2map_dist_matrix\n", + " )\n", "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", @@ -185,14 +192,24 @@ " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", "plot_fsc_distances(data_d, gt_ordering)" ] @@ -214,12 +231,13 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['fsc']['user_submission_label']\n", - " data_d[anonymous_label] = data['fsc']['computed_assets']['fsc_matrix']\n", + " anonymous_label = data[\"fsc\"][\"user_submission_label\"]\n", + " data_d[anonymous_label] = data[\"fsc\"][\"computed_assets\"][\"fsc_matrix\"]\n", " return data_d\n", "\n", + "\n", "fscs_sorted_d = get_full_fsc_curve(config.map2map_results)" ] }, @@ -229,9 +247,11 @@ "metadata": {}, "outputs": [], "source": [ - "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", - "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" + "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(\n", + " fscs_sorted_d[\"Cookie Dough\"], threshold=0.5\n", + ")\n", + "n_fourier_bins = fscs_sorted_d[\"Cookie Dough\"].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1, n_fourier_bins + 1) / n_fourier_bins)" ] }, { @@ -261,7 +281,7 @@ } ], "source": [ - "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect='auto', cmap='Blues_r')\n", + "plt.imshow(units_Angstroms[res_fsc_half][gt_ordering], aspect=\"auto\", cmap=\"Blues_r\")\n", "plt.colorbar()" ] }, @@ -283,17 +303,18 @@ ], "source": [ "def plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={}):\n", - "\n", " smaller_fontsize = 20\n", " larger_fontsize = 30\n", " n_plts = 12\n", " vmin, vmax = np.inf, -np.inf\n", "\n", - " fig, axis = plt.subplots(3,n_plts//3,\n", - " figsize=(40,20),\n", - " # dpi=100,\n", - " )\n", - " fig.suptitle(r'Resolution $(\\AA)$ at $FSC=0.5$', y=0.95, fontsize=larger_fontsize)\n", + " fig, axis = plt.subplots(\n", + " 3,\n", + " n_plts // 3,\n", + " figsize=(40, 20),\n", + " # dpi=100,\n", + " )\n", + " fig.suptitle(r\"Resolution $(\\AA)$ at $FSC=0.5$\", y=0.95, fontsize=larger_fontsize)\n", "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", @@ -302,27 +323,38 @@ "\n", " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", - "\n", " ncols = 4\n", " if map2map_dist_matrix.min() < vmin:\n", " vmin = map2map_dist_matrix.min()\n", " if map2map_dist_matrix.max() > vmax:\n", " vmax = map2map_dist_matrix.max()\n", - " if 'vmax' in overwrite_dict.keys():\n", - " vmax = overwrite_dict['vmax']\n", - " if 'vmin' in overwrite_dict.keys():\n", - " vmin = overwrite_dict['vmin']\n", - "\n", - " ax = axis[idx//ncols,idx%ncols].imshow(sorted_map2map_dist_matrix, aspect='auto', cmap='Blues_r', vmin=vmin, vmax=vmax)\n", - "\n", - "\n", - " axis[idx//ncols,idx%ncols].tick_params(axis='both', labelsize=smaller_fontsize)\n", + " if \"vmax\" in overwrite_dict.keys():\n", + " vmax = overwrite_dict[\"vmax\"]\n", + " if \"vmin\" in overwrite_dict.keys():\n", + " vmin = overwrite_dict[\"vmin\"]\n", + "\n", + " ax = axis[idx // ncols, idx % ncols].imshow(\n", + " sorted_map2map_dist_matrix,\n", + " aspect=\"auto\",\n", + " cmap=\"Blues_r\",\n", + " vmin=vmin,\n", + " vmax=vmax,\n", + " )\n", + "\n", + " axis[idx // ncols, idx % ncols].tick_params(\n", + " axis=\"both\", labelsize=smaller_fontsize\n", + " )\n", " cbar = fig.colorbar(ax)\n", " cbar.ax.tick_params(labelsize=smaller_fontsize)\n", " plot_panel_label = anonymous_label\n", - " axis[idx//ncols,idx%ncols].set_title(plot_panel_label, fontsize=smaller_fontsize)\n", + " axis[idx // ncols, idx % ncols].set_title(\n", + " plot_panel_label, fontsize=smaller_fontsize\n", + " )\n", + "\n", "\n", - "plot_res_at_fsc_threshold_distances(fscs_sorted_d, gt_ordering, overwrite_dict={'vmax': 31})" + "plot_res_at_fsc_threshold_distances(\n", + " fscs_sorted_d, gt_ordering, overwrite_dict={\"vmax\": 31}\n", + ")" ] }, { @@ -338,9 +370,9 @@ "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", + "fname = config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"][0]\n", "\n", - "with open(fname, 'rb') as f:\n", + "with open(fname, \"rb\") as f:\n", " data = pickle.load(f)" ] }, @@ -354,13 +386,16 @@ " data_d = {}\n", " for fname in fnames:\n", " if fname not in data_d.keys():\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", - " anonymous_label = data['id']\n", + " anonymous_label = data[\"id\"]\n", " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" + "\n", + "dist2dist_results_d = get_dist2dist_results(\n", + " config.dist2dist_results[\"prob_submitted_plot\"][\"pkl_fnames\"]\n", + ")" ] }, { @@ -381,46 +416,58 @@ ], "source": [ "window_size = 15\n", - "nrows, ncols = 3,4\n", + "nrows, ncols = 3, 4\n", "suptitle = f\"Submitted populations vs optimal populations \\n d_FSC (no rank) | n_replicates={data['config']['n_replicates']} | window_size={window_size} | n_pool_microstate={data['config']['n_pool_microstate']}\"\n", "\n", + "\n", "def plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols):\n", + " fig, axes = plt.subplots(nrows, ncols, figsize=(40, 25))\n", "\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=(40,25))\n", - " \n", - " fig.suptitle(\n", - " suptitle,\n", - " fontsize=30,\n", - " y=0.95)\n", + " fig.suptitle(suptitle, fontsize=30, y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", - " \n", + " for idx_fname, (_, data) in enumerate(dist2dist_results_d.items()):\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " data[\"user_submitted_populations\"], color=\"black\", label=\"submited\"\n", + " )\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_title(data[\"id\"], fontsize=30)\n", "\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_title(data['id'], fontsize=30)\n", - "\n", - " def window_q(q_opt,window_size):\n", - " running_avg = np.convolve(q_opt, np.ones(window_size)/window_size, mode='same')\n", + " def window_q(q_opt, window_size):\n", + " running_avg = np.convolve(\n", + " q_opt, np.ones(window_size) / window_size, mode=\"same\"\n", + " )\n", " return running_avg\n", - " \n", - " for replicate_idx in range(data['config']['n_replicates']):\n", - " \n", - " if replicate_idx == 0: \n", - " label_d = {'EMD': 'EMD', 'KL': 'KL', 'KL_raw': 'Unwindowed', 'EMD_raw': 'Unwindowed'}\n", - " else:\n", - " label_d = {'EMD': None, 'KL': None, 'KL_raw': None, 'EMD_raw': None}\n", - " axes[idx_fname//ncols, idx_fname%ncols].plot(window_q(data['fsc']['replicates'][replicate_idx]['EMD']['q_opt'],window_size), color='blue', alpha=alpha, label=label_d['EMD'])\n", - "\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_xlabel('Submission index')\n", - " axes[idx_fname//ncols, idx_fname%ncols].set_ylabel('Population')\n", "\n", - " legend = axes[idx_fname//ncols, idx_fname%ncols].legend()\n", + " for replicate_idx in range(data[\"config\"][\"n_replicates\"]):\n", + " if replicate_idx == 0:\n", + " label_d = {\n", + " \"EMD\": \"EMD\",\n", + " \"KL\": \"KL\",\n", + " \"KL_raw\": \"Unwindowed\",\n", + " \"EMD_raw\": \"Unwindowed\",\n", + " }\n", + " else:\n", + " label_d = {\"EMD\": None, \"KL\": None, \"KL_raw\": None, \"EMD_raw\": None}\n", + " axes[idx_fname // ncols, idx_fname % ncols].plot(\n", + " window_q(\n", + " data[\"fsc\"][\"replicates\"][replicate_idx][\"EMD\"][\"q_opt\"],\n", + " window_size,\n", + " ),\n", + " color=\"blue\",\n", + " alpha=alpha,\n", + " label=label_d[\"EMD\"],\n", + " )\n", + "\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_xlabel(\"Submission index\")\n", + " axes[idx_fname // ncols, idx_fname % ncols].set_ylabel(\"Population\")\n", + "\n", + " legend = axes[idx_fname // ncols, idx_fname % ncols].legend()\n", " for line, text in zip(legend.get_lines(), legend.get_texts()):\n", " text.set_color(line.get_color())\n", " line.set_alpha(1)\n", "\n", - "plot_q_opt_distances(dist2dist_results_d, suptitle,nrows, ncols)" + "\n", + "plot_q_opt_distances(dist2dist_results_d, suptitle, nrows, ncols)" ] }, { @@ -437,7 +484,6 @@ "outputs": [], "source": [ "def wragle_pkl_to_dataframe(pkl_globs):\n", - "\n", " fnames = []\n", " for fname_glob in pkl_globs:\n", " fnames.extend(glob.glob(fname_glob))\n", @@ -446,28 +492,50 @@ "\n", " df_list = []\n", " n_replicates = 30\n", - " metric = 'fsc'\n", + " metric = \"fsc\"\n", "\n", " for fname in fnames:\n", - " with open(fname, 'rb') as f:\n", + " with open(fname, \"rb\") as f:\n", " data = pickle.load(f)\n", "\n", - " df_list.append(pd.DataFrame({\n", - " 'EMD_opt': [data[metric]['replicates'][i]['EMD']['EMD_opt'] for i in range(n_replicates)],\n", - " 'EMD_submitted': [data[metric]['replicates'][i]['EMD']['EMD_submitted'] for i in range(n_replicates)],\n", - " 'klpq_opt': [data[metric]['replicates'][i]['KL']['klpq_opt'] for i in range(n_replicates)],\n", - " 'klqp_opt': [data[metric]['replicates'][i]['KL']['klqp_opt'] for i in range(n_replicates)],\n", - " 'klpq_submitted': [data[metric]['replicates'][i]['KL']['klpq_submitted'] for i in range(n_replicates)], \n", - " 'klqp_submitted': [data[metric]['replicates'][i]['KL']['klqp_submitted'] for i in range(n_replicates)], \n", - " 'id': data['id'],\n", - " 'n_pool_microstate': data['config']['n_pool_microstate'],\n", - " }))\n", + " df_list.append(\n", + " pd.DataFrame(\n", + " {\n", + " \"EMD_opt\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"EMD_submitted\": [\n", + " data[metric][\"replicates\"][i][\"EMD\"][\"EMD_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_opt\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_opt\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klpq_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klpq_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"klqp_submitted\": [\n", + " data[metric][\"replicates\"][i][\"KL\"][\"klqp_submitted\"]\n", + " for i in range(n_replicates)\n", + " ],\n", + " \"id\": data[\"id\"],\n", + " \"n_pool_microstate\": data[\"config\"][\"n_pool_microstate\"],\n", + " }\n", + " )\n", + " )\n", "\n", " df = pd.concat(df_list)\n", - " df['EMD_opt_norm'] = df['EMD_opt'] / df['n_pool_microstate']\n", - " df['EMD_submitted_norm'] = df['EMD_submitted'] / df['n_pool_microstate']\n", + " df[\"EMD_opt_norm\"] = df[\"EMD_opt\"] / df[\"n_pool_microstate\"]\n", + " df[\"EMD_submitted_norm\"] = df[\"EMD_submitted\"] / df[\"n_pool_microstate\"]\n", "\n", - " return df\n" + " return df" ] }, { @@ -476,7 +544,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results[\"emd_plot\"][\"pkl_globs\"])" ] }, { @@ -498,50 +566,96 @@ "source": [ "def plot_EMD_vs_EMDopt(df, suptitle=None):\n", " alpha = 1\n", - " df_average = df.groupby(['n_pool_microstate','id']).mean().reset_index()\n", - " df_std = df_average.groupby(['id']).std().reset_index().filter(['EMD_opt_norm','EMD_submitted_norm', 'id']).rename(columns={'EMD_opt_norm':'EMD_opt_norm_std', 'EMD_submitted_norm':'EMD_submitted_norm_std'})\n", - " df_average = df.groupby(['id']).mean().reset_index()\n", - "\n", - " df_average_and_error = pd.merge(df_average, df_std, on='id')\n", + " df_average = df.groupby([\"n_pool_microstate\", \"id\"]).mean().reset_index()\n", + " df_std = (\n", + " df_average.groupby([\"id\"])\n", + " .std()\n", + " .reset_index()\n", + " .filter([\"EMD_opt_norm\", \"EMD_submitted_norm\", \"id\"])\n", + " .rename(\n", + " columns={\n", + " \"EMD_opt_norm\": \"EMD_opt_norm_std\",\n", + " \"EMD_submitted_norm\": \"EMD_submitted_norm_std\",\n", + " }\n", + " )\n", + " )\n", + " df_average = df.groupby([\"id\"]).mean().reset_index()\n", + "\n", + " df_average_and_error = pd.merge(df_average, df_std, on=\"id\")\n", "\n", " # Get unique ids\n", - " ids = df_average_and_error['id'].unique()\n", + " ids = df_average_and_error[\"id\"].unique()\n", "\n", " # Define marker styles\n", - " markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']\n", + " markers = [\n", + " \"o\",\n", + " \"v\",\n", + " \"^\",\n", + " \"<\",\n", + " \">\",\n", + " \"s\",\n", + " \"p\",\n", + " \"*\",\n", + " \"h\",\n", + " \"H\",\n", + " \"+\",\n", + " \"x\",\n", + " \"D\",\n", + " \"d\",\n", + " \"|\",\n", + " \"_\",\n", + " ]\n", " marker_size = 250\n", "\n", - " plt.style.use('seaborn-v0_8-poster')\n", - " plot_width, plot_height = 8,6\n", + " plt.style.use(\"seaborn-v0_8-poster\")\n", + " plot_width, plot_height = 8, 6\n", " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", " for idx, id_label in enumerate(ids):\n", - " df_average_id = df_average[df_average['id'] == id_label]\n", - " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", - "\n", - " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", - " y=df_average_and_error['EMD_opt_norm'], \n", - " xerr=df_average_and_error['EMD_submitted_norm_std'], \n", - " yerr=df_average_and_error['EMD_opt_norm_std'], \n", - " fmt='', alpha=0.05, linestyle='None', ecolor='k', elinewidth=2, capsize=5)\n", + " df_average_id = df_average[df_average[\"id\"] == id_label]\n", + " sns.scatterplot(\n", + " x=\"EMD_submitted_norm\",\n", + " y=\"EMD_opt_norm\",\n", + " data=df_average_id,\n", + " alpha=alpha,\n", + " marker=markers[idx % len(markers)],\n", + " label=id_label,\n", + " s=marker_size,\n", + " )\n", + "\n", + " plt.errorbar(\n", + " x=df_average_and_error[\"EMD_submitted_norm\"],\n", + " y=df_average_and_error[\"EMD_opt_norm\"],\n", + " xerr=df_average_and_error[\"EMD_submitted_norm_std\"],\n", + " yerr=df_average_and_error[\"EMD_opt_norm_std\"],\n", + " fmt=\"\",\n", + " alpha=0.05,\n", + " linestyle=\"None\",\n", + " ecolor=\"k\",\n", + " elinewidth=2,\n", + " capsize=5,\n", + " )\n", "\n", " plt.xlim(left=0.5)\n", " plt.ylim(bottom=0.5)\n", "\n", - " limits = [np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", - " np.max([plt.xlim(), plt.ylim()])] # max of both axes\n", + " limits = [\n", + " np.min([plt.xlim(), plt.ylim()]), # min of both axes\n", + " np.max([plt.xlim(), plt.ylim()]),\n", + " ] # max of both axes\n", "\n", - " plt.plot(limits, limits, 'k-', alpha=0.75, zorder=0)\n", + " plt.plot(limits, limits, \"k-\", alpha=0.75, zorder=0)\n", " plt.xlim(limits)\n", " plt.ylim(limits)\n", - " legend = plt.legend(loc='upper left', fontsize=12)\n", + " legend = plt.legend(loc=\"upper left\", fontsize=12)\n", " for handle in legend.legend_handles:\n", " handle.set_alpha(1)\n", "\n", " plt.suptitle(suptitle)\n", "\n", - "suptitle = r'$d_{FSC}$ (no rank)'\n", + "\n", + "suptitle = r\"$d_{FSC}$ (no rank)\"\n", "plot_EMD_vs_EMDopt(df, suptitle=None)" ] } diff --git a/tutorials/README.md b/tutorials/README.md index d427ba5..842bdf7 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -25,4 +25,4 @@ This notebook walks through generating and analyzing (plots) the map to map dist - output: a `.pkl` file ### `5_tutorial_plotting.ipynb` -This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. \ No newline at end of file +This notebook walks through parsing and analyzing (plots) the map to map and distribution to distribution results. From dd09fdedd49dae51348410d3eb7439cc3263791a Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Mon, 12 Aug 2024 13:40:06 -0400 Subject: [PATCH 02/30] fix issue with preprocessing test --- .../submission_x/submission_config.json | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index 1cde04a..e7669ab 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -8,18 +8,13 @@ }, "0": { "name": "raw_submission_in_testdata", - "align": 1, "flavor_name": "test flavor", -<<<<<<< HEAD - "box_size": 244, - "pixel_size": 2.146, -======= + "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", + "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", + "submission_version": "1.0", "box_size": 32, "pixel_size": 15.022, ->>>>>>> dev - "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", "flip": 1, - "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", - "submission_version": "1.0" + "align": 1 } } From 5c1b901071199f13f37f150e8db13b5419321479 Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Thu, 15 Aug 2024 10:45:30 -0400 Subject: [PATCH 03/30] add docstrings to tests and implemented functions --- .../_preprocessing/bfactor_normalize.py | 61 +++++++++++- src/cryo_challenge/power_spectrum_utils.py | 98 +++++++++++++++++-- tests/test_power_spectrum_and_bfactor.py | 13 +++ 3 files changed, 160 insertions(+), 12 deletions(-) diff --git a/src/cryo_challenge/_preprocessing/bfactor_normalize.py b/src/cryo_challenge/_preprocessing/bfactor_normalize.py index 2af5e98..37d876b 100644 --- a/src/cryo_challenge/_preprocessing/bfactor_normalize.py +++ b/src/cryo_challenge/_preprocessing/bfactor_normalize.py @@ -3,6 +3,25 @@ def _compute_bfactor_scaling(b_factor, box_size, voxel_size): + """ + Compute the B-factor scaling factor for a given B-factor, box size, and voxel size. + The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared + distance in Fourier space. + + Parameters + ---------- + b_factor: float + B-factor to apply. + box_size: int + Size of the box. + voxel_size: float + Voxel size of the box. + + Returns + ------- + bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size)) + B-factor scaling factor. + """ x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) y = x.clone() z = x.clone() @@ -14,7 +33,44 @@ def _compute_bfactor_scaling(b_factor, box_size, voxel_size): return bfactor_scaling_torch -def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): +def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False): + """ + Normalize volumes by applying a B-factor correction. This is done by multiplying + a centered Fourier transform of the volume by the B-factor scaling factor and then + applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the + computation of the B-factor scaling. + + Parameters + ---------- + volumes: torch.tensor + Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N). + bfactor: float + B-factor to apply. + voxel_size: float + Voxel size of the volumes. + in_place: bool - default: False + Whether to normalize the volumes in place. + + Returns + ------- + volumes: torch.tensor + Normalized volumes. + """ + # assert that volumes have the correct shape + assert volumes.ndim in [ + 3, + 4, + ], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)" + + if volumes.ndim == 3: + assert ( + volumes.shape[0] == volumes.shape[1] == volumes.shape[2] + ), "Input volumes must have equal dimensions" + else: + assert ( + volumes.shape[1] == volumes.shape[2] == volumes.shape[3] + ), "Input volumes must have equal dimensions" + if not in_place: volumes = volumes.clone() @@ -30,7 +86,4 @@ def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=True): volumes = volumes * b_factor_scaling[None, ...] volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real - else: - raise ValueError("Input volumes must have 3 or 4 dimensions.") - return volumes diff --git a/src/cryo_challenge/power_spectrum_utils.py b/src/cryo_challenge/power_spectrum_utils.py index 0a6338f..afae5bc 100644 --- a/src/cryo_challenge/power_spectrum_utils.py +++ b/src/cryo_challenge/power_spectrum_utils.py @@ -21,6 +21,22 @@ def _cart2sph(x, y, z): def _grid_3d(n, dtype=torch.float32): + """ + Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates. + + Parameters + ---------- + n: int + Size of the grid. + dtype: torch.dtype + Data type of the grid. + + Returns + ------- + grid: dict + Dictionary containing the grid in cartesian and spherical coordinates. + keys: x, y, z, phi, theta, r + """ start = -n // 2 + 1 end = n // 2 @@ -39,33 +55,99 @@ def _grid_3d(n, dtype=torch.float32): def _centered_fftn(x, dim=None): + """ + Wrapper around torch.fft.fftn that centers the Fourier transform. + """ x = torch.fft.fftn(x, dim=dim) x = torch.fft.fftshift(x, dim=dim) return x def _centered_ifftn(x, dim=None): + """ + Wrapper around torch.fft.ifftn that centers the inverse Fourier transform. + """ x = torch.fft.fftshift(x, dim=dim) x = torch.fft.ifftn(x, dim=dim) return x -def _compute_power_spectrum_shell(index, volume, radii, shell_width=0.5): - inner_diameter = shell_width + index - outer_diameter = shell_width + (index + 1) +def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5): + """ + Given a volume in Fourier space, compute the average value of the volume over a shell. + + Parameters + ---------- + shell_index: int + Index of the shell in Fourier space. + volume: torch.tensor + Volume in Fourier space. + radii: torch.tensor + Radii of the Fourier space grid. + shell_width: float + Width of the shell. + + Returns + ------- + average: float + Average value of the volume over the shell. + """ + inner_diameter = shell_width + shell_index + outer_diameter = shell_width + (shell_index + 1) mask = (radii > inner_diameter) & (radii < outer_diameter) return torch.sum(mask * volume) / torch.sum(mask) -def compute_power_spectrum(volume, shell_width=0.5): - L = volume.shape[0] +def _average_over_shells(volume_in_fourier_space, shell_width=0.5): + """ + Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space. + + Parameters + ---------- + volume_in_fourier_space: torch.tensor + Volume in Fourier space. + + Returns + ------- + radial_average: torch.tensor + Average value of the volume over all shells. + """ + L = volume_in_fourier_space.shape[0] dtype = torch.float32 radii = _grid_3d(L, dtype=dtype)["r"] + radial_average = torch.vmap( + _average_over_single_shell, in_dims=(0, None, None, None) + )(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width) + + return radial_average + + +def compute_power_spectrum(volume, shell_width=0.5): + """ + Compute the power spectrum of a volume. + + Parameters + ---------- + volume: torch.tensor + Volume for which to compute the power spectrum. + shell_width: float + Width of the shell. + + Returns + ------- + power_spectrum: torch.tensor + Power spectrum of the volume. + + Examples + -------- + volume = mrcfile.open("volume.mrc").data.copy() + volume = torch.tensor(volume, dtype=torch.float32) + power_spectrum = compute_power_spectrum(volume) + """ + # Compute centered Fourier transforms. vol_fft = torch.abs(_centered_fftn(volume)) ** 2 + power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width) - power_spectrum = torch.vmap( - _compute_power_spectrum_shell, in_dims=(0, None, None, None) - )(torch.arange(0, L // 2), vol_fft, radii, shell_width) return power_spectrum diff --git a/tests/test_power_spectrum_and_bfactor.py b/tests/test_power_spectrum_and_bfactor.py index 8496ba2..218632b 100644 --- a/tests/test_power_spectrum_and_bfactor.py +++ b/tests/test_power_spectrum_and_bfactor.py @@ -7,6 +7,12 @@ def test_compute_power_spectrum(): + """ + Test the computation of the power spectrum of a radially symmetric Gaussian volume. + Since the volume is radially symmetric, the power spectrum of the whole volume should be + approximately the power spectrum in a central slice. The computation is not exact as our + averaging over shells is approximated. + """ box_size = 224 volume_shape = (box_size, box_size, box_size) voxel_size = 1.073 * 2 @@ -37,6 +43,13 @@ def test_compute_power_spectrum(): def test_bfactor_normalize_volumes(): + """ + Similarly to the other test, we test the normalization of a radially symmetric volume. + In this case we test with an oscillatory volume, which is a volume with a sinusoidal. + Since both the b-factor correction volume and the volume are radially symmetric, the + power spectrum of the normalized volume should be the same as the power spectrum of + a normalized central slice + """ box_size = 128 volume_shape = (box_size, box_size, box_size) voxel_size = 1.5 From 5572d979992d1bb50429459b051ce89cd1ca70ac Mon Sep 17 00:00:00 2001 From: Miro Astore Date: Mon, 19 Aug 2024 15:36:37 -0400 Subject: [PATCH 04/30] added serial code for centering every volume --- .../_preprocessing/align_utils.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/cryo_challenge/_preprocessing/align_utils.py b/src/cryo_challenge/_preprocessing/align_utils.py index cbbb6a1..4a976b5 100644 --- a/src/cryo_challenge/_preprocessing/align_utils.py +++ b/src/cryo_challenge/_preprocessing/align_utils.py @@ -134,21 +134,23 @@ def align_submission( -------- volumes (torch.Tensor): aligned submission volumes """ - obj_vol = volumes[0].numpy().astype(np.float32) - - obj_vol = Volume(obj_vol / obj_vol.sum()) - ref_vol = Volume(ref_volume / ref_volume.sum()) - - _, R_est = align_BO( - ref_vol, - obj_vol, - loss_type=params["BOT_loss"], - downsampled_size=params["BOT_box_size"], - max_iters=params["BOT_iter"], - refine=params["BOT_refine"], - ) - R_est = Rotation(R_est.astype(np.float32)) - - volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data) + for i in range(len(volumes)): + print('aligning ' + str(i) + 'th volume' ) + obj_vol = volumes[i].numpy().astype(np.float32) + + obj_vol = Volume(obj_vol / obj_vol.sum()) + ref_vol = Volume(ref_volume / ref_volume.sum()) + + _, R_est = align_BO( + ref_vol, + obj_vol, + loss_type=params["BOT_loss"], + downsampled_size=params["BOT_box_size"], + max_iters=params["BOT_iter"], + refine=params["BOT_refine"], + ) + R_est = Rotation(R_est.astype(np.float32)) + + volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) return volumes From 94b07dd79ec066b39d49cdbd7045061c1e40363d Mon Sep 17 00:00:00 2001 From: Miro Astore Date: Tue, 20 Aug 2024 10:31:29 -0400 Subject: [PATCH 05/30] element wise volume alignment --- src/cryo_challenge/_preprocessing/align_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cryo_challenge/_preprocessing/align_utils.py b/src/cryo_challenge/_preprocessing/align_utils.py index 4a976b5..d2a7784 100644 --- a/src/cryo_challenge/_preprocessing/align_utils.py +++ b/src/cryo_challenge/_preprocessing/align_utils.py @@ -135,11 +135,10 @@ def align_submission( volumes (torch.Tensor): aligned submission volumes """ for i in range(len(volumes)): - print('aligning ' + str(i) + 'th volume' ) - obj_vol = volumes[i].numpy().astype(np.float32) + obj_vol = volumes[i].numpy().astype(np.float32).copy() obj_vol = Volume(obj_vol / obj_vol.sum()) - ref_vol = Volume(ref_volume / ref_volume.sum()) + ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) _, R_est = align_BO( ref_vol, From 8cf8aea9388331f0c63a68d9815bfa5fd18b8690 Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Tue, 3 Sep 2024 17:08:47 -0400 Subject: [PATCH 06/30] remove centering and align only one volume in preprocessing --- .../_preprocessing/align_utils.py | 72 ++++++++++++++----- .../_preprocessing/preprocessing_pipeline.py | 4 +- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/src/cryo_challenge/_preprocessing/align_utils.py b/src/cryo_challenge/_preprocessing/align_utils.py index d2a7784..f0ae4fa 100644 --- a/src/cryo_challenge/_preprocessing/align_utils.py +++ b/src/cryo_challenge/_preprocessing/align_utils.py @@ -116,6 +116,45 @@ def center_submission(volumes: torch.Tensor, pixel_size: float) -> torch.Tensor: return volumes +# def align_submission( +# volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict +# ) -> torch.Tensor: +# """ +# Align submission volumes to ground truth volume + +# Parameters: +# ----------- +# volumes (torch.Tensor): submission volumes +# shape: (n_volumes, im_x, im_y, im_z) +# ref_volume (torch.Tensor): ground truth volume +# shape: (im_x, im_y, im_z) +# params (dict): dictionary containing alignment parameters + +# Returns: +# -------- +# volumes (torch.Tensor): aligned submission volumes +# """ +# for i in range(len(volumes)): +# obj_vol = volumes[i].numpy().astype(np.float32).copy() + +# obj_vol = Volume(obj_vol / obj_vol.sum()) +# ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) + +# _, R_est = align_BO( +# ref_vol, +# obj_vol, +# loss_type=params["BOT_loss"], +# downsampled_size=params["BOT_box_size"], +# max_iters=params["BOT_iter"], +# refine=params["BOT_refine"], +# ) +# R_est = Rotation(R_est.astype(np.float32)) + +# volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) + +# return volumes + + def align_submission( volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict ) -> torch.Tensor: @@ -134,22 +173,21 @@ def align_submission( -------- volumes (torch.Tensor): aligned submission volumes """ - for i in range(len(volumes)): - obj_vol = volumes[i].numpy().astype(np.float32).copy() - - obj_vol = Volume(obj_vol / obj_vol.sum()) - ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) - - _, R_est = align_BO( - ref_vol, - obj_vol, - loss_type=params["BOT_loss"], - downsampled_size=params["BOT_box_size"], - max_iters=params["BOT_iter"], - refine=params["BOT_refine"], - ) - R_est = Rotation(R_est.astype(np.float32)) - - volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) + obj_vol = volumes[0].numpy().astype(np.float32) + + obj_vol = Volume(obj_vol / obj_vol.sum()) + ref_vol = Volume(ref_volume / ref_volume.sum()) + + _, R_est = align_BO( + ref_vol, + obj_vol, + loss_type=params["BOT_loss"], + downsampled_size=params["BOT_box_size"], + max_iters=params["BOT_iter"], + refine=params["BOT_refine"], + ) + R_est = Rotation(R_est.astype(np.float32)) + + volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data) return volumes diff --git a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py index 90ccc51..b4c3e61 100644 --- a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py +++ b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py @@ -2,7 +2,7 @@ import json import os -from .align_utils import align_submission, center_submission, threshold_submissions +from .align_utils import align_submission, threshold_submissions from .crop_pad_utils import crop_pad_submission from .fourier_utils import downsample_submission @@ -80,7 +80,7 @@ def preprocess_submissions(submission_dataset, config): # center submission print(" Centering submission") - volumes = center_submission(volumes, pixel_size=pixel_size_gt) + # volumes = center_submission(volumes, pixel_size=pixel_size_gt) # flip handedness if submission_dataset.submission_config[str(idx)]["flip"] == 1: From 356696c11b2f23eeed2da21f78f45609c3e53bb2 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 3 Sep 2024 20:16:30 -0400 Subject: [PATCH 07/30] low memory pasisng test --- .../_map_to_map/map_to_map_distance.py | 36 ++++++++++++++++ .../_map_to_map/map_to_map_pipeline.py | 41 +++++++++++++++---- .../data/_validation/output_validators.py | 1 + .../config_files/test_config_map_to_map.yaml | 13 +++--- 4 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e253d25..afc496f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -60,6 +60,42 @@ def get_distance(self, map1, map2): return self.compute_cost_l2(map1, map2) +class CorrelationLowMemory(MapToMapDistance): + """Correlation distance. + + Not technically a distance metric, but a similarity.""" + + def __init__(self, config): + super().__init__(config) + + def compute_cost_corr(self, map_1, map_2): + return (map_1 * map_2).sum() + + @override + def get_distance(self, map1, map2, global_store_of_running_results): + map1 = map1.flatten() + map1 -= map1.median() + map1 /= map1.std() + map1 = map1[global_store_of_running_results["mask"]] + + return self.compute_cost_corr(map1, map2) + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + maps_gt_flat = maps1 + maps_user_flat = maps2 + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + for idx_gt in range(len(maps_gt_flat)): + for idx_user in range(len(maps_user_flat)): + cost_matrix[idx_gt, idx_user] = self.get_distance( + maps_gt_flat[idx_gt], + maps_user_flat[idx_user], + global_store_of_running_results, + ) + + return cost_matrix + + class Correlation(MapToMapDistance): """Correlation distance. diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index ffc02df..df63caf 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,5 +1,6 @@ import mrcfile import pandas as pd +import numpy as np import pickle import torch @@ -7,6 +8,7 @@ from .._map_to_map.map_to_map_distance import ( FSCDistance, Correlation, + CorrelationLowMemory, L2DistanceSum, BioEM3dDistance, FSCResDistance, @@ -16,6 +18,7 @@ AVAILABLE_MAP2MAP_DISTANCES = { "fsc": FSCDistance, "corr": Correlation, + "corr_low_memory": CorrelationLowMemory, "l2": L2DistanceSum, "bioem": BioEM3dDistance, "res": FSCResDistance, @@ -32,7 +35,7 @@ def run(config): for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() } - n_pix = config["data"]["n_pix"] + # n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] @@ -51,28 +54,50 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( - -1, n_pix**3 - ) + # maps_gt_flat = np.load(config["data"]["ground_truth"]["volumes"], mmap_mode='r+')#.reshape(-1, n_pix**3) + from torch.utils.data import Dataset + + class GT_Dataset(Dataset): + def __init__(self, npy_file): + self.npy_file = npy_file + self.data = np.load(npy_file, mmap_mode="r+") + + self.shape = self.data.shape + self._dim = len(self.data.shape) + + def dim(self): + return self._dim + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + sample = self.data[idx] + return torch.from_numpy(sample.copy()) + + maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) if config["data"]["mask"]["do"]: mask = ( mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() ) - maps_gt_flat = maps_gt_flat[:, mask] + # maps_gt_flat = maps_gt_flat[:, mask] maps_user_flat = maps_user_flat[:, mask] else: - maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) + # maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) if config["analysis"]["normalize"]["do"]: if config["analysis"]["normalize"]["method"] == "median_zscore": - maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) + # maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values + # maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) computed_assets = {} + results_dict["mask"] = mask for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 9f76a6d..b233a30 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -27,6 +27,7 @@ class MapToMapResultsValidator: config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None + corr_low_memory: Optional[dict] = None l2: Optional[dict] = None bioem: Optional[dict] = None fsc: Optional[dict] = None diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 7dfa7e9..d1eca4d 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,18 +7,19 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc analysis: metrics: - - l2 - - corr - - bioem - - fsc - - res + # - l2 + # - corr + - corr_low_memory + # - bioem + # - fsc + # - res chunk_size_submission: 80 chunk_size_gt: 190 normalize: From 5e3591f4a9be4739b0f5ff113fbb1f54d7607945 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 18:03:28 -0400 Subject: [PATCH 08/30] tests passing for low memory mode t and f --- .../_map_to_map/map_to_map_distance.py | 22 +++++++++ .../_map_to_map/map_to_map_pipeline.py | 48 +++++++------------ .../config_files/test_config_map_to_map.yaml | 6 +-- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index afc496f..649ca4b 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -4,6 +4,28 @@ from typing_extensions import override import mrcfile import numpy as np +from torch.utils.data import Dataset + + +class GT_Dataset(Dataset): + def __init__(self, npy_file): + self.npy_file = npy_file + self.data = np.load(npy_file, mmap_mode="r+") + + self.shape = self.data.shape + self._dim = len(self.data.shape) + + def dim(self): + return self._dim + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + sample = self.data[idx] + return torch.from_numpy(sample.copy()) class MapToMapDistance: diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index df63caf..8923833 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,11 +1,11 @@ import mrcfile import pandas as pd -import numpy as np import pickle import torch from ..data._validation.output_validators import MapToMapResultsValidator from .._map_to_map.map_to_map_distance import ( + GT_Dataset, FSCDistance, Correlation, CorrelationLowMemory, @@ -35,7 +35,9 @@ def run(config): for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() } - # n_pix = config["data"]["n_pix"] + low_memory_mode = False + if not low_memory_mode: + n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] @@ -54,45 +56,31 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - # maps_gt_flat = np.load(config["data"]["ground_truth"]["volumes"], mmap_mode='r+')#.reshape(-1, n_pix**3) - from torch.utils.data import Dataset - - class GT_Dataset(Dataset): - def __init__(self, npy_file): - self.npy_file = npy_file - self.data = np.load(npy_file, mmap_mode="r+") - - self.shape = self.data.shape - self._dim = len(self.data.shape) - - def dim(self): - return self._dim - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - sample = self.data[idx] - return torch.from_numpy(sample.copy()) - - maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) + if low_memory_mode: + maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) + else: + maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( + -1, n_pix**3 + ) if config["data"]["mask"]["do"]: mask = ( mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() ) - # maps_gt_flat = maps_gt_flat[:, mask] + if not low_memory_mode: + maps_gt_flat = maps_gt_flat[:, mask] maps_user_flat = maps_user_flat[:, mask] else: - # maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) + if not low_memory_mode: + maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) if config["analysis"]["normalize"]["do"]: if config["analysis"]["normalize"]["method"] == "median_zscore": - # maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - # maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) + if not low_memory_mode: + maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values + if not low_memory_mode: + maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index d1eca4d..73edae3 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -15,8 +15,8 @@ data: analysis: metrics: # - l2 - # - corr - - corr_low_memory + - corr + # - corr_low_memory # - bioem # - fsc # - res From 7c8e24755040a9aa57f427b03426dc168435a931 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:17:45 -0400 Subject: [PATCH 09/30] l2, corr, bioem working --- .../_map_to_map/map_to_map_distance.py | 245 +++++++++++++++++- .../_map_to_map/map_to_map_pipeline.py | 31 ++- .../data/_validation/output_validators.py | 17 ++ .../config_files/test_config_map_to_map.yaml | 9 +- .../test_config_map_to_map_low_memory.yaml | 27 ++ tests/test_map_to_map.py | 5 + 6 files changed, 313 insertions(+), 21 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 649ca4b..46280fc 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -55,6 +55,39 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} +class MapToMapDistanceLowMemory(MapToMapDistance): + """General class for map-to-map distance metrics that require low memory.""" + + def __init__(self, config): + super().__init__(config) + + def compute_cost(self, map_1, map_2): + raise NotImplementedError() + + @override + def get_distance(self, map1, map2, global_store_of_running_results): + map1 = map1.flatten() + map1 -= map1.median() + map1 /= map1.std() + map1 = map1[global_store_of_running_results["mask"]] + + return self.compute_cost(map1, map2) + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + maps_gt_flat = maps1 + maps_user_flat = maps2 + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + for idx_gt in range(len(maps_gt_flat)): + for idx_user in range(len(maps_user_flat)): + cost_matrix[idx_gt, idx_user] = self.get_distance( + maps_gt_flat[idx_gt], + maps_user_flat[idx_user], + global_store_of_running_results, + ) + return cost_matrix + + class L2DistanceNorm(MapToMapDistance): """L2 distance norm""" @@ -66,24 +99,19 @@ def get_distance(self, map1, map2): return torch.norm(map1 - map2) ** 2 -class L2DistanceSum(MapToMapDistance): - """L2 distance. - - Computed by summing the squared differences between the two maps.""" +class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): + """L2 distance norm""" def __init__(self, config): super().__init__(config) - def compute_cost_l2(self, map_1, map_2): - return ((map_1 - map_2) ** 2).sum() - @override - def get_distance(self, map1, map2): - return self.compute_cost_l2(map1, map2) + def compute_cost(self, map1, map2): + return torch.norm(map1 - map2) ** 2 -class CorrelationLowMemory(MapToMapDistance): - """Correlation distance. +class CorrelationLowMemoryCheck(MapToMapDistance): + """Correlation. Not technically a distance metric, but a similarity.""" @@ -118,8 +146,12 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): return cost_matrix +def correlation(map1, map2): + return (map1 * map2).sum() + + class Correlation(MapToMapDistance): - """Correlation distance. + """Correlation. Not technically a distance metric, but a similarity.""" @@ -134,6 +166,61 @@ def get_distance(self, map1, map2): return self.compute_cost_corr(map1, map2) +class CorrelationLowMemory(MapToMapDistanceLowMemory): + """Correlation.""" + + def __init__(self, config): + super().__init__(config) + + @override + def compute_cost(self, map1, map2): + return correlation(map1, map2) + + +def compute_bioem3d_cost(map1, map2): + """ + Compute the cost between two maps using the BioEM cost function in 3D. + + Notes + ----- + See Eq. 10 in 10.1016/j.jsb.2013.10.006 + + Parameters + ---------- + map1 : torch.Tensor + shape (n_pix,n_pix,n_pix) + map2 : torch.Tensor + shape (n_pix,n_pix,n_pix) + + Returns + ------- + cost : torch.Tensor + shape (1,) + """ + m1, m2 = map1.reshape(-1), map2.reshape(-1) + co = m1.sum() + cc = m2.sum() + coo = m1.pow(2).sum() + ccc = m2.pow(2).sum() + coc = (m1 * m2).sum() + + N = len(m1) + + t1 = 2 * torch.pi * math.exp(1) + t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc + t3 = (N - 2) * (N * ccc - cc * cc) + + smallest_float = torch.finfo(m1.dtype).tiny + log_prob = ( + 0.5 * torch.pi + + math.log(t1) * (1 - N / 2) + + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) + + t3.clamp(smallest_float).log() * (N / 2 - 2) + ) + cost = -log_prob + return cost + + class BioEM3dDistance(MapToMapDistance): """BioEM 3D distance.""" @@ -193,6 +280,83 @@ def get_distance(self, map1, map2): return self.compute_bioem3d_cost(map1, map2) +class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): + """BioEM 3D distance.""" + + def __init__(self, config): + super().__init__(config) + + @override + def compute_cost(self, map1, map2): + return compute_bioem3d_cost(map1, map2) + + +def fourier_shell_correlation( + x: torch.Tensor, + y: torch.Tensor, + dim: Sequence[int] = (-3, -2, -1), + normalize: bool = True, + max_k: Optional[int] = None, +): + """Computes Fourier Shell / Ring Correlation (FSC) between x and y. + + Parameters + ---------- + x : torch.Tensor + First input tensor. + y : torch.Tensor + Second input tensor. + dim : Tuple[int, ...] + Dimensions over which to take the Fourier transform. + normalize : bool + Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). + Note that when `normalize=False`, we still divide by the number of elements in each shell. + max_k : int + The maximum shell to compute the correlation for. + + Returns + ------- + torch.Tensor + The correlation between x and y for each Fourier shell. + """ # noqa: E501 + batch_shape = x.shape[: -len(dim)] + + freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] + freq_total = ( + torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) + ) + + x_f = torch.fft.fftn(x, dim=dim) + y_f = torch.fft.fftn(y, dim=dim) + + n = min(x.shape[d] for d in dim) + + if max_k is None: + max_k = n // 2 + + result = x.new_zeros(batch_shape + (max_k,)) + + for i in range(1, max_k + 1): + mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) + x_ri = x_f[..., mask] + y_fi = y_f[..., mask] + + if x.is_cuda: + c_i = torch.linalg.vecdot(x_ri, y_fi).real + else: + # vecdot currently bugged on CPU for torch 2.0 in some configurations + c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real + + if normalize: + c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) + else: + c_i /= x_ri.shape[-1] + + result[..., i - 1] = c_i + + return result + + class FSCDistance(MapToMapDistance): """Fourier Shell Correlation distance. @@ -317,6 +481,59 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first +class FSCDistanceLowMemory(MapToMapDistance): + """Fourier Shell Correlation distance.""" + + def __init__(self, config): + super().__init__(config) + self.npix = self.config["data"]["n_pix"] + + def compute_cost(self, map_1, map_2): + raise NotImplementedError() + + @override + def get_distance(self, map1, map2, global_store_of_running_results): + maps_gt_flat = map1 = map1.flatten() + map1 -= map1.median() + map1 /= map1.std() + maps_gt_flat_cube = torch.zeros(self.n_pix**3) + map1 = map1[global_store_of_running_results["mask"]] + maps_gt_flat_cube[:, global_store_of_running_results["mask"]] = maps_gt_flat + + corr_vector = fourier_shell_correlation( + maps_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), + map2.reshape(self.n_pix, self.n_pix, self.n_pix), + ) + dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff + self.stored_computed_assets["corr_vector"] = corr_vector + return dist + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + maps_gt_flat = maps1 + maps_user_flat = maps2 + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + fsc_matrix = torch.zeros( + len(maps_gt_flat), len(maps_user_flat), self.n_pix // 2 + ) + for idx_gt in range(len(maps_gt_flat)): + for idx_user in range(len(maps_user_flat)): + cost_matrix[idx_gt, idx_user] = self.get_distance( + maps_gt_flat[idx_gt], + maps_user_flat[idx_user], + global_store_of_running_results, + ) + fsc_matrix[idx_gt, idx_user] = self.stored_computed_assets[ + "corr_vector" + ] + self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + return cost_matrix + + @override + def get_computed_assets(self, maps1, maps2, global_store_of_running_results): + return self.stored_computed_assets # must run get_distance_matrix first + + class FSCResDistance(MapToMapDistance): """FSC Resolution distance. @@ -351,3 +568,7 @@ def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} return units_Angstroms[res_fsc_half] + + +class FSCResDistanceLowMemory(MapToMapDistance): + pass diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 8923833..7028eb9 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -9,8 +9,10 @@ FSCDistance, Correlation, CorrelationLowMemory, - L2DistanceSum, + L2DistanceNorm, + L2DistanceNormLowMemory, BioEM3dDistance, + BioEM3dDistanceLowMemory, FSCResDistance, ) @@ -18,12 +20,19 @@ AVAILABLE_MAP2MAP_DISTANCES = { "fsc": FSCDistance, "corr": Correlation, - "corr_low_memory": CorrelationLowMemory, - "l2": L2DistanceSum, + "l2": L2DistanceNorm, "bioem": BioEM3dDistance, "res": FSCResDistance, } +AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY = { + "corr_low_memory": CorrelationLowMemory, + "l2_low_memory": L2DistanceNormLowMemory, + "bioem_low_memory": BioEM3dDistanceLowMemory, + "fsc_low_memory": FSCDistance, + "res_low_memory": FSCResDistance, +} + def run(config): """ @@ -33,9 +42,23 @@ def run(config): map_to_map_distances = { distance_label: distance_class(config) for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() + if distance_label in config["analysis"]["metrics"] + } + + map_to_map_distances_low_memory = { + distance_label: distance_class(config) + for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY.items() + if distance_label in config["analysis"]["metrics"] } - low_memory_mode = False + assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 + + if len(map_to_map_distances_low_memory) > 0: + map_to_map_distances = map_to_map_distances_low_memory + low_memory_mode = True + else: + low_memory_mode = False + if not low_memory_mode: n_pix = config["data"]["n_pix"] diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index b233a30..ada4492 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -19,9 +19,13 @@ class MapToMapResultsValidator: config: dict, input config dictionary. user_submitted_populations: torch.Tensor, user submitted populations, which sum to 1. corr: dict, correlation results. + corr_low_memory: dict, correlation results in low memory mode. l2: dict, L2 results. + l2_low_memory: dict, L2 results in low memory mode. bioem: dict, BioEM results. + bioem_low_memory: dict, BioEM results in low memory mode. fsc: dict, FSC results. + fsc_low_memory: dict, FSC results in low memory mode. """ config: dict @@ -29,9 +33,13 @@ class MapToMapResultsValidator: corr: Optional[dict] = None corr_low_memory: Optional[dict] = None l2: Optional[dict] = None + l2_low_memory: Optional[dict] = None bioem: Optional[dict] = None + bioem_low_memory: Optional[dict] = None fsc: Optional[dict] = None + fsc_low_memory: Optional[dict] = None res: Optional[dict] = None + res_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_mtm(self.config) @@ -142,6 +150,10 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. + fsc_low_memory: dict, FSC distance results in low memory mode. + bioem_low_memory: dict, BioEM distance results in low memory mode. + l2_low_memory: dict, L2 distance results in low memory mode. + corr_low_memory: dict, correlation distance results in low memory mode. """ config: dict @@ -152,6 +164,11 @@ class DistributionToDistributionResultsValidator: res: Optional[dict] = None l2: Optional[dict] = None corr: Optional[dict] = None + fsc_low_memory: Optional[dict] = None + bioem_low_memory: Optional[dict] = None + res_low_memory: Optional[dict] = None + l2_low_memory: Optional[dict] = None + corr_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_disttodist(self.config) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 73edae3..7dfa7e9 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -14,12 +14,11 @@ data: volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc analysis: metrics: - # - l2 + - l2 - corr - # - corr_low_memory - # - bioem - # - fsc - # - res + - bioem + - fsc + - res chunk_size_submission: 80 chunk_size_gt: 190 normalize: diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml new file mode 100644 index 0000000..41f4a1e --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -0,0 +1,27 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: true + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2_low_memory + - corr_low_memory + - bioem_low_memory + # - fsc_low_memory + # - res_low_memory + chunk_size_submission: 80 + chunk_size_gt: 190 + normalize: + do: true + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index e31f29f..301ee12 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -7,3 +7,8 @@ def test_run_map2map_pipeline(): {"config": "tests/config_files/test_config_map_to_map.yaml"} ) run_map2map_pipeline.main(args) + + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map_low_memory.yaml"} + ) + run_map2map_pipeline.main(args) From 14805d7cd5d2953356927e3339bb8d2e4ebf1ced Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:24:46 -0400 Subject: [PATCH 10/30] fsc low memory working --- .../_map_to_map/map_to_map_distance.py | 14 +++++++------- .../_map_to_map/map_to_map_pipeline.py | 6 ++++-- .../test_config_map_to_map_low_memory.yaml | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 46280fc..fe54c9a 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -486,26 +486,26 @@ class FSCDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) - self.npix = self.config["data"]["n_pix"] + self.n_pix = self.config["data"]["n_pix"] def compute_cost(self, map_1, map_2): raise NotImplementedError() @override def get_distance(self, map1, map2, global_store_of_running_results): - maps_gt_flat = map1 = map1.flatten() + map_gt_flat = map1 = map1.flatten() map1 -= map1.median() map1 /= map1.std() - maps_gt_flat_cube = torch.zeros(self.n_pix**3) + map_gt_flat_cube = torch.zeros(self.n_pix**3) map1 = map1[global_store_of_running_results["mask"]] - maps_gt_flat_cube[:, global_store_of_running_results["mask"]] = maps_gt_flat + map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat corr_vector = fourier_shell_correlation( - maps_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), + map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), map2.reshape(self.n_pix, self.n_pix, self.n_pix), ) - dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff - self.stored_computed_assets["corr_vector"] = corr_vector + dist = 1 - corr_vector.mean() # TODO: spectral cutoff + self.stored_computed_assets = {"corr_vector": corr_vector} return dist @override diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 7028eb9..60d3f34 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -7,6 +7,7 @@ from .._map_to_map.map_to_map_distance import ( GT_Dataset, FSCDistance, + FSCDistanceLowMemory, Correlation, CorrelationLowMemory, L2DistanceNorm, @@ -14,6 +15,7 @@ BioEM3dDistance, BioEM3dDistanceLowMemory, FSCResDistance, + FSCResDistanceLowMemory, ) @@ -29,8 +31,8 @@ "corr_low_memory": CorrelationLowMemory, "l2_low_memory": L2DistanceNormLowMemory, "bioem_low_memory": BioEM3dDistanceLowMemory, - "fsc_low_memory": FSCDistance, - "res_low_memory": FSCResDistance, + "fsc_low_memory": FSCDistanceLowMemory, + "res_low_memory": FSCResDistanceLowMemory, } diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index 41f4a1e..0d62008 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -17,7 +17,7 @@ analysis: - l2_low_memory - corr_low_memory - bioem_low_memory - # - fsc_low_memory + - fsc_low_memory # - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 From 6f5e09eb2cfe55293c6b5d29d7104b41b1d99da7 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:42:27 -0400 Subject: [PATCH 11/30] tests passing checking identical output --- .../_commands/run_map2map_pipeline.py | 4 +--- tests/test_map_to_map.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/cryo_challenge/_commands/run_map2map_pipeline.py b/src/cryo_challenge/_commands/run_map2map_pipeline.py index ab36f7a..90db1aa 100644 --- a/src/cryo_challenge/_commands/run_map2map_pipeline.py +++ b/src/cryo_challenge/_commands/run_map2map_pipeline.py @@ -39,9 +39,7 @@ def main(args): warnexists(config["output"]) mkbasedir(os.path.dirname(config["output"])) - run(config) - - return + return run(config) if __name__ == "__main__": diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 301ee12..c1817d8 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -1,14 +1,27 @@ from omegaconf import OmegaConf from cryo_challenge._commands import run_map2map_pipeline +import numpy as np def test_run_map2map_pipeline(): args = OmegaConf.create( {"config": "tests/config_files/test_config_map_to_map.yaml"} ) - run_map2map_pipeline.main(args) + results_dict = run_map2map_pipeline.main(args) - args = OmegaConf.create( + args_low_memory = OmegaConf.create( {"config": "tests/config_files/test_config_map_to_map_low_memory.yaml"} ) - run_map2map_pipeline.main(args) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fsc_matrix" + ], + ) + np.allclose( + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, + ) From a7290aae5eca25d4bcea2c322a76ed155d48271e Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:47:12 -0400 Subject: [PATCH 12/30] all metrics implemented and test passing for matched results --- .../_map_to_map/map_to_map_distance.py | 34 ++++++++++++++++++- .../test_config_map_to_map_low_memory.yaml | 2 +- tests/test_map_to_map.py | 7 ++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index fe54c9a..b3bbf2c 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -571,4 +571,36 @@ def res_at_fsc_threshold(fscs, threshold=0.5): class FSCResDistanceLowMemory(MapToMapDistance): - pass + """FSC Resolution distance. + + The resolution at which the Fourier Shell Correlation reaches 0.5. + Built on top of the FSCDistance class. This needs to be run first and store the FSC matrix in the computed assets. + """ + + def __init__(self, config): + super().__init__(config) + + @override + def get_distance_matrix( + self, maps1, maps2, global_store_of_running_results + ): # custom method + # get fsc matrix + fourier_pixel_max = ( + self.config["data"]["n_pix"] // 2 + ) # TODO: check for odd psizes if this should be +1 + psize = self.config["data"]["psize"] + fsc_matrix = global_store_of_running_results["fsc_low_memory"][ + "computed_assets" + ]["fsc_matrix"] + units_Angstroms = ( + 2 * psize / (np.arange(1, fourier_pixel_max + 1) / fourier_pixel_max) + ) + + def res_at_fsc_threshold(fscs, threshold=0.5): + res_fsc_half = np.argmin(fscs > threshold, axis=-1) + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist + + res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) + self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} + return units_Angstroms[res_fsc_half] diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index 0d62008..4eb6cd0 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -18,7 +18,7 @@ analysis: - corr_low_memory - bioem_low_memory - fsc_low_memory - # - res_low_memory + - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 normalize: diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c1817d8..ed91815 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -21,6 +21,13 @@ def test_run_map2map_pipeline(): "fsc_matrix" ], ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fraction_nyquist" + ], + ) np.allclose( results_dict[metric]["cost_matrix"].values, results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, From 22ec5a5085cb8a8eaac1d54549ff6296dfa1bde1 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:50:09 -0400 Subject: [PATCH 13/30] all metrics implemented and test passing for matched results --- .../_map_to_map/map_to_map_distance.py | 31 +++---------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index b3bbf2c..369c47f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -543,6 +543,7 @@ class FSCResDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.fsc_label = "fsc" @override def get_distance_matrix( @@ -553,7 +554,7 @@ def get_distance_matrix( self.config["data"]["n_pix"] // 2 ) # TODO: check for odd psizes if this should be +1 psize = self.config["data"]["psize"] - fsc_matrix = global_store_of_running_results["fsc"]["computed_assets"][ + fsc_matrix = global_store_of_running_results[self.fsc_label]["computed_assets"][ "fsc_matrix" ] units_Angstroms = ( @@ -570,7 +571,7 @@ def res_at_fsc_threshold(fscs, threshold=0.5): return units_Angstroms[res_fsc_half] -class FSCResDistanceLowMemory(MapToMapDistance): +class FSCResDistanceLowMemory(FSCResDistance): """FSC Resolution distance. The resolution at which the Fourier Shell Correlation reaches 0.5. @@ -579,28 +580,4 @@ class FSCResDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) - - @override - def get_distance_matrix( - self, maps1, maps2, global_store_of_running_results - ): # custom method - # get fsc matrix - fourier_pixel_max = ( - self.config["data"]["n_pix"] // 2 - ) # TODO: check for odd psizes if this should be +1 - psize = self.config["data"]["psize"] - fsc_matrix = global_store_of_running_results["fsc_low_memory"][ - "computed_assets" - ]["fsc_matrix"] - units_Angstroms = ( - 2 * psize / (np.arange(1, fourier_pixel_max + 1) / fourier_pixel_max) - ) - - def res_at_fsc_threshold(fscs, threshold=0.5): - res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist - - res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) - self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} - return units_Angstroms[res_fsc_half] + self.fsc_label = "fsc_low_memory" From de3cddc1e0839b507acd60705f02a1b7f6fcc085 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:53:37 -0400 Subject: [PATCH 14/30] all metrics implemented and test passing for matched results --- .../_map_to_map/map_to_map_distance.py | 159 +----------------- 1 file changed, 3 insertions(+), 156 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 369c47f..70df449 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -110,42 +110,6 @@ def compute_cost(self, map1, map2): return torch.norm(map1 - map2) ** 2 -class CorrelationLowMemoryCheck(MapToMapDistance): - """Correlation. - - Not technically a distance metric, but a similarity.""" - - def __init__(self, config): - super().__init__(config) - - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() - map1 = map1[global_store_of_running_results["mask"]] - - return self.compute_cost_corr(map1, map2) - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - - return cost_matrix - - def correlation(map1, map2): return (map1 * map2).sum() @@ -158,12 +122,9 @@ class Correlation(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - @override def get_distance(self, map1, map2): - return self.compute_cost_corr(map1, map2) + return correlation(map1, map2) class CorrelationLowMemory(MapToMapDistanceLowMemory): @@ -227,57 +188,9 @@ class BioEM3dDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_bioem3d_cost(self, map1, map2): - """ - Compute the cost between two maps using the BioEM cost function in 3D. - - Notes - ----- - See Eq. 10 in 10.1016/j.jsb.2013.10.006 - - Parameters - ---------- - map1 : torch.Tensor - shape (n_pix,n_pix,n_pix) - map2 : torch.Tensor - shape (n_pix,n_pix,n_pix) - - Returns - ------- - cost : torch.Tensor - shape (1,) - """ - m1, m2 = map1.reshape(-1), map2.reshape(-1) - co = m1.sum() - cc = m2.sum() - coo = m1.pow(2).sum() - ccc = m2.pow(2).sum() - coc = (m1 * m2).sum() - - N = len(m1) - - t1 = 2 * torch.pi * math.exp(1) - t2 = ( - N * (ccc * coo - coc * coc) - + 2 * co * coc * cc - - ccc * co * co - - coo * cc * cc - ) - t3 = (N - 2) * (N * ccc - cc * cc) - - smallest_float = torch.finfo(m1.dtype).tiny - log_prob = ( - 0.5 * torch.pi - + math.log(t1) * (1 - N / 2) - + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) - + t3.clamp(smallest_float).log() * (N / 2 - 2) - ) - cost = -log_prob - return cost - @override def get_distance(self, map1, map2): - return self.compute_bioem3d_cost(map1, map2) + return compute_bioem3d_cost(map1, map2) class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): @@ -365,72 +278,6 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def fourier_shell_correlation( - self, - x: torch.Tensor, - y: torch.Tensor, - dim: Sequence[int] = (-3, -2, -1), - normalize: bool = True, - max_k: Optional[int] = None, - ): - """Computes Fourier Shell / Ring Correlation (FSC) between x and y. - - Parameters - ---------- - x : torch.Tensor - First input tensor. - y : torch.Tensor - Second input tensor. - dim : Tuple[int, ...] - Dimensions over which to take the Fourier transform. - normalize : bool - Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). - Note that when `normalize=False`, we still divide by the number of elements in each shell. - max_k : int - The maximum shell to compute the correlation for. - - Returns - ------- - torch.Tensor - The correlation between x and y for each Fourier shell. - """ # noqa: E501 - batch_shape = x.shape[: -len(dim)] - - freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] - freq_total = ( - torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) - ) - - x_f = torch.fft.fftn(x, dim=dim) - y_f = torch.fft.fftn(y, dim=dim) - - n = min(x.shape[d] for d in dim) - - if max_k is None: - max_k = n // 2 - - result = x.new_zeros(batch_shape + (max_k,)) - - for i in range(1, max_k + 1): - mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) - x_ri = x_f[..., mask] - y_fi = y_f[..., mask] - - if x.is_cuda: - c_i = torch.linalg.vecdot(x_ri, y_fi).real - else: - # vecdot currently bugged on CPU for torch 2.0 in some configurations - c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real - - if normalize: - c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) - else: - c_i /= x_ri.shape[-1] - - result[..., i - 1] = c_i - - return result - def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ Compute the cost between two maps using the Fourier Shell Correlation in 3D. @@ -443,7 +290,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) for idx in range(len(maps_gt_flat)): - corr_vector = self.fourier_shell_correlation( + corr_vector = fourier_shell_correlation( maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), ) From f20344e8b6ac36384f34baa8fa96f8fc69d9adad Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 21:10:08 -0400 Subject: [PATCH 15/30] flags for masking and not masking --- .../_map_to_map/map_to_map_distance.py | 40 +++++++++++++------ .../_map_to_map/map_to_map_pipeline.py | 4 ++ 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 70df449..4769528 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -60,6 +60,7 @@ class MapToMapDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.config = config def compute_cost(self, map_1, map_2): raise NotImplementedError() @@ -67,8 +68,14 @@ def compute_cost(self, map_1, map_2): @override def get_distance(self, map1, map2, global_store_of_running_results): map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() + if self.config["analysis"]["normalize"]["do"]: + if self.config["analysis"]["normalize"]["method"] == "median_zscore": + map1 -= map1.median() + map1 /= map1.std() + else: + raise NotImplementedError( + f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." + ) map1 = map1[global_store_of_running_results["mask"]] return self.compute_cost(map1, map2) @@ -308,14 +315,19 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): maps_user_flat = maps2 n_pix = self.config["data"]["n_pix"] maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat + + if self.config["data"]["mask"]["do"]: + mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) + maps_gt_flat_cube[:, mask] = maps_gt_flat + maps_user_flat_cube[:, mask] = maps_user_flat + else: + maps_gt_flat_cube = maps_gt_flat + maps_user_flat_cube = maps_user_flat cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( maps_gt_flat_cube, maps_user_flat_cube, n_pix @@ -334,6 +346,7 @@ class FSCDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) self.n_pix = self.config["data"]["n_pix"] + self.config = config def compute_cost(self, map_1, map_2): raise NotImplementedError() @@ -341,11 +354,12 @@ def compute_cost(self, map_1, map_2): @override def get_distance(self, map1, map2, global_store_of_running_results): map_gt_flat = map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() map_gt_flat_cube = torch.zeros(self.n_pix**3) - map1 = map1[global_store_of_running_results["mask"]] - map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + if self.config["data"]["mask"]["do"]: + map_gt_flat = map_gt_flat[global_store_of_running_results["mask"]] + map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + else: + map_gt_flat_cube = map_gt_flat corr_vector = fourier_shell_correlation( map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 60d3f34..d7db04b 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -108,6 +108,10 @@ def run(config): maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) + else: + raise NotImplementedError( + f"Normalization method {config['analysis']['normalize']['method']} not implemented." + ) computed_assets = {} results_dict["mask"] = mask From 6def1bf6e6a20a4c3f50652002f23eb5f53fc1c2 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 21:15:21 -0400 Subject: [PATCH 16/30] tests for masking and normalization --- ..._to_map_low_memory_nomask_nonormalize.yaml | 27 +++++++++ ..._config_map_to_map_nomask_nonormalize.yaml | 27 +++++++++ tests/test_map_to_map.py | 56 ++++++++++--------- 3 files changed, 85 insertions(+), 25 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml create mode 100644 tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml new file mode 100644 index 0000000..13a1cf4 --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml @@ -0,0 +1,27 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2_low_memory + - corr_low_memory + - bioem_low_memory + - fsc_low_memory + - res_low_memory + chunk_size_submission: 80 + chunk_size_gt: 190 + normalize: + do: false + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml new file mode 100644 index 0000000..7856db4 --- /dev/null +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -0,0 +1,27 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + normalize: + do: false + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index ed91815..957a9b1 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -4,31 +4,37 @@ def test_run_map2map_pipeline(): - args = OmegaConf.create( - {"config": "tests/config_files/test_config_map_to_map.yaml"} - ) - results_dict = run_map2map_pipeline.main(args) + for config_fname, config_fname_low_memory in zip( + [ + "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map.yaml", + ], + [ + "tests/config_files/test_config_map_to_map_low_memory.yaml", + "tests/config_files/test_config_map_to_map_low_memory.yaml", + ], + ): + args = OmegaConf.create({"config": config_fname}) + results_dict = run_map2map_pipeline.main(args) - args_low_memory = OmegaConf.create( - {"config": "tests/config_files/test_config_map_to_map_low_memory.yaml"} - ) - results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) - for metric in ["fsc", "corr", "l2", "bioem"]: - if metric == "fsc": + args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fsc_matrix" + ], + ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fraction_nyquist" + ], + ) np.allclose( - results_dict[metric]["computed_assets"]["fsc_matrix"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fsc_matrix" - ], + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, ) - elif metric == "res": - np.allclose( - results_dict[metric]["computed_assets"]["fraction_nyquist"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fraction_nyquist" - ], - ) - np.allclose( - results_dict[metric]["cost_matrix"].values, - results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, - ) From c0f2f831facafc2fd3bddc5b112dfa0c97343254 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:16:11 -0400 Subject: [PATCH 17/30] code duplication for norm --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 4769528..b413ade 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -95,6 +95,10 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): return cost_matrix +def norm2(map1, map2): + return torch.norm(map1 - map2) ** 2 + + class L2DistanceNorm(MapToMapDistance): """L2 distance norm""" @@ -103,7 +107,7 @@ def __init__(self, config): @override def get_distance(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): @@ -114,7 +118,7 @@ def __init__(self, config): @override def compute_cost(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) def correlation(map1, map2): From f1f8cd8ad1dfc1fdae131f2f9ed3250125320423 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:30:59 -0400 Subject: [PATCH 18/30] tests passing. vmap over sub batch --- .../_map_to_map/map_to_map_distance.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index b413ade..317964f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -40,14 +40,30 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" chunk_size_submission = self.config["analysis"]["chunk_size_submission"] chunk_size_gt = self.config["analysis"]["chunk_size_gt"] - distance_matrix = torch.vmap( - lambda maps1: torch.vmap( - lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, - )(maps2), - chunk_size=chunk_size_gt, - )(maps1) + # load in memory as torch tensors + if True: # config.low_memory: + distance_matrix = torch.empty(len(maps1), len(maps2)) + n_chunks_low_memory = 100 + for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): + maps1_in_memory = maps1[idxs] + sub_distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1_in_memory) + distance_matrix[idxs] = sub_distance_matrix + else: + assert False, "Not implemented" + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1) return distance_matrix def get_computed_assets(self, maps1, maps2, global_store_of_running_results): From 61608cb3749fb874c725c856152af1c39d3b9c40 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:52:32 -0400 Subject: [PATCH 19/30] tests passing. get_sub_distance_matrix --- .../_map_to_map/map_to_map_distance.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 317964f..ce650f1 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -31,28 +31,35 @@ def __getitem__(self, idx): class MapToMapDistance: def __init__(self, config): self.config = config + self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] def get_distance(self, map1, map2): """Compute the distance between two maps.""" raise NotImplementedError() + def get_sub_distance_matrix(self, maps1, maps2): + """Compute the distance matrix between two sets of maps.""" + sub_distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=self.chunk_size_submission, + )(maps2), + chunk_size=self.chunk_size_gt, + )(maps1) + return sub_distance_matrix + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" - chunk_size_submission = self.config["analysis"]["chunk_size_submission"] - chunk_size_gt = self.config["analysis"]["chunk_size_gt"] # load in memory as torch tensors if True: # config.low_memory: distance_matrix = torch.empty(len(maps1), len(maps2)) n_chunks_low_memory = 100 for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): maps1_in_memory = maps1[idxs] - sub_distance_matrix = torch.vmap( - lambda maps1: torch.vmap( - lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, - )(maps2), - chunk_size=chunk_size_gt, - )(maps1_in_memory) + sub_distance_matrix = self.get_sub_distance_matrix( + maps1_in_memory, maps2 + ) distance_matrix[idxs] = sub_distance_matrix else: @@ -60,9 +67,9 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, + chunk_size=self.chunk_size_submission, )(maps2), - chunk_size=chunk_size_gt, + chunk_size=self.chunk_size_gt, )(maps1) return distance_matrix From 1440d8419078526460913d911cd538b91daefcb8 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 15:04:55 -0400 Subject: [PATCH 20/30] tests passing with hard coded instantiation --- .../_map_to_map/map_to_map_distance.py | 40 +++++++++++++++++-- .../_map_to_map/map_to_map_pipeline.py | 2 + 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index ce650f1..c24c300 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -38,7 +38,7 @@ def get_distance(self, map1, map2): """Compute the distance between two maps.""" raise NotImplementedError() - def get_sub_distance_matrix(self, maps1, maps2): + def get_sub_distance_matrix(self, maps1, maps2, idxs): """Compute the distance matrix between two sets of maps.""" sub_distance_matrix = torch.vmap( lambda maps1: torch.vmap( @@ -58,7 +58,9 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): maps1_in_memory = maps1[idxs] sub_distance_matrix = self.get_sub_distance_matrix( - maps1_in_memory, maps2 + maps1_in_memory, + maps2, + idxs, ) distance_matrix[idxs] = sub_distance_matrix @@ -311,6 +313,7 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.stored_computed_assets = {"fsc_matrix": torch.empty(10, 8, 8)} def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ @@ -334,7 +337,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): return cost_matrix, fsc_matrix @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + def get_sub_distance_matrix(self, maps1, maps2, idxs): """ Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. """ @@ -359,9 +362,38 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( maps_gt_flat_cube, maps_user_flat_cube, n_pix ) - self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix + # @override + # def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + # """ + # Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. + # """ + # maps_gt_flat = maps1 + # maps_user_flat = maps2 + # n_pix = self.config["data"]["n_pix"] + # maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) + # maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) + + # if self.config["data"]["mask"]["do"]: + # mask = ( + # mrcfile.open(self.config["data"]["mask"]["volume"]) + # .data.astype(bool) + # .flatten() + # ) + # maps_gt_flat_cube[:, mask] = maps_gt_flat + # maps_user_flat_cube[:, mask] = maps_user_flat + # else: + # maps_gt_flat_cube = maps_gt_flat + # maps_user_flat_cube = maps_user_flat + + # cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( + # maps_gt_flat_cube, maps_user_flat_cube, n_pix + # ) + # self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + # return cost_matrix + @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index d7db04b..7016c3f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -118,6 +118,8 @@ def run(config): for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) + print("maps_gt_flat", maps_gt_flat.shape) + print("maps_user_flat", maps_user_flat.shape) cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, From 555a39fc8a38200da3b03f0a84d5babf96390bd7 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 17:31:02 -0400 Subject: [PATCH 21/30] map_to_map_distance.distance_matrix_precomputation --- .../_map_to_map/map_to_map_distance.py | 15 ++++++++++++++- .../_map_to_map/map_to_map_pipeline.py | 3 +++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index c24c300..14f933a 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -33,6 +33,7 @@ def __init__(self, config): self.config = config self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + self.n_pix = self.config["data"]["n_pix"] def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -49,6 +50,10 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): )(maps1) return sub_distance_matrix + def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results): + """Pre-compute any assets needed for the distance matrix computation.""" + return + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" # load in memory as torch tensors @@ -313,7 +318,6 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - self.stored_computed_assets = {"fsc_matrix": torch.empty(10, 8, 8)} def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ @@ -336,6 +340,15 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix[idx] = dist return cost_matrix, fsc_matrix + @override + def distance_matrix_precomputation(self, maps1, maps2): + self.len_maps1 = len(maps1) + self.len_maps2 = len(maps2) + self.stored_computed_assets = { + "fsc_matrix": torch.empty(self.len_maps1, self.len_maps2, self.n_pix // 2) + } + return + @override def get_sub_distance_matrix(self, maps1, maps2, idxs): """ diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 7016c3f..434280e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -121,6 +121,9 @@ def run(config): print("maps_gt_flat", maps_gt_flat.shape) print("maps_user_flat", maps_user_flat.shape) + map_to_map_distance.distance_matrix_precomputation( + maps_gt_flat, maps_user_flat + ) cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, maps_user_flat, From b71e52b47955f8bc2657e30de287dfe363cacb0f Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 17:41:26 -0400 Subject: [PATCH 22/30] chunk_size_low_memory --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 5 +++-- tests/config_files/test_config_map_to_map.yaml | 1 + tests/config_files/test_config_map_to_map_low_memory.yaml | 1 + ...test_config_map_to_map_low_memory_nomask_nonormalize.yaml | 1 + .../test_config_map_to_map_nomask_nonormalize.yaml | 1 + 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 14f933a..e3cbf31 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -34,6 +34,7 @@ def __init__(self, config): self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] self.n_pix = self.config["data"]["n_pix"] + self.chunk_size_low_memory = self.config["analysis"]["chunk_size_low_memory"] def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -58,9 +59,9 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" # load in memory as torch tensors if True: # config.low_memory: + self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) - n_chunks_low_memory = 100 - for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): + for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] sub_distance_matrix = self.get_sub_distance_matrix( maps1_in_memory, diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 7dfa7e9..62651f8 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -21,6 +21,7 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index 4eb6cd0..e02e271 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -21,6 +21,7 @@ analysis: - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml index 13a1cf4..49dac78 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml @@ -21,6 +21,7 @@ analysis: - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: false method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index 7856db4..e97abaf 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -21,6 +21,7 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: false method: median_zscore From 76fdb462e63e379d630de083ba78740e84a8e503 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 18:54:32 -0400 Subject: [PATCH 23/30] on the fly sub batch normalization --- .../_map_to_map/map_to_map_distance.py | 18 ++++++++++++++++++ .../_map_to_map/map_to_map_pipeline.py | 13 ------------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e3cbf31..df81ef6 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -28,6 +28,15 @@ def __getitem__(self, idx): return torch.from_numpy(sample.copy()) +def normalize(maps, method): + if method == "median_zscore": + maps -= maps.median(dim=1, keepdim=True).values + maps /= maps.std(dim=1, keepdim=True) + else: + raise NotImplementedError(f"Normalization method {method} not implemented.") + return maps + + class MapToMapDistance: def __init__(self, config): self.config = config @@ -58,11 +67,20 @@ def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" # load in memory as torch tensors + if self.config["analysis"]["normalize"]["do"]: + maps2 = normalize( + maps2, method=self.config["analysis"]["normalize"]["method"] + ) if True: # config.low_memory: self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] + if self.config["analysis"]["normalize"]["do"]: + maps1_in_memory = normalize( + maps1_in_memory, + method=self.config["analysis"]["normalize"]["method"], + ) sub_distance_matrix = self.get_sub_distance_matrix( maps1_in_memory, maps2, diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 434280e..0f29e56 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -100,19 +100,6 @@ def run(config): maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) - if config["analysis"]["normalize"]["do"]: - if config["analysis"]["normalize"]["method"] == "median_zscore": - if not low_memory_mode: - maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - if not low_memory_mode: - maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) - maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values - maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) - else: - raise NotImplementedError( - f"Normalization method {config['analysis']['normalize']['method']} not implemented." - ) - computed_assets = {} results_dict["mask"] = mask for distance_label, map_to_map_distance in map_to_map_distances.items(): From 0a2efc0ad9bb3758a781c5ed33a9b587c6bc937e Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 19:06:12 -0400 Subject: [PATCH 24/30] normalization and masking on the fly in sub batch --- .../_map_to_map/map_to_map_distance.py | 69 ++++++++----------- .../_map_to_map/map_to_map_pipeline.py | 21 ++---- .../config_files/test_config_map_to_map.yaml | 2 +- ..._config_map_to_map_nomask_nonormalize.yaml | 2 +- 4 files changed, 34 insertions(+), 60 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index df81ef6..0782f6f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -44,6 +44,11 @@ def __init__(self, config): self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] self.n_pix = self.config["data"]["n_pix"] self.chunk_size_low_memory = self.config["analysis"]["chunk_size_low_memory"] + self.mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -66,7 +71,11 @@ def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" - # load in memory as torch tensors + if self.config["data"]["mask"]["do"]: + maps2 = maps2[:, self.mask] + else: + maps2.reshape(len(maps2), -1, inplace=True) + if self.config["analysis"]["normalize"]["do"]: maps2 = normalize( maps2, method=self.config["analysis"]["normalize"]["method"] @@ -76,6 +85,10 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] + if self.config["data"]["mask"]["do"]: + maps1_in_memory = maps1_in_memory[:, self.mask] + else: + maps1_in_memory.reshape(len(maps1_in_memory), -1, inplace=True) if self.config["analysis"]["normalize"]["do"]: maps1_in_memory = normalize( maps1_in_memory, @@ -125,7 +138,7 @@ def get_distance(self, map1, map2, global_store_of_running_results): raise NotImplementedError( f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." ) - map1 = map1[global_store_of_running_results["mask"]] + map1 = map1[self.mask] return self.compute_cost(map1, map2) @@ -380,13 +393,8 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) if self.config["data"]["mask"]["do"]: - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat - maps_user_flat_cube[:, mask] = maps_user_flat + maps_gt_flat_cube[:, self.mask] = maps_gt_flat + maps_user_flat_cube[:, self.mask] = maps_user_flat else: maps_gt_flat_cube = maps_gt_flat maps_user_flat_cube = maps_user_flat @@ -397,35 +405,6 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix - # @override - # def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - # """ - # Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. - # """ - # maps_gt_flat = maps1 - # maps_user_flat = maps2 - # n_pix = self.config["data"]["n_pix"] - # maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - # maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - - # if self.config["data"]["mask"]["do"]: - # mask = ( - # mrcfile.open(self.config["data"]["mask"]["volume"]) - # .data.astype(bool) - # .flatten() - # ) - # maps_gt_flat_cube[:, mask] = maps_gt_flat - # maps_user_flat_cube[:, mask] = maps_user_flat - # else: - # maps_gt_flat_cube = maps_gt_flat - # maps_user_flat_cube = maps_user_flat - - # cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( - # maps_gt_flat_cube, maps_user_flat_cube, n_pix - # ) - # self.stored_computed_assets = {"fsc_matrix": fsc_matrix} - # return cost_matrix - @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first @@ -447,8 +426,8 @@ def get_distance(self, map1, map2, global_store_of_running_results): map_gt_flat = map1 = map1.flatten() map_gt_flat_cube = torch.zeros(self.n_pix**3) if self.config["data"]["mask"]["do"]: - map_gt_flat = map_gt_flat[global_store_of_running_results["mask"]] - map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + map_gt_flat = map_gt_flat[self.mask] + map_gt_flat_cube[self.mask] = map_gt_flat else: map_gt_flat_cube = map_gt_flat @@ -483,7 +462,15 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - return self.stored_computed_assets # must run get_distance_matrix first + """ + Return any computed assets that are needed for (downstream) analysis. + + Notes + ----- + The FSC matrix is stored in the computed assets. + Must run get_distance_matrix first. + """ + return self.stored_computed_assets class FSCResDistance(MapToMapDistance): diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 0f29e56..e12dab0 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,4 +1,4 @@ -import mrcfile +import numpy as np import pandas as pd import pickle import torch @@ -84,24 +84,11 @@ def run(config): if low_memory_mode: maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) else: - maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( - -1, n_pix**3 - ) - - if config["data"]["mask"]["do"]: - mask = ( - mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() - ) - if not low_memory_mode: - maps_gt_flat = maps_gt_flat[:, mask] - maps_user_flat = maps_user_flat[:, mask] - else: - if not low_memory_mode: - maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) - maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) + maps_gt_flat = torch.from_numpy( + np.load(config["data"]["ground_truth"]["volumes"]) + ).reshape(-1, n_pix**3) computed_assets = {} - results_dict["mask"] = mask for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 62651f8..0d951dc 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index e97abaf..b73b88f 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: false From 42525f4eaf67533a2642b39370afd9af3a03c1df Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 19:48:56 -0400 Subject: [PATCH 25/30] tests passing --- .../_map_to_map/map_to_map_distance.py | 16 ++++++++++------ .../_map_to_map/map_to_map_pipeline.py | 11 ++++------- .../data/_validation/config_validators.py | 1 + tests/config_files/test_config_map_to_map.yaml | 4 +++- .../test_config_map_to_map_low_memory.yaml | 4 +++- ...map_to_map_low_memory_nomask_nonormalize.yaml | 3 +++ ...est_config_map_to_map_nomask_nonormalize.yaml | 4 +++- tests/test_map_to_map.py | 4 ++-- 8 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 0782f6f..451ca5f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -40,10 +40,13 @@ def normalize(maps, method): class MapToMapDistance: def __init__(self, config): self.config = config + self.do_low_memory_mode = self.config["analysis"]["low_memory"]["do"] self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] self.n_pix = self.config["data"]["n_pix"] - self.chunk_size_low_memory = self.config["analysis"]["chunk_size_low_memory"] + self.chunk_size_low_memory = self.config["analysis"]["low_memory"][ + "chunk_size_low_memory" + ] self.mask = ( mrcfile.open(self.config["data"]["mask"]["volume"]) .data.astype(bool) @@ -74,21 +77,23 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): if self.config["data"]["mask"]["do"]: maps2 = maps2[:, self.mask] else: - maps2.reshape(len(maps2), -1, inplace=True) + maps2 = maps2.reshape(len(maps2), -1) if self.config["analysis"]["normalize"]["do"]: maps2 = normalize( maps2, method=self.config["analysis"]["normalize"]["method"] ) - if True: # config.low_memory: + if True: # self.do_low_memory_mode: self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] if self.config["data"]["mask"]["do"]: - maps1_in_memory = maps1_in_memory[:, self.mask] + maps1_in_memory = maps1_in_memory.reshape(len(idxs), -1)[ + :, self.mask + ] else: - maps1_in_memory.reshape(len(maps1_in_memory), -1, inplace=True) + maps1_in_memory = maps1_in_memory.reshape(len(maps1_in_memory), -1) if self.config["analysis"]["normalize"]["do"]: maps1_in_memory = normalize( maps1_in_memory, @@ -102,7 +107,6 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix[idxs] = sub_distance_matrix else: - assert False, "Not implemented" distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index e12dab0..8d3d9f5 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -54,15 +54,12 @@ def run(config): } assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 - if len(map_to_map_distances_low_memory) > 0: map_to_map_distances = map_to_map_distances_low_memory - low_memory_mode = True - else: - low_memory_mode = False - if not low_memory_mode: - n_pix = config["data"]["n_pix"] + do_low_memory_mode = config["analysis"]["low_memory"]["do"] + + n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] @@ -81,7 +78,7 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - if low_memory_mode: + if do_low_memory_mode: maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) else: maps_gt_flat = torch.from_numpy( diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index b2fa933..22046c3 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -158,6 +158,7 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: "chunk_size_submission": Number, "chunk_size_gt": Number, "normalize": dict, + "low_memory": dict, } validate_generic_config(config_analysis, keys_and_types) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 0d951dc..fd5990c 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -21,7 +21,9 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 - chunk_size_low_memory: 10 + low_memory: + do: false + chunk_size_low_memory: 10 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index e02e271..35ef921 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -21,7 +21,9 @@ analysis: - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 - chunk_size_low_memory: 10 + low_memory: + do: false + chunk_size_low_memory: null normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml index 49dac78..4cfb346 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml @@ -22,6 +22,9 @@ analysis: chunk_size_submission: 80 chunk_size_gt: 190 chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: null normalize: do: false method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index b73b88f..d37e4ac 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -21,7 +21,9 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 - chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 normalize: do: false method: median_zscore diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 957a9b1..f0d65b7 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -7,11 +7,11 @@ def test_run_map2map_pipeline(): for config_fname, config_fname_low_memory in zip( [ "tests/config_files/test_config_map_to_map.yaml", - "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", ], [ "tests/config_files/test_config_map_to_map_low_memory.yaml", - "tests/config_files/test_config_map_to_map_low_memory.yaml", + "tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml", ], ): args = OmegaConf.create({"config": config_fname}) From e472372eb2418940801825d400b753b8b7be1343 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 19:53:56 -0400 Subject: [PATCH 26/30] tests passing for low memory off --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 7 ++++++- tests/config_files/test_config_map_to_map.yaml | 2 +- .../test_config_map_to_map_nomask_nonormalize.yaml | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 451ca5f..73f2502 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -83,7 +83,7 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): maps2 = normalize( maps2, method=self.config["analysis"]["normalize"]["method"] ) - if True: # self.do_low_memory_mode: + if self.do_low_memory_mode: self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): @@ -409,6 +409,11 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + idxs = torch.arange(len(maps1)) + return self.get_sub_distance_matrix(maps1, maps2, idxs) + @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index fd5990c..689eb82 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -23,7 +23,7 @@ analysis: chunk_size_gt: 190 low_memory: do: false - chunk_size_low_memory: 10 + chunk_size_low_memory: null normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index d37e4ac..2a9a3a8 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -22,8 +22,8 @@ analysis: chunk_size_submission: 80 chunk_size_gt: 190 low_memory: - do: true - chunk_size_low_memory: 10 + do: false + chunk_size_low_memory: null normalize: do: false method: median_zscore From 3077f6ae89e5fd822ab606352db1792db803438b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:23:29 -0400 Subject: [PATCH 27/30] tests passing. time to delete separate low memory --- .../_map_to_map/map_to_map_distance.py | 42 +++++++++++++++++-- .../_map_to_map/map_to_map_pipeline.py | 2 - ...config_map_to_map_low_memory_subbatch.yaml | 31 ++++++++++++++ tests/test_map_to_map.py | 31 ++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 73f2502..eec6828 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -397,8 +397,9 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) if self.config["data"]["mask"]["do"]: - maps_gt_flat_cube[:, self.mask] = maps_gt_flat + maps_gt_flat_cube[:, self.mask] = maps_gt_flat[:] maps_user_flat_cube[:, self.mask] = maps_user_flat + else: maps_gt_flat_cube = maps_gt_flat maps_user_flat_cube = maps_user_flat @@ -411,8 +412,43 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): @override def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - idxs = torch.arange(len(maps1)) - return self.get_sub_distance_matrix(maps1, maps2, idxs) + """Compute the distance matrix between two sets of maps.""" + if self.config["data"]["mask"]["do"]: + maps2 = maps2[:, self.mask] + else: + maps2 = maps2.reshape(len(maps2), -1) + + if self.config["analysis"]["normalize"]["do"]: + maps2 = normalize( + maps2, method=self.config["analysis"]["normalize"]["method"] + ) + if self.chunk_size_low_memory is None: + self.n_chunks_low_memory = 1 + else: + self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory + distance_matrix = torch.empty(len(maps1), len(maps2)) + for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): + maps1_in_memory = maps1[idxs] + + if self.config["data"]["mask"]["do"]: + maps1_in_memory = maps1_in_memory[:].reshape(len(idxs), -1)[ + :, self.mask + ] + + else: + maps1_in_memory = maps1_in_memory.reshape(len(maps1_in_memory), -1) + if self.config["analysis"]["normalize"]["do"]: + maps1_in_memory = normalize( + maps1_in_memory, + method=self.config["analysis"]["normalize"]["method"], + ) + sub_distance_matrix = self.get_sub_distance_matrix( + maps1_in_memory, + maps2, + idxs, + ) + distance_matrix[idxs] = sub_distance_matrix + return distance_matrix @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 8d3d9f5..c85b5d1 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -89,8 +89,6 @@ def run(config): for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) - print("maps_gt_flat", maps_gt_flat.shape) - print("maps_user_flat", maps_user_flat.shape) map_to_map_distance.distance_matrix_precomputation( maps_gt_flat, maps_user_flat diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml new file mode 100644 index 0000000..7b02d2e --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml @@ -0,0 +1,31 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: true + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 + normalize: + do: true + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index f0d65b7..492610a 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -4,6 +4,37 @@ def test_run_map2map_pipeline(): + for config_fname, config_fname_low_memory in zip( + [ + "tests/config_files/test_config_map_to_map.yaml", + ], + [ + "tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml", + ], + ): + args = OmegaConf.create({"config": config_fname}) + results_dict = run_map2map_pipeline.main(args) + + args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric]["computed_assets"]["fsc_matrix"], + ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric]["computed_assets"][ + "fraction_nyquist" + ], + ) + np.allclose( + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric]["cost_matrix"].values, + ) + for config_fname, config_fname_low_memory in zip( [ "tests/config_files/test_config_map_to_map.yaml", From cb7085e6569e6222e0ea665a2785d595290e56d6 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:30:38 -0400 Subject: [PATCH 28/30] all tests passing --- .../_map_to_map/map_to_map_distance.py | 14 +++++---- ...ow_memory_subbatch_nomask_nonormalize.yaml | 31 +++++++++++++++++++ tests/test_map_to_map.py | 2 ++ 3 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index eec6828..e70817e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -47,11 +47,12 @@ def __init__(self, config): self.chunk_size_low_memory = self.config["analysis"]["low_memory"][ "chunk_size_low_memory" ] - self.mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) + if self.config["data"]["mask"]["do"]: + self.mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -142,7 +143,8 @@ def get_distance(self, map1, map2, global_store_of_running_results): raise NotImplementedError( f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." ) - map1 = map1[self.mask] + if self.config["data"]["mask"]["do"]: + map1 = map1[self.mask] return self.compute_cost(map1, map2) diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml new file mode 100644 index 0000000..020f58f --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml @@ -0,0 +1,31 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: dummy-string +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 + normalize: + do: false + method: dummy-string +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 492610a..2c607e7 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -7,9 +7,11 @@ def test_run_map2map_pipeline(): for config_fname, config_fname_low_memory in zip( [ "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", ], [ "tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml", + "tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml", ], ): args = OmegaConf.create({"config": config_fname}) From 1f4057e29357da659c24e03c36eb296c17816a68 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:37:35 -0400 Subject: [PATCH 29/30] remove low memory versions --- .../_map_to_map/map_to_map_distance.py | 149 ------------------ .../_map_to_map/map_to_map_pipeline.py | 23 --- .../test_config_map_to_map_low_memory.yaml | 30 ---- ..._to_map_low_memory_nomask_nonormalize.yaml | 31 ---- tests/test_map_to_map.py | 35 ---- 5 files changed, 268 deletions(-) delete mode 100644 tests/config_files/test_config_map_to_map_low_memory.yaml delete mode 100644 tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e70817e..5b9dc8e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -122,47 +122,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} -class MapToMapDistanceLowMemory(MapToMapDistance): - """General class for map-to-map distance metrics that require low memory.""" - - def __init__(self, config): - super().__init__(config) - self.config = config - - def compute_cost(self, map_1, map_2): - raise NotImplementedError() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map1 = map1.flatten() - if self.config["analysis"]["normalize"]["do"]: - if self.config["analysis"]["normalize"]["method"] == "median_zscore": - map1 -= map1.median() - map1 /= map1.std() - else: - raise NotImplementedError( - f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." - ) - if self.config["data"]["mask"]["do"]: - map1 = map1[self.mask] - - return self.compute_cost(map1, map2) - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - return cost_matrix - - def norm2(map1, map2): return torch.norm(map1 - map2) ** 2 @@ -178,17 +137,6 @@ def get_distance(self, map1, map2): return norm2(map1, map2) -class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): - """L2 distance norm""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return norm2(map1, map2) - - def correlation(map1, map2): return (map1 * map2).sum() @@ -206,17 +154,6 @@ def get_distance(self, map1, map2): return correlation(map1, map2) -class CorrelationLowMemory(MapToMapDistanceLowMemory): - """Correlation.""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return correlation(map1, map2) - - def compute_bioem3d_cost(map1, map2): """ Compute the cost between two maps using the BioEM cost function in 3D. @@ -272,17 +209,6 @@ def get_distance(self, map1, map2): return compute_bioem3d_cost(map1, map2) -class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): - """BioEM 3D distance.""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return compute_bioem3d_cost(map1, map2) - - def fourier_shell_correlation( x: torch.Tensor, y: torch.Tensor, @@ -457,69 +383,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first -class FSCDistanceLowMemory(MapToMapDistance): - """Fourier Shell Correlation distance.""" - - def __init__(self, config): - super().__init__(config) - self.n_pix = self.config["data"]["n_pix"] - self.config = config - - def compute_cost(self, map_1, map_2): - raise NotImplementedError() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map_gt_flat = map1 = map1.flatten() - map_gt_flat_cube = torch.zeros(self.n_pix**3) - if self.config["data"]["mask"]["do"]: - map_gt_flat = map_gt_flat[self.mask] - map_gt_flat_cube[self.mask] = map_gt_flat - else: - map_gt_flat_cube = map_gt_flat - - corr_vector = fourier_shell_correlation( - map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), - map2.reshape(self.n_pix, self.n_pix, self.n_pix), - ) - dist = 1 - corr_vector.mean() # TODO: spectral cutoff - self.stored_computed_assets = {"corr_vector": corr_vector} - return dist - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - fsc_matrix = torch.zeros( - len(maps_gt_flat), len(maps_user_flat), self.n_pix // 2 - ) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - fsc_matrix[idx_gt, idx_user] = self.stored_computed_assets[ - "corr_vector" - ] - self.stored_computed_assets = {"fsc_matrix": fsc_matrix} - return cost_matrix - - @override - def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - """ - Return any computed assets that are needed for (downstream) analysis. - - Notes - ----- - The FSC matrix is stored in the computed assets. - Must run get_distance_matrix first. - """ - return self.stored_computed_assets - - class FSCResDistance(MapToMapDistance): """FSC Resolution distance. @@ -555,15 +418,3 @@ def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} return units_Angstroms[res_fsc_half] - - -class FSCResDistanceLowMemory(FSCResDistance): - """FSC Resolution distance. - - The resolution at which the Fourier Shell Correlation reaches 0.5. - Built on top of the FSCDistance class. This needs to be run first and store the FSC matrix in the computed assets. - """ - - def __init__(self, config): - super().__init__(config) - self.fsc_label = "fsc_low_memory" diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index c85b5d1..c281496 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -7,15 +7,10 @@ from .._map_to_map.map_to_map_distance import ( GT_Dataset, FSCDistance, - FSCDistanceLowMemory, Correlation, - CorrelationLowMemory, L2DistanceNorm, - L2DistanceNormLowMemory, BioEM3dDistance, - BioEM3dDistanceLowMemory, FSCResDistance, - FSCResDistanceLowMemory, ) @@ -27,14 +22,6 @@ "res": FSCResDistance, } -AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY = { - "corr_low_memory": CorrelationLowMemory, - "l2_low_memory": L2DistanceNormLowMemory, - "bioem_low_memory": BioEM3dDistanceLowMemory, - "fsc_low_memory": FSCDistanceLowMemory, - "res_low_memory": FSCResDistanceLowMemory, -} - def run(config): """ @@ -47,16 +34,6 @@ def run(config): if distance_label in config["analysis"]["metrics"] } - map_to_map_distances_low_memory = { - distance_label: distance_class(config) - for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY.items() - if distance_label in config["analysis"]["metrics"] - } - - assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 - if len(map_to_map_distances_low_memory) > 0: - map_to_map_distances = map_to_map_distances_low_memory - do_low_memory_mode = config["analysis"]["low_memory"]["do"] n_pix = config["data"]["n_pix"] diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml deleted file mode 100644 index 35ef921..0000000 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ /dev/null @@ -1,30 +0,0 @@ -data: - n_pix: 16 - psize: 30.044 - submission: - fname: tests/data/dataset_2_submissions/submission_1000.pt - volume_key: volumes - metadata_key: populations - label_key: id - ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy - metadata: tests/data/Ground_truth/test_metadata_10.csv - mask: - do: true - volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc -analysis: - metrics: - - l2_low_memory - - corr_low_memory - - bioem_low_memory - - fsc_low_memory - - res_low_memory - chunk_size_submission: 80 - chunk_size_gt: 190 - low_memory: - do: false - chunk_size_low_memory: null - normalize: - do: true - method: median_zscore -output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml deleted file mode 100644 index 4cfb346..0000000 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ /dev/null @@ -1,31 +0,0 @@ -data: - n_pix: 16 - psize: 30.044 - submission: - fname: tests/data/dataset_2_submissions/submission_1000.pt - volume_key: volumes - metadata_key: populations - label_key: id - ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy - metadata: tests/data/Ground_truth/test_metadata_10.csv - mask: - do: false - volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc -analysis: - metrics: - - l2_low_memory - - corr_low_memory - - bioem_low_memory - - fsc_low_memory - - res_low_memory - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 - low_memory: - do: true - chunk_size_low_memory: null - normalize: - do: false - method: median_zscore -output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 2c607e7..907e6d3 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -36,38 +36,3 @@ def test_run_map2map_pipeline(): results_dict[metric]["cost_matrix"].values, results_dict_low_memory[metric]["cost_matrix"].values, ) - - for config_fname, config_fname_low_memory in zip( - [ - "tests/config_files/test_config_map_to_map.yaml", - "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", - ], - [ - "tests/config_files/test_config_map_to_map_low_memory.yaml", - "tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml", - ], - ): - args = OmegaConf.create({"config": config_fname}) - results_dict = run_map2map_pipeline.main(args) - - args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) - results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) - for metric in ["fsc", "corr", "l2", "bioem"]: - if metric == "fsc": - np.allclose( - results_dict[metric]["computed_assets"]["fsc_matrix"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fsc_matrix" - ], - ) - elif metric == "res": - np.allclose( - results_dict[metric]["computed_assets"]["fraction_nyquist"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fraction_nyquist" - ], - ) - np.allclose( - results_dict[metric]["cost_matrix"].values, - results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, - ) From b33a6a06f707209efc3c9ec7581b6d4c8fb3e36b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:57:25 -0400 Subject: [PATCH 30/30] update configs to remove low memory metrics --- .../data/_validation/config_validators.py | 1 + .../data/_validation/output_validators.py | 18 ------------------ 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 22046c3..83083ed 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -151,6 +151,7 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: chunk_size_submission: int, is the chunk size for the submission volume. chunk_size_gt: int, is the chunk size for the ground truth volume. normalize: dict, is the normalize part of the analysis part of the config. + low_memory: dict, is the low memory part of the analysis part of the config. # TODO: add validation for low_memory """ # noqa: E501 keys_and_types = { diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index ada4492..9f76a6d 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -19,27 +19,18 @@ class MapToMapResultsValidator: config: dict, input config dictionary. user_submitted_populations: torch.Tensor, user submitted populations, which sum to 1. corr: dict, correlation results. - corr_low_memory: dict, correlation results in low memory mode. l2: dict, L2 results. - l2_low_memory: dict, L2 results in low memory mode. bioem: dict, BioEM results. - bioem_low_memory: dict, BioEM results in low memory mode. fsc: dict, FSC results. - fsc_low_memory: dict, FSC results in low memory mode. """ config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None - corr_low_memory: Optional[dict] = None l2: Optional[dict] = None - l2_low_memory: Optional[dict] = None bioem: Optional[dict] = None - bioem_low_memory: Optional[dict] = None fsc: Optional[dict] = None - fsc_low_memory: Optional[dict] = None res: Optional[dict] = None - res_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_mtm(self.config) @@ -150,10 +141,6 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - fsc_low_memory: dict, FSC distance results in low memory mode. - bioem_low_memory: dict, BioEM distance results in low memory mode. - l2_low_memory: dict, L2 distance results in low memory mode. - corr_low_memory: dict, correlation distance results in low memory mode. """ config: dict @@ -164,11 +151,6 @@ class DistributionToDistributionResultsValidator: res: Optional[dict] = None l2: Optional[dict] = None corr: Optional[dict] = None - fsc_low_memory: Optional[dict] = None - bioem_low_memory: Optional[dict] = None - res_low_memory: Optional[dict] = None - l2_low_memory: Optional[dict] = None - corr_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_disttodist(self.config)