From 356696c11b2f23eeed2da21f78f45609c3e53bb2 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 3 Sep 2024 20:16:30 -0400 Subject: [PATCH 01/26] low memory pasisng test --- .../_map_to_map/map_to_map_distance.py | 36 ++++++++++++++++ .../_map_to_map/map_to_map_pipeline.py | 41 +++++++++++++++---- .../data/_validation/output_validators.py | 1 + .../config_files/test_config_map_to_map.yaml | 13 +++--- 4 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e253d25..afc496f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -60,6 +60,42 @@ def get_distance(self, map1, map2): return self.compute_cost_l2(map1, map2) +class CorrelationLowMemory(MapToMapDistance): + """Correlation distance. + + Not technically a distance metric, but a similarity.""" + + def __init__(self, config): + super().__init__(config) + + def compute_cost_corr(self, map_1, map_2): + return (map_1 * map_2).sum() + + @override + def get_distance(self, map1, map2, global_store_of_running_results): + map1 = map1.flatten() + map1 -= map1.median() + map1 /= map1.std() + map1 = map1[global_store_of_running_results["mask"]] + + return self.compute_cost_corr(map1, map2) + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + maps_gt_flat = maps1 + maps_user_flat = maps2 + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + for idx_gt in range(len(maps_gt_flat)): + for idx_user in range(len(maps_user_flat)): + cost_matrix[idx_gt, idx_user] = self.get_distance( + maps_gt_flat[idx_gt], + maps_user_flat[idx_user], + global_store_of_running_results, + ) + + return cost_matrix + + class Correlation(MapToMapDistance): """Correlation distance. diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index ffc02df..df63caf 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,5 +1,6 @@ import mrcfile import pandas as pd +import numpy as np import pickle import torch @@ -7,6 +8,7 @@ from .._map_to_map.map_to_map_distance import ( FSCDistance, Correlation, + CorrelationLowMemory, L2DistanceSum, BioEM3dDistance, FSCResDistance, @@ -16,6 +18,7 @@ AVAILABLE_MAP2MAP_DISTANCES = { "fsc": FSCDistance, "corr": Correlation, + "corr_low_memory": CorrelationLowMemory, "l2": L2DistanceSum, "bioem": BioEM3dDistance, "res": FSCResDistance, @@ -32,7 +35,7 @@ def run(config): for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() } - n_pix = config["data"]["n_pix"] + # n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] @@ -51,28 +54,50 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( - -1, n_pix**3 - ) + # maps_gt_flat = np.load(config["data"]["ground_truth"]["volumes"], mmap_mode='r+')#.reshape(-1, n_pix**3) + from torch.utils.data import Dataset + + class GT_Dataset(Dataset): + def __init__(self, npy_file): + self.npy_file = npy_file + self.data = np.load(npy_file, mmap_mode="r+") + + self.shape = self.data.shape + self._dim = len(self.data.shape) + + def dim(self): + return self._dim + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + sample = self.data[idx] + return torch.from_numpy(sample.copy()) + + maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) if config["data"]["mask"]["do"]: mask = ( mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() ) - maps_gt_flat = maps_gt_flat[:, mask] + # maps_gt_flat = maps_gt_flat[:, mask] maps_user_flat = maps_user_flat[:, mask] else: - maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) + # maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) if config["analysis"]["normalize"]["do"]: if config["analysis"]["normalize"]["method"] == "median_zscore": - maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) + # maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values + # maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) computed_assets = {} + results_dict["mask"] = mask for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 9f76a6d..b233a30 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -27,6 +27,7 @@ class MapToMapResultsValidator: config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None + corr_low_memory: Optional[dict] = None l2: Optional[dict] = None bioem: Optional[dict] = None fsc: Optional[dict] = None diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 7dfa7e9..d1eca4d 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,18 +7,19 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc analysis: metrics: - - l2 - - corr - - bioem - - fsc - - res + # - l2 + # - corr + - corr_low_memory + # - bioem + # - fsc + # - res chunk_size_submission: 80 chunk_size_gt: 190 normalize: From 5e3591f4a9be4739b0f5ff113fbb1f54d7607945 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 18:03:28 -0400 Subject: [PATCH 02/26] tests passing for low memory mode t and f --- .../_map_to_map/map_to_map_distance.py | 22 +++++++++ .../_map_to_map/map_to_map_pipeline.py | 48 +++++++------------ .../config_files/test_config_map_to_map.yaml | 6 +-- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index afc496f..649ca4b 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -4,6 +4,28 @@ from typing_extensions import override import mrcfile import numpy as np +from torch.utils.data import Dataset + + +class GT_Dataset(Dataset): + def __init__(self, npy_file): + self.npy_file = npy_file + self.data = np.load(npy_file, mmap_mode="r+") + + self.shape = self.data.shape + self._dim = len(self.data.shape) + + def dim(self): + return self._dim + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + sample = self.data[idx] + return torch.from_numpy(sample.copy()) class MapToMapDistance: diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index df63caf..8923833 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,11 +1,11 @@ import mrcfile import pandas as pd -import numpy as np import pickle import torch from ..data._validation.output_validators import MapToMapResultsValidator from .._map_to_map.map_to_map_distance import ( + GT_Dataset, FSCDistance, Correlation, CorrelationLowMemory, @@ -35,7 +35,9 @@ def run(config): for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() } - # n_pix = config["data"]["n_pix"] + low_memory_mode = False + if not low_memory_mode: + n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] @@ -54,45 +56,31 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - # maps_gt_flat = np.load(config["data"]["ground_truth"]["volumes"], mmap_mode='r+')#.reshape(-1, n_pix**3) - from torch.utils.data import Dataset - - class GT_Dataset(Dataset): - def __init__(self, npy_file): - self.npy_file = npy_file - self.data = np.load(npy_file, mmap_mode="r+") - - self.shape = self.data.shape - self._dim = len(self.data.shape) - - def dim(self): - return self._dim - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - sample = self.data[idx] - return torch.from_numpy(sample.copy()) - - maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) + if low_memory_mode: + maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) + else: + maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( + -1, n_pix**3 + ) if config["data"]["mask"]["do"]: mask = ( mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() ) - # maps_gt_flat = maps_gt_flat[:, mask] + if not low_memory_mode: + maps_gt_flat = maps_gt_flat[:, mask] maps_user_flat = maps_user_flat[:, mask] else: - # maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) + if not low_memory_mode: + maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) if config["analysis"]["normalize"]["do"]: if config["analysis"]["normalize"]["method"] == "median_zscore": - # maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - # maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) + if not low_memory_mode: + maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values + if not low_memory_mode: + maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index d1eca4d..73edae3 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -15,8 +15,8 @@ data: analysis: metrics: # - l2 - # - corr - - corr_low_memory + - corr + # - corr_low_memory # - bioem # - fsc # - res From 7c8e24755040a9aa57f427b03426dc168435a931 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:17:45 -0400 Subject: [PATCH 03/26] l2, corr, bioem working --- .../_map_to_map/map_to_map_distance.py | 245 +++++++++++++++++- .../_map_to_map/map_to_map_pipeline.py | 31 ++- .../data/_validation/output_validators.py | 17 ++ .../config_files/test_config_map_to_map.yaml | 9 +- .../test_config_map_to_map_low_memory.yaml | 27 ++ tests/test_map_to_map.py | 5 + 6 files changed, 313 insertions(+), 21 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 649ca4b..46280fc 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -55,6 +55,39 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} +class MapToMapDistanceLowMemory(MapToMapDistance): + """General class for map-to-map distance metrics that require low memory.""" + + def __init__(self, config): + super().__init__(config) + + def compute_cost(self, map_1, map_2): + raise NotImplementedError() + + @override + def get_distance(self, map1, map2, global_store_of_running_results): + map1 = map1.flatten() + map1 -= map1.median() + map1 /= map1.std() + map1 = map1[global_store_of_running_results["mask"]] + + return self.compute_cost(map1, map2) + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + maps_gt_flat = maps1 + maps_user_flat = maps2 + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + for idx_gt in range(len(maps_gt_flat)): + for idx_user in range(len(maps_user_flat)): + cost_matrix[idx_gt, idx_user] = self.get_distance( + maps_gt_flat[idx_gt], + maps_user_flat[idx_user], + global_store_of_running_results, + ) + return cost_matrix + + class L2DistanceNorm(MapToMapDistance): """L2 distance norm""" @@ -66,24 +99,19 @@ def get_distance(self, map1, map2): return torch.norm(map1 - map2) ** 2 -class L2DistanceSum(MapToMapDistance): - """L2 distance. - - Computed by summing the squared differences between the two maps.""" +class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): + """L2 distance norm""" def __init__(self, config): super().__init__(config) - def compute_cost_l2(self, map_1, map_2): - return ((map_1 - map_2) ** 2).sum() - @override - def get_distance(self, map1, map2): - return self.compute_cost_l2(map1, map2) + def compute_cost(self, map1, map2): + return torch.norm(map1 - map2) ** 2 -class CorrelationLowMemory(MapToMapDistance): - """Correlation distance. +class CorrelationLowMemoryCheck(MapToMapDistance): + """Correlation. Not technically a distance metric, but a similarity.""" @@ -118,8 +146,12 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): return cost_matrix +def correlation(map1, map2): + return (map1 * map2).sum() + + class Correlation(MapToMapDistance): - """Correlation distance. + """Correlation. Not technically a distance metric, but a similarity.""" @@ -134,6 +166,61 @@ def get_distance(self, map1, map2): return self.compute_cost_corr(map1, map2) +class CorrelationLowMemory(MapToMapDistanceLowMemory): + """Correlation.""" + + def __init__(self, config): + super().__init__(config) + + @override + def compute_cost(self, map1, map2): + return correlation(map1, map2) + + +def compute_bioem3d_cost(map1, map2): + """ + Compute the cost between two maps using the BioEM cost function in 3D. + + Notes + ----- + See Eq. 10 in 10.1016/j.jsb.2013.10.006 + + Parameters + ---------- + map1 : torch.Tensor + shape (n_pix,n_pix,n_pix) + map2 : torch.Tensor + shape (n_pix,n_pix,n_pix) + + Returns + ------- + cost : torch.Tensor + shape (1,) + """ + m1, m2 = map1.reshape(-1), map2.reshape(-1) + co = m1.sum() + cc = m2.sum() + coo = m1.pow(2).sum() + ccc = m2.pow(2).sum() + coc = (m1 * m2).sum() + + N = len(m1) + + t1 = 2 * torch.pi * math.exp(1) + t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc + t3 = (N - 2) * (N * ccc - cc * cc) + + smallest_float = torch.finfo(m1.dtype).tiny + log_prob = ( + 0.5 * torch.pi + + math.log(t1) * (1 - N / 2) + + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) + + t3.clamp(smallest_float).log() * (N / 2 - 2) + ) + cost = -log_prob + return cost + + class BioEM3dDistance(MapToMapDistance): """BioEM 3D distance.""" @@ -193,6 +280,83 @@ def get_distance(self, map1, map2): return self.compute_bioem3d_cost(map1, map2) +class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): + """BioEM 3D distance.""" + + def __init__(self, config): + super().__init__(config) + + @override + def compute_cost(self, map1, map2): + return compute_bioem3d_cost(map1, map2) + + +def fourier_shell_correlation( + x: torch.Tensor, + y: torch.Tensor, + dim: Sequence[int] = (-3, -2, -1), + normalize: bool = True, + max_k: Optional[int] = None, +): + """Computes Fourier Shell / Ring Correlation (FSC) between x and y. + + Parameters + ---------- + x : torch.Tensor + First input tensor. + y : torch.Tensor + Second input tensor. + dim : Tuple[int, ...] + Dimensions over which to take the Fourier transform. + normalize : bool + Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). + Note that when `normalize=False`, we still divide by the number of elements in each shell. + max_k : int + The maximum shell to compute the correlation for. + + Returns + ------- + torch.Tensor + The correlation between x and y for each Fourier shell. + """ # noqa: E501 + batch_shape = x.shape[: -len(dim)] + + freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] + freq_total = ( + torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) + ) + + x_f = torch.fft.fftn(x, dim=dim) + y_f = torch.fft.fftn(y, dim=dim) + + n = min(x.shape[d] for d in dim) + + if max_k is None: + max_k = n // 2 + + result = x.new_zeros(batch_shape + (max_k,)) + + for i in range(1, max_k + 1): + mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) + x_ri = x_f[..., mask] + y_fi = y_f[..., mask] + + if x.is_cuda: + c_i = torch.linalg.vecdot(x_ri, y_fi).real + else: + # vecdot currently bugged on CPU for torch 2.0 in some configurations + c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real + + if normalize: + c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) + else: + c_i /= x_ri.shape[-1] + + result[..., i - 1] = c_i + + return result + + class FSCDistance(MapToMapDistance): """Fourier Shell Correlation distance. @@ -317,6 +481,59 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first +class FSCDistanceLowMemory(MapToMapDistance): + """Fourier Shell Correlation distance.""" + + def __init__(self, config): + super().__init__(config) + self.npix = self.config["data"]["n_pix"] + + def compute_cost(self, map_1, map_2): + raise NotImplementedError() + + @override + def get_distance(self, map1, map2, global_store_of_running_results): + maps_gt_flat = map1 = map1.flatten() + map1 -= map1.median() + map1 /= map1.std() + maps_gt_flat_cube = torch.zeros(self.n_pix**3) + map1 = map1[global_store_of_running_results["mask"]] + maps_gt_flat_cube[:, global_store_of_running_results["mask"]] = maps_gt_flat + + corr_vector = fourier_shell_correlation( + maps_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), + map2.reshape(self.n_pix, self.n_pix, self.n_pix), + ) + dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff + self.stored_computed_assets["corr_vector"] = corr_vector + return dist + + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + maps_gt_flat = maps1 + maps_user_flat = maps2 + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + fsc_matrix = torch.zeros( + len(maps_gt_flat), len(maps_user_flat), self.n_pix // 2 + ) + for idx_gt in range(len(maps_gt_flat)): + for idx_user in range(len(maps_user_flat)): + cost_matrix[idx_gt, idx_user] = self.get_distance( + maps_gt_flat[idx_gt], + maps_user_flat[idx_user], + global_store_of_running_results, + ) + fsc_matrix[idx_gt, idx_user] = self.stored_computed_assets[ + "corr_vector" + ] + self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + return cost_matrix + + @override + def get_computed_assets(self, maps1, maps2, global_store_of_running_results): + return self.stored_computed_assets # must run get_distance_matrix first + + class FSCResDistance(MapToMapDistance): """FSC Resolution distance. @@ -351,3 +568,7 @@ def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} return units_Angstroms[res_fsc_half] + + +class FSCResDistanceLowMemory(MapToMapDistance): + pass diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 8923833..7028eb9 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -9,8 +9,10 @@ FSCDistance, Correlation, CorrelationLowMemory, - L2DistanceSum, + L2DistanceNorm, + L2DistanceNormLowMemory, BioEM3dDistance, + BioEM3dDistanceLowMemory, FSCResDistance, ) @@ -18,12 +20,19 @@ AVAILABLE_MAP2MAP_DISTANCES = { "fsc": FSCDistance, "corr": Correlation, - "corr_low_memory": CorrelationLowMemory, - "l2": L2DistanceSum, + "l2": L2DistanceNorm, "bioem": BioEM3dDistance, "res": FSCResDistance, } +AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY = { + "corr_low_memory": CorrelationLowMemory, + "l2_low_memory": L2DistanceNormLowMemory, + "bioem_low_memory": BioEM3dDistanceLowMemory, + "fsc_low_memory": FSCDistance, + "res_low_memory": FSCResDistance, +} + def run(config): """ @@ -33,9 +42,23 @@ def run(config): map_to_map_distances = { distance_label: distance_class(config) for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() + if distance_label in config["analysis"]["metrics"] + } + + map_to_map_distances_low_memory = { + distance_label: distance_class(config) + for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY.items() + if distance_label in config["analysis"]["metrics"] } - low_memory_mode = False + assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 + + if len(map_to_map_distances_low_memory) > 0: + map_to_map_distances = map_to_map_distances_low_memory + low_memory_mode = True + else: + low_memory_mode = False + if not low_memory_mode: n_pix = config["data"]["n_pix"] diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index b233a30..ada4492 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -19,9 +19,13 @@ class MapToMapResultsValidator: config: dict, input config dictionary. user_submitted_populations: torch.Tensor, user submitted populations, which sum to 1. corr: dict, correlation results. + corr_low_memory: dict, correlation results in low memory mode. l2: dict, L2 results. + l2_low_memory: dict, L2 results in low memory mode. bioem: dict, BioEM results. + bioem_low_memory: dict, BioEM results in low memory mode. fsc: dict, FSC results. + fsc_low_memory: dict, FSC results in low memory mode. """ config: dict @@ -29,9 +33,13 @@ class MapToMapResultsValidator: corr: Optional[dict] = None corr_low_memory: Optional[dict] = None l2: Optional[dict] = None + l2_low_memory: Optional[dict] = None bioem: Optional[dict] = None + bioem_low_memory: Optional[dict] = None fsc: Optional[dict] = None + fsc_low_memory: Optional[dict] = None res: Optional[dict] = None + res_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_mtm(self.config) @@ -142,6 +150,10 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. + fsc_low_memory: dict, FSC distance results in low memory mode. + bioem_low_memory: dict, BioEM distance results in low memory mode. + l2_low_memory: dict, L2 distance results in low memory mode. + corr_low_memory: dict, correlation distance results in low memory mode. """ config: dict @@ -152,6 +164,11 @@ class DistributionToDistributionResultsValidator: res: Optional[dict] = None l2: Optional[dict] = None corr: Optional[dict] = None + fsc_low_memory: Optional[dict] = None + bioem_low_memory: Optional[dict] = None + res_low_memory: Optional[dict] = None + l2_low_memory: Optional[dict] = None + corr_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_disttodist(self.config) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 73edae3..7dfa7e9 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -14,12 +14,11 @@ data: volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc analysis: metrics: - # - l2 + - l2 - corr - # - corr_low_memory - # - bioem - # - fsc - # - res + - bioem + - fsc + - res chunk_size_submission: 80 chunk_size_gt: 190 normalize: diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml new file mode 100644 index 0000000..41f4a1e --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -0,0 +1,27 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: true + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2_low_memory + - corr_low_memory + - bioem_low_memory + # - fsc_low_memory + # - res_low_memory + chunk_size_submission: 80 + chunk_size_gt: 190 + normalize: + do: true + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index e31f29f..301ee12 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -7,3 +7,8 @@ def test_run_map2map_pipeline(): {"config": "tests/config_files/test_config_map_to_map.yaml"} ) run_map2map_pipeline.main(args) + + args = OmegaConf.create( + {"config": "tests/config_files/test_config_map_to_map_low_memory.yaml"} + ) + run_map2map_pipeline.main(args) From 14805d7cd5d2953356927e3339bb8d2e4ebf1ced Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:24:46 -0400 Subject: [PATCH 04/26] fsc low memory working --- .../_map_to_map/map_to_map_distance.py | 14 +++++++------- .../_map_to_map/map_to_map_pipeline.py | 6 ++++-- .../test_config_map_to_map_low_memory.yaml | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 46280fc..fe54c9a 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -486,26 +486,26 @@ class FSCDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) - self.npix = self.config["data"]["n_pix"] + self.n_pix = self.config["data"]["n_pix"] def compute_cost(self, map_1, map_2): raise NotImplementedError() @override def get_distance(self, map1, map2, global_store_of_running_results): - maps_gt_flat = map1 = map1.flatten() + map_gt_flat = map1 = map1.flatten() map1 -= map1.median() map1 /= map1.std() - maps_gt_flat_cube = torch.zeros(self.n_pix**3) + map_gt_flat_cube = torch.zeros(self.n_pix**3) map1 = map1[global_store_of_running_results["mask"]] - maps_gt_flat_cube[:, global_store_of_running_results["mask"]] = maps_gt_flat + map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat corr_vector = fourier_shell_correlation( - maps_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), + map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), map2.reshape(self.n_pix, self.n_pix, self.n_pix), ) - dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff - self.stored_computed_assets["corr_vector"] = corr_vector + dist = 1 - corr_vector.mean() # TODO: spectral cutoff + self.stored_computed_assets = {"corr_vector": corr_vector} return dist @override diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 7028eb9..60d3f34 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -7,6 +7,7 @@ from .._map_to_map.map_to_map_distance import ( GT_Dataset, FSCDistance, + FSCDistanceLowMemory, Correlation, CorrelationLowMemory, L2DistanceNorm, @@ -14,6 +15,7 @@ BioEM3dDistance, BioEM3dDistanceLowMemory, FSCResDistance, + FSCResDistanceLowMemory, ) @@ -29,8 +31,8 @@ "corr_low_memory": CorrelationLowMemory, "l2_low_memory": L2DistanceNormLowMemory, "bioem_low_memory": BioEM3dDistanceLowMemory, - "fsc_low_memory": FSCDistance, - "res_low_memory": FSCResDistance, + "fsc_low_memory": FSCDistanceLowMemory, + "res_low_memory": FSCResDistanceLowMemory, } diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index 41f4a1e..0d62008 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -17,7 +17,7 @@ analysis: - l2_low_memory - corr_low_memory - bioem_low_memory - # - fsc_low_memory + - fsc_low_memory # - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 From 6f5e09eb2cfe55293c6b5d29d7104b41b1d99da7 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:42:27 -0400 Subject: [PATCH 05/26] tests passing checking identical output --- .../_commands/run_map2map_pipeline.py | 4 +--- tests/test_map_to_map.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/cryo_challenge/_commands/run_map2map_pipeline.py b/src/cryo_challenge/_commands/run_map2map_pipeline.py index ab36f7a..90db1aa 100644 --- a/src/cryo_challenge/_commands/run_map2map_pipeline.py +++ b/src/cryo_challenge/_commands/run_map2map_pipeline.py @@ -39,9 +39,7 @@ def main(args): warnexists(config["output"]) mkbasedir(os.path.dirname(config["output"])) - run(config) - - return + return run(config) if __name__ == "__main__": diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 301ee12..c1817d8 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -1,14 +1,27 @@ from omegaconf import OmegaConf from cryo_challenge._commands import run_map2map_pipeline +import numpy as np def test_run_map2map_pipeline(): args = OmegaConf.create( {"config": "tests/config_files/test_config_map_to_map.yaml"} ) - run_map2map_pipeline.main(args) + results_dict = run_map2map_pipeline.main(args) - args = OmegaConf.create( + args_low_memory = OmegaConf.create( {"config": "tests/config_files/test_config_map_to_map_low_memory.yaml"} ) - run_map2map_pipeline.main(args) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fsc_matrix" + ], + ) + np.allclose( + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, + ) From a7290aae5eca25d4bcea2c322a76ed155d48271e Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:47:12 -0400 Subject: [PATCH 06/26] all metrics implemented and test passing for matched results --- .../_map_to_map/map_to_map_distance.py | 34 ++++++++++++++++++- .../test_config_map_to_map_low_memory.yaml | 2 +- tests/test_map_to_map.py | 7 ++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index fe54c9a..b3bbf2c 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -571,4 +571,36 @@ def res_at_fsc_threshold(fscs, threshold=0.5): class FSCResDistanceLowMemory(MapToMapDistance): - pass + """FSC Resolution distance. + + The resolution at which the Fourier Shell Correlation reaches 0.5. + Built on top of the FSCDistance class. This needs to be run first and store the FSC matrix in the computed assets. + """ + + def __init__(self, config): + super().__init__(config) + + @override + def get_distance_matrix( + self, maps1, maps2, global_store_of_running_results + ): # custom method + # get fsc matrix + fourier_pixel_max = ( + self.config["data"]["n_pix"] // 2 + ) # TODO: check for odd psizes if this should be +1 + psize = self.config["data"]["psize"] + fsc_matrix = global_store_of_running_results["fsc_low_memory"][ + "computed_assets" + ]["fsc_matrix"] + units_Angstroms = ( + 2 * psize / (np.arange(1, fourier_pixel_max + 1) / fourier_pixel_max) + ) + + def res_at_fsc_threshold(fscs, threshold=0.5): + res_fsc_half = np.argmin(fscs > threshold, axis=-1) + fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] + return res_fsc_half, fraction_nyquist + + res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) + self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} + return units_Angstroms[res_fsc_half] diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index 0d62008..4eb6cd0 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -18,7 +18,7 @@ analysis: - corr_low_memory - bioem_low_memory - fsc_low_memory - # - res_low_memory + - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 normalize: diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index c1817d8..ed91815 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -21,6 +21,13 @@ def test_run_map2map_pipeline(): "fsc_matrix" ], ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fraction_nyquist" + ], + ) np.allclose( results_dict[metric]["cost_matrix"].values, results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, From 22ec5a5085cb8a8eaac1d54549ff6296dfa1bde1 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:50:09 -0400 Subject: [PATCH 07/26] all metrics implemented and test passing for matched results --- .../_map_to_map/map_to_map_distance.py | 31 +++---------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index b3bbf2c..369c47f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -543,6 +543,7 @@ class FSCResDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.fsc_label = "fsc" @override def get_distance_matrix( @@ -553,7 +554,7 @@ def get_distance_matrix( self.config["data"]["n_pix"] // 2 ) # TODO: check for odd psizes if this should be +1 psize = self.config["data"]["psize"] - fsc_matrix = global_store_of_running_results["fsc"]["computed_assets"][ + fsc_matrix = global_store_of_running_results[self.fsc_label]["computed_assets"][ "fsc_matrix" ] units_Angstroms = ( @@ -570,7 +571,7 @@ def res_at_fsc_threshold(fscs, threshold=0.5): return units_Angstroms[res_fsc_half] -class FSCResDistanceLowMemory(MapToMapDistance): +class FSCResDistanceLowMemory(FSCResDistance): """FSC Resolution distance. The resolution at which the Fourier Shell Correlation reaches 0.5. @@ -579,28 +580,4 @@ class FSCResDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) - - @override - def get_distance_matrix( - self, maps1, maps2, global_store_of_running_results - ): # custom method - # get fsc matrix - fourier_pixel_max = ( - self.config["data"]["n_pix"] // 2 - ) # TODO: check for odd psizes if this should be +1 - psize = self.config["data"]["psize"] - fsc_matrix = global_store_of_running_results["fsc_low_memory"][ - "computed_assets" - ]["fsc_matrix"] - units_Angstroms = ( - 2 * psize / (np.arange(1, fourier_pixel_max + 1) / fourier_pixel_max) - ) - - def res_at_fsc_threshold(fscs, threshold=0.5): - res_fsc_half = np.argmin(fscs > threshold, axis=-1) - fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1] - return res_fsc_half, fraction_nyquist - - res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) - self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} - return units_Angstroms[res_fsc_half] + self.fsc_label = "fsc_low_memory" From de3cddc1e0839b507acd60705f02a1b7f6fcc085 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 20:53:37 -0400 Subject: [PATCH 08/26] all metrics implemented and test passing for matched results --- .../_map_to_map/map_to_map_distance.py | 159 +----------------- 1 file changed, 3 insertions(+), 156 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 369c47f..70df449 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -110,42 +110,6 @@ def compute_cost(self, map1, map2): return torch.norm(map1 - map2) ** 2 -class CorrelationLowMemoryCheck(MapToMapDistance): - """Correlation. - - Not technically a distance metric, but a similarity.""" - - def __init__(self, config): - super().__init__(config) - - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() - map1 = map1[global_store_of_running_results["mask"]] - - return self.compute_cost_corr(map1, map2) - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - - return cost_matrix - - def correlation(map1, map2): return (map1 * map2).sum() @@ -158,12 +122,9 @@ class Correlation(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - @override def get_distance(self, map1, map2): - return self.compute_cost_corr(map1, map2) + return correlation(map1, map2) class CorrelationLowMemory(MapToMapDistanceLowMemory): @@ -227,57 +188,9 @@ class BioEM3dDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_bioem3d_cost(self, map1, map2): - """ - Compute the cost between two maps using the BioEM cost function in 3D. - - Notes - ----- - See Eq. 10 in 10.1016/j.jsb.2013.10.006 - - Parameters - ---------- - map1 : torch.Tensor - shape (n_pix,n_pix,n_pix) - map2 : torch.Tensor - shape (n_pix,n_pix,n_pix) - - Returns - ------- - cost : torch.Tensor - shape (1,) - """ - m1, m2 = map1.reshape(-1), map2.reshape(-1) - co = m1.sum() - cc = m2.sum() - coo = m1.pow(2).sum() - ccc = m2.pow(2).sum() - coc = (m1 * m2).sum() - - N = len(m1) - - t1 = 2 * torch.pi * math.exp(1) - t2 = ( - N * (ccc * coo - coc * coc) - + 2 * co * coc * cc - - ccc * co * co - - coo * cc * cc - ) - t3 = (N - 2) * (N * ccc - cc * cc) - - smallest_float = torch.finfo(m1.dtype).tiny - log_prob = ( - 0.5 * torch.pi - + math.log(t1) * (1 - N / 2) - + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) - + t3.clamp(smallest_float).log() * (N / 2 - 2) - ) - cost = -log_prob - return cost - @override def get_distance(self, map1, map2): - return self.compute_bioem3d_cost(map1, map2) + return compute_bioem3d_cost(map1, map2) class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): @@ -365,72 +278,6 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def fourier_shell_correlation( - self, - x: torch.Tensor, - y: torch.Tensor, - dim: Sequence[int] = (-3, -2, -1), - normalize: bool = True, - max_k: Optional[int] = None, - ): - """Computes Fourier Shell / Ring Correlation (FSC) between x and y. - - Parameters - ---------- - x : torch.Tensor - First input tensor. - y : torch.Tensor - Second input tensor. - dim : Tuple[int, ...] - Dimensions over which to take the Fourier transform. - normalize : bool - Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). - Note that when `normalize=False`, we still divide by the number of elements in each shell. - max_k : int - The maximum shell to compute the correlation for. - - Returns - ------- - torch.Tensor - The correlation between x and y for each Fourier shell. - """ # noqa: E501 - batch_shape = x.shape[: -len(dim)] - - freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] - freq_total = ( - torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) - ) - - x_f = torch.fft.fftn(x, dim=dim) - y_f = torch.fft.fftn(y, dim=dim) - - n = min(x.shape[d] for d in dim) - - if max_k is None: - max_k = n // 2 - - result = x.new_zeros(batch_shape + (max_k,)) - - for i in range(1, max_k + 1): - mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) - x_ri = x_f[..., mask] - y_fi = y_f[..., mask] - - if x.is_cuda: - c_i = torch.linalg.vecdot(x_ri, y_fi).real - else: - # vecdot currently bugged on CPU for torch 2.0 in some configurations - c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real - - if normalize: - c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) - else: - c_i /= x_ri.shape[-1] - - result[..., i - 1] = c_i - - return result - def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ Compute the cost between two maps using the Fourier Shell Correlation in 3D. @@ -443,7 +290,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) for idx in range(len(maps_gt_flat)): - corr_vector = self.fourier_shell_correlation( + corr_vector = fourier_shell_correlation( maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), ) From f20344e8b6ac36384f34baa8fa96f8fc69d9adad Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 21:10:08 -0400 Subject: [PATCH 09/26] flags for masking and not masking --- .../_map_to_map/map_to_map_distance.py | 40 +++++++++++++------ .../_map_to_map/map_to_map_pipeline.py | 4 ++ 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 70df449..4769528 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -60,6 +60,7 @@ class MapToMapDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.config = config def compute_cost(self, map_1, map_2): raise NotImplementedError() @@ -67,8 +68,14 @@ def compute_cost(self, map_1, map_2): @override def get_distance(self, map1, map2, global_store_of_running_results): map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() + if self.config["analysis"]["normalize"]["do"]: + if self.config["analysis"]["normalize"]["method"] == "median_zscore": + map1 -= map1.median() + map1 /= map1.std() + else: + raise NotImplementedError( + f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." + ) map1 = map1[global_store_of_running_results["mask"]] return self.compute_cost(map1, map2) @@ -308,14 +315,19 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): maps_user_flat = maps2 n_pix = self.config["data"]["n_pix"] maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat + + if self.config["data"]["mask"]["do"]: + mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) + maps_gt_flat_cube[:, mask] = maps_gt_flat + maps_user_flat_cube[:, mask] = maps_user_flat + else: + maps_gt_flat_cube = maps_gt_flat + maps_user_flat_cube = maps_user_flat cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( maps_gt_flat_cube, maps_user_flat_cube, n_pix @@ -334,6 +346,7 @@ class FSCDistanceLowMemory(MapToMapDistance): def __init__(self, config): super().__init__(config) self.n_pix = self.config["data"]["n_pix"] + self.config = config def compute_cost(self, map_1, map_2): raise NotImplementedError() @@ -341,11 +354,12 @@ def compute_cost(self, map_1, map_2): @override def get_distance(self, map1, map2, global_store_of_running_results): map_gt_flat = map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() map_gt_flat_cube = torch.zeros(self.n_pix**3) - map1 = map1[global_store_of_running_results["mask"]] - map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + if self.config["data"]["mask"]["do"]: + map_gt_flat = map_gt_flat[global_store_of_running_results["mask"]] + map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + else: + map_gt_flat_cube = map_gt_flat corr_vector = fourier_shell_correlation( map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 60d3f34..d7db04b 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -108,6 +108,10 @@ def run(config): maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) + else: + raise NotImplementedError( + f"Normalization method {config['analysis']['normalize']['method']} not implemented." + ) computed_assets = {} results_dict["mask"] = mask From 6def1bf6e6a20a4c3f50652002f23eb5f53fc1c2 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 10 Sep 2024 21:15:21 -0400 Subject: [PATCH 10/26] tests for masking and normalization --- ..._to_map_low_memory_nomask_nonormalize.yaml | 27 +++++++++ ..._config_map_to_map_nomask_nonormalize.yaml | 27 +++++++++ tests/test_map_to_map.py | 56 ++++++++++--------- 3 files changed, 85 insertions(+), 25 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml create mode 100644 tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml new file mode 100644 index 0000000..13a1cf4 --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml @@ -0,0 +1,27 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2_low_memory + - corr_low_memory + - bioem_low_memory + - fsc_low_memory + - res_low_memory + chunk_size_submission: 80 + chunk_size_gt: 190 + normalize: + do: false + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml new file mode 100644 index 0000000..7856db4 --- /dev/null +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -0,0 +1,27 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + normalize: + do: false + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index ed91815..957a9b1 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -4,31 +4,37 @@ def test_run_map2map_pipeline(): - args = OmegaConf.create( - {"config": "tests/config_files/test_config_map_to_map.yaml"} - ) - results_dict = run_map2map_pipeline.main(args) + for config_fname, config_fname_low_memory in zip( + [ + "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map.yaml", + ], + [ + "tests/config_files/test_config_map_to_map_low_memory.yaml", + "tests/config_files/test_config_map_to_map_low_memory.yaml", + ], + ): + args = OmegaConf.create({"config": config_fname}) + results_dict = run_map2map_pipeline.main(args) - args_low_memory = OmegaConf.create( - {"config": "tests/config_files/test_config_map_to_map_low_memory.yaml"} - ) - results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) - for metric in ["fsc", "corr", "l2", "bioem"]: - if metric == "fsc": + args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fsc_matrix" + ], + ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ + "fraction_nyquist" + ], + ) np.allclose( - results_dict[metric]["computed_assets"]["fsc_matrix"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fsc_matrix" - ], + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, ) - elif metric == "res": - np.allclose( - results_dict[metric]["computed_assets"]["fraction_nyquist"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fraction_nyquist" - ], - ) - np.allclose( - results_dict[metric]["cost_matrix"].values, - results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, - ) From c0f2f831facafc2fd3bddc5b112dfa0c97343254 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:16:11 -0400 Subject: [PATCH 11/26] code duplication for norm --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 4769528..b413ade 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -95,6 +95,10 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): return cost_matrix +def norm2(map1, map2): + return torch.norm(map1 - map2) ** 2 + + class L2DistanceNorm(MapToMapDistance): """L2 distance norm""" @@ -103,7 +107,7 @@ def __init__(self, config): @override def get_distance(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): @@ -114,7 +118,7 @@ def __init__(self, config): @override def compute_cost(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) def correlation(map1, map2): From f1f8cd8ad1dfc1fdae131f2f9ed3250125320423 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:30:59 -0400 Subject: [PATCH 12/26] tests passing. vmap over sub batch --- .../_map_to_map/map_to_map_distance.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index b413ade..317964f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -40,14 +40,30 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" chunk_size_submission = self.config["analysis"]["chunk_size_submission"] chunk_size_gt = self.config["analysis"]["chunk_size_gt"] - distance_matrix = torch.vmap( - lambda maps1: torch.vmap( - lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, - )(maps2), - chunk_size=chunk_size_gt, - )(maps1) + # load in memory as torch tensors + if True: # config.low_memory: + distance_matrix = torch.empty(len(maps1), len(maps2)) + n_chunks_low_memory = 100 + for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): + maps1_in_memory = maps1[idxs] + sub_distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1_in_memory) + distance_matrix[idxs] = sub_distance_matrix + else: + assert False, "Not implemented" + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1) return distance_matrix def get_computed_assets(self, maps1, maps2, global_store_of_running_results): From 61608cb3749fb874c725c856152af1c39d3b9c40 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:52:32 -0400 Subject: [PATCH 13/26] tests passing. get_sub_distance_matrix --- .../_map_to_map/map_to_map_distance.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 317964f..ce650f1 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -31,28 +31,35 @@ def __getitem__(self, idx): class MapToMapDistance: def __init__(self, config): self.config = config + self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] def get_distance(self, map1, map2): """Compute the distance between two maps.""" raise NotImplementedError() + def get_sub_distance_matrix(self, maps1, maps2): + """Compute the distance matrix between two sets of maps.""" + sub_distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=self.chunk_size_submission, + )(maps2), + chunk_size=self.chunk_size_gt, + )(maps1) + return sub_distance_matrix + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" - chunk_size_submission = self.config["analysis"]["chunk_size_submission"] - chunk_size_gt = self.config["analysis"]["chunk_size_gt"] # load in memory as torch tensors if True: # config.low_memory: distance_matrix = torch.empty(len(maps1), len(maps2)) n_chunks_low_memory = 100 for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): maps1_in_memory = maps1[idxs] - sub_distance_matrix = torch.vmap( - lambda maps1: torch.vmap( - lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, - )(maps2), - chunk_size=chunk_size_gt, - )(maps1_in_memory) + sub_distance_matrix = self.get_sub_distance_matrix( + maps1_in_memory, maps2 + ) distance_matrix[idxs] = sub_distance_matrix else: @@ -60,9 +67,9 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, + chunk_size=self.chunk_size_submission, )(maps2), - chunk_size=chunk_size_gt, + chunk_size=self.chunk_size_gt, )(maps1) return distance_matrix From 1440d8419078526460913d911cd538b91daefcb8 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 15:04:55 -0400 Subject: [PATCH 14/26] tests passing with hard coded instantiation --- .../_map_to_map/map_to_map_distance.py | 40 +++++++++++++++++-- .../_map_to_map/map_to_map_pipeline.py | 2 + 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index ce650f1..c24c300 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -38,7 +38,7 @@ def get_distance(self, map1, map2): """Compute the distance between two maps.""" raise NotImplementedError() - def get_sub_distance_matrix(self, maps1, maps2): + def get_sub_distance_matrix(self, maps1, maps2, idxs): """Compute the distance matrix between two sets of maps.""" sub_distance_matrix = torch.vmap( lambda maps1: torch.vmap( @@ -58,7 +58,9 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): maps1_in_memory = maps1[idxs] sub_distance_matrix = self.get_sub_distance_matrix( - maps1_in_memory, maps2 + maps1_in_memory, + maps2, + idxs, ) distance_matrix[idxs] = sub_distance_matrix @@ -311,6 +313,7 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.stored_computed_assets = {"fsc_matrix": torch.empty(10, 8, 8)} def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ @@ -334,7 +337,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): return cost_matrix, fsc_matrix @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + def get_sub_distance_matrix(self, maps1, maps2, idxs): """ Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. """ @@ -359,9 +362,38 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( maps_gt_flat_cube, maps_user_flat_cube, n_pix ) - self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix + # @override + # def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + # """ + # Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. + # """ + # maps_gt_flat = maps1 + # maps_user_flat = maps2 + # n_pix = self.config["data"]["n_pix"] + # maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) + # maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) + + # if self.config["data"]["mask"]["do"]: + # mask = ( + # mrcfile.open(self.config["data"]["mask"]["volume"]) + # .data.astype(bool) + # .flatten() + # ) + # maps_gt_flat_cube[:, mask] = maps_gt_flat + # maps_user_flat_cube[:, mask] = maps_user_flat + # else: + # maps_gt_flat_cube = maps_gt_flat + # maps_user_flat_cube = maps_user_flat + + # cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( + # maps_gt_flat_cube, maps_user_flat_cube, n_pix + # ) + # self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + # return cost_matrix + @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index d7db04b..7016c3f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -118,6 +118,8 @@ def run(config): for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) + print("maps_gt_flat", maps_gt_flat.shape) + print("maps_user_flat", maps_user_flat.shape) cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, From 555a39fc8a38200da3b03f0a84d5babf96390bd7 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 17:31:02 -0400 Subject: [PATCH 15/26] map_to_map_distance.distance_matrix_precomputation --- .../_map_to_map/map_to_map_distance.py | 15 ++++++++++++++- .../_map_to_map/map_to_map_pipeline.py | 3 +++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index c24c300..14f933a 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -33,6 +33,7 @@ def __init__(self, config): self.config = config self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + self.n_pix = self.config["data"]["n_pix"] def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -49,6 +50,10 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): )(maps1) return sub_distance_matrix + def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results): + """Pre-compute any assets needed for the distance matrix computation.""" + return + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" # load in memory as torch tensors @@ -313,7 +318,6 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - self.stored_computed_assets = {"fsc_matrix": torch.empty(10, 8, 8)} def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ @@ -336,6 +340,15 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix[idx] = dist return cost_matrix, fsc_matrix + @override + def distance_matrix_precomputation(self, maps1, maps2): + self.len_maps1 = len(maps1) + self.len_maps2 = len(maps2) + self.stored_computed_assets = { + "fsc_matrix": torch.empty(self.len_maps1, self.len_maps2, self.n_pix // 2) + } + return + @override def get_sub_distance_matrix(self, maps1, maps2, idxs): """ diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 7016c3f..434280e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -121,6 +121,9 @@ def run(config): print("maps_gt_flat", maps_gt_flat.shape) print("maps_user_flat", maps_user_flat.shape) + map_to_map_distance.distance_matrix_precomputation( + maps_gt_flat, maps_user_flat + ) cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, maps_user_flat, From b71e52b47955f8bc2657e30de287dfe363cacb0f Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 17:41:26 -0400 Subject: [PATCH 16/26] chunk_size_low_memory --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 5 +++-- tests/config_files/test_config_map_to_map.yaml | 1 + tests/config_files/test_config_map_to_map_low_memory.yaml | 1 + ...test_config_map_to_map_low_memory_nomask_nonormalize.yaml | 1 + .../test_config_map_to_map_nomask_nonormalize.yaml | 1 + 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 14f933a..e3cbf31 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -34,6 +34,7 @@ def __init__(self, config): self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] self.n_pix = self.config["data"]["n_pix"] + self.chunk_size_low_memory = self.config["analysis"]["chunk_size_low_memory"] def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -58,9 +59,9 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" # load in memory as torch tensors if True: # config.low_memory: + self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) - n_chunks_low_memory = 100 - for idxs in torch.arange(len(maps1)).chunk(n_chunks_low_memory): + for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] sub_distance_matrix = self.get_sub_distance_matrix( maps1_in_memory, diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 7dfa7e9..62651f8 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -21,6 +21,7 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index 4eb6cd0..e02e271 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -21,6 +21,7 @@ analysis: - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml index 13a1cf4..49dac78 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml @@ -21,6 +21,7 @@ analysis: - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: false method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index 7856db4..e97abaf 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -21,6 +21,7 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 + chunk_size_low_memory: 10 normalize: do: false method: median_zscore From 76fdb462e63e379d630de083ba78740e84a8e503 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 18:54:32 -0400 Subject: [PATCH 17/26] on the fly sub batch normalization --- .../_map_to_map/map_to_map_distance.py | 18 ++++++++++++++++++ .../_map_to_map/map_to_map_pipeline.py | 13 ------------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e3cbf31..df81ef6 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -28,6 +28,15 @@ def __getitem__(self, idx): return torch.from_numpy(sample.copy()) +def normalize(maps, method): + if method == "median_zscore": + maps -= maps.median(dim=1, keepdim=True).values + maps /= maps.std(dim=1, keepdim=True) + else: + raise NotImplementedError(f"Normalization method {method} not implemented.") + return maps + + class MapToMapDistance: def __init__(self, config): self.config = config @@ -58,11 +67,20 @@ def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" # load in memory as torch tensors + if self.config["analysis"]["normalize"]["do"]: + maps2 = normalize( + maps2, method=self.config["analysis"]["normalize"]["method"] + ) if True: # config.low_memory: self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] + if self.config["analysis"]["normalize"]["do"]: + maps1_in_memory = normalize( + maps1_in_memory, + method=self.config["analysis"]["normalize"]["method"], + ) sub_distance_matrix = self.get_sub_distance_matrix( maps1_in_memory, maps2, diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 434280e..0f29e56 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -100,19 +100,6 @@ def run(config): maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) - if config["analysis"]["normalize"]["do"]: - if config["analysis"]["normalize"]["method"] == "median_zscore": - if not low_memory_mode: - maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - if not low_memory_mode: - maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) - maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values - maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) - else: - raise NotImplementedError( - f"Normalization method {config['analysis']['normalize']['method']} not implemented." - ) - computed_assets = {} results_dict["mask"] = mask for distance_label, map_to_map_distance in map_to_map_distances.items(): From 0a2efc0ad9bb3758a781c5ed33a9b587c6bc937e Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 19:06:12 -0400 Subject: [PATCH 18/26] normalization and masking on the fly in sub batch --- .../_map_to_map/map_to_map_distance.py | 69 ++++++++----------- .../_map_to_map/map_to_map_pipeline.py | 21 ++---- .../config_files/test_config_map_to_map.yaml | 2 +- ..._config_map_to_map_nomask_nonormalize.yaml | 2 +- 4 files changed, 34 insertions(+), 60 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index df81ef6..0782f6f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -44,6 +44,11 @@ def __init__(self, config): self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] self.n_pix = self.config["data"]["n_pix"] self.chunk_size_low_memory = self.config["analysis"]["chunk_size_low_memory"] + self.mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -66,7 +71,11 @@ def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): """Compute the distance matrix between two sets of maps.""" - # load in memory as torch tensors + if self.config["data"]["mask"]["do"]: + maps2 = maps2[:, self.mask] + else: + maps2.reshape(len(maps2), -1, inplace=True) + if self.config["analysis"]["normalize"]["do"]: maps2 = normalize( maps2, method=self.config["analysis"]["normalize"]["method"] @@ -76,6 +85,10 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] + if self.config["data"]["mask"]["do"]: + maps1_in_memory = maps1_in_memory[:, self.mask] + else: + maps1_in_memory.reshape(len(maps1_in_memory), -1, inplace=True) if self.config["analysis"]["normalize"]["do"]: maps1_in_memory = normalize( maps1_in_memory, @@ -125,7 +138,7 @@ def get_distance(self, map1, map2, global_store_of_running_results): raise NotImplementedError( f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." ) - map1 = map1[global_store_of_running_results["mask"]] + map1 = map1[self.mask] return self.compute_cost(map1, map2) @@ -380,13 +393,8 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) if self.config["data"]["mask"]["do"]: - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat - maps_user_flat_cube[:, mask] = maps_user_flat + maps_gt_flat_cube[:, self.mask] = maps_gt_flat + maps_user_flat_cube[:, self.mask] = maps_user_flat else: maps_gt_flat_cube = maps_gt_flat maps_user_flat_cube = maps_user_flat @@ -397,35 +405,6 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix - # @override - # def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - # """ - # Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. - # """ - # maps_gt_flat = maps1 - # maps_user_flat = maps2 - # n_pix = self.config["data"]["n_pix"] - # maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - # maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - - # if self.config["data"]["mask"]["do"]: - # mask = ( - # mrcfile.open(self.config["data"]["mask"]["volume"]) - # .data.astype(bool) - # .flatten() - # ) - # maps_gt_flat_cube[:, mask] = maps_gt_flat - # maps_user_flat_cube[:, mask] = maps_user_flat - # else: - # maps_gt_flat_cube = maps_gt_flat - # maps_user_flat_cube = maps_user_flat - - # cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( - # maps_gt_flat_cube, maps_user_flat_cube, n_pix - # ) - # self.stored_computed_assets = {"fsc_matrix": fsc_matrix} - # return cost_matrix - @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first @@ -447,8 +426,8 @@ def get_distance(self, map1, map2, global_store_of_running_results): map_gt_flat = map1 = map1.flatten() map_gt_flat_cube = torch.zeros(self.n_pix**3) if self.config["data"]["mask"]["do"]: - map_gt_flat = map_gt_flat[global_store_of_running_results["mask"]] - map_gt_flat_cube[global_store_of_running_results["mask"]] = map_gt_flat + map_gt_flat = map_gt_flat[self.mask] + map_gt_flat_cube[self.mask] = map_gt_flat else: map_gt_flat_cube = map_gt_flat @@ -483,7 +462,15 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - return self.stored_computed_assets # must run get_distance_matrix first + """ + Return any computed assets that are needed for (downstream) analysis. + + Notes + ----- + The FSC matrix is stored in the computed assets. + Must run get_distance_matrix first. + """ + return self.stored_computed_assets class FSCResDistance(MapToMapDistance): diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 0f29e56..e12dab0 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,4 +1,4 @@ -import mrcfile +import numpy as np import pandas as pd import pickle import torch @@ -84,24 +84,11 @@ def run(config): if low_memory_mode: maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) else: - maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( - -1, n_pix**3 - ) - - if config["data"]["mask"]["do"]: - mask = ( - mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() - ) - if not low_memory_mode: - maps_gt_flat = maps_gt_flat[:, mask] - maps_user_flat = maps_user_flat[:, mask] - else: - if not low_memory_mode: - maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) - maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) + maps_gt_flat = torch.from_numpy( + np.load(config["data"]["ground_truth"]["volumes"]) + ).reshape(-1, n_pix**3) computed_assets = {} - results_dict["mask"] = mask for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 62651f8..0d951dc 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index e97abaf..b73b88f 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: false From 42525f4eaf67533a2642b39370afd9af3a03c1df Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 19:48:56 -0400 Subject: [PATCH 19/26] tests passing --- .../_map_to_map/map_to_map_distance.py | 16 ++++++++++------ .../_map_to_map/map_to_map_pipeline.py | 11 ++++------- .../data/_validation/config_validators.py | 1 + tests/config_files/test_config_map_to_map.yaml | 4 +++- .../test_config_map_to_map_low_memory.yaml | 4 +++- ...map_to_map_low_memory_nomask_nonormalize.yaml | 3 +++ ...est_config_map_to_map_nomask_nonormalize.yaml | 4 +++- tests/test_map_to_map.py | 4 ++-- 8 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 0782f6f..451ca5f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -40,10 +40,13 @@ def normalize(maps, method): class MapToMapDistance: def __init__(self, config): self.config = config + self.do_low_memory_mode = self.config["analysis"]["low_memory"]["do"] self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] self.n_pix = self.config["data"]["n_pix"] - self.chunk_size_low_memory = self.config["analysis"]["chunk_size_low_memory"] + self.chunk_size_low_memory = self.config["analysis"]["low_memory"][ + "chunk_size_low_memory" + ] self.mask = ( mrcfile.open(self.config["data"]["mask"]["volume"]) .data.astype(bool) @@ -74,21 +77,23 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): if self.config["data"]["mask"]["do"]: maps2 = maps2[:, self.mask] else: - maps2.reshape(len(maps2), -1, inplace=True) + maps2 = maps2.reshape(len(maps2), -1) if self.config["analysis"]["normalize"]["do"]: maps2 = normalize( maps2, method=self.config["analysis"]["normalize"]["method"] ) - if True: # config.low_memory: + if True: # self.do_low_memory_mode: self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): maps1_in_memory = maps1[idxs] if self.config["data"]["mask"]["do"]: - maps1_in_memory = maps1_in_memory[:, self.mask] + maps1_in_memory = maps1_in_memory.reshape(len(idxs), -1)[ + :, self.mask + ] else: - maps1_in_memory.reshape(len(maps1_in_memory), -1, inplace=True) + maps1_in_memory = maps1_in_memory.reshape(len(maps1_in_memory), -1) if self.config["analysis"]["normalize"]["do"]: maps1_in_memory = normalize( maps1_in_memory, @@ -102,7 +107,6 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix[idxs] = sub_distance_matrix else: - assert False, "Not implemented" distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index e12dab0..8d3d9f5 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -54,15 +54,12 @@ def run(config): } assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 - if len(map_to_map_distances_low_memory) > 0: map_to_map_distances = map_to_map_distances_low_memory - low_memory_mode = True - else: - low_memory_mode = False - if not low_memory_mode: - n_pix = config["data"]["n_pix"] + do_low_memory_mode = config["analysis"]["low_memory"]["do"] + + n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] @@ -81,7 +78,7 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - if low_memory_mode: + if do_low_memory_mode: maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) else: maps_gt_flat = torch.from_numpy( diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index b2fa933..22046c3 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -158,6 +158,7 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: "chunk_size_submission": Number, "chunk_size_gt": Number, "normalize": dict, + "low_memory": dict, } validate_generic_config(config_analysis, keys_and_types) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 0d951dc..fd5990c 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -21,7 +21,9 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 - chunk_size_low_memory: 10 + low_memory: + do: false + chunk_size_low_memory: 10 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml index e02e271..35ef921 100644 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory.yaml @@ -21,7 +21,9 @@ analysis: - res_low_memory chunk_size_submission: 80 chunk_size_gt: 190 - chunk_size_low_memory: 10 + low_memory: + do: false + chunk_size_low_memory: null normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml index 49dac78..4cfb346 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml @@ -22,6 +22,9 @@ analysis: chunk_size_submission: 80 chunk_size_gt: 190 chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: null normalize: do: false method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index b73b88f..d37e4ac 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -21,7 +21,9 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 - chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 normalize: do: false method: median_zscore diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 957a9b1..f0d65b7 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -7,11 +7,11 @@ def test_run_map2map_pipeline(): for config_fname, config_fname_low_memory in zip( [ "tests/config_files/test_config_map_to_map.yaml", - "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", ], [ "tests/config_files/test_config_map_to_map_low_memory.yaml", - "tests/config_files/test_config_map_to_map_low_memory.yaml", + "tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml", ], ): args = OmegaConf.create({"config": config_fname}) From e472372eb2418940801825d400b753b8b7be1343 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 19:53:56 -0400 Subject: [PATCH 20/26] tests passing for low memory off --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 7 ++++++- tests/config_files/test_config_map_to_map.yaml | 2 +- .../test_config_map_to_map_nomask_nonormalize.yaml | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 451ca5f..73f2502 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -83,7 +83,7 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): maps2 = normalize( maps2, method=self.config["analysis"]["normalize"]["method"] ) - if True: # self.do_low_memory_mode: + if self.do_low_memory_mode: self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory distance_matrix = torch.empty(len(maps1), len(maps2)) for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): @@ -409,6 +409,11 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + idxs = torch.arange(len(maps1)) + return self.get_sub_distance_matrix(maps1, maps2, idxs) + @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index fd5990c..689eb82 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -23,7 +23,7 @@ analysis: chunk_size_gt: 190 low_memory: do: false - chunk_size_low_memory: 10 + chunk_size_low_memory: null normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index d37e4ac..2a9a3a8 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -22,8 +22,8 @@ analysis: chunk_size_submission: 80 chunk_size_gt: 190 low_memory: - do: true - chunk_size_low_memory: 10 + do: false + chunk_size_low_memory: null normalize: do: false method: median_zscore From 3077f6ae89e5fd822ab606352db1792db803438b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:23:29 -0400 Subject: [PATCH 21/26] tests passing. time to delete separate low memory --- .../_map_to_map/map_to_map_distance.py | 42 +++++++++++++++++-- .../_map_to_map/map_to_map_pipeline.py | 2 - ...config_map_to_map_low_memory_subbatch.yaml | 31 ++++++++++++++ tests/test_map_to_map.py | 31 ++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 73f2502..eec6828 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -397,8 +397,9 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) if self.config["data"]["mask"]["do"]: - maps_gt_flat_cube[:, self.mask] = maps_gt_flat + maps_gt_flat_cube[:, self.mask] = maps_gt_flat[:] maps_user_flat_cube[:, self.mask] = maps_user_flat + else: maps_gt_flat_cube = maps_gt_flat maps_user_flat_cube = maps_user_flat @@ -411,8 +412,43 @@ def get_sub_distance_matrix(self, maps1, maps2, idxs): @override def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - idxs = torch.arange(len(maps1)) - return self.get_sub_distance_matrix(maps1, maps2, idxs) + """Compute the distance matrix between two sets of maps.""" + if self.config["data"]["mask"]["do"]: + maps2 = maps2[:, self.mask] + else: + maps2 = maps2.reshape(len(maps2), -1) + + if self.config["analysis"]["normalize"]["do"]: + maps2 = normalize( + maps2, method=self.config["analysis"]["normalize"]["method"] + ) + if self.chunk_size_low_memory is None: + self.n_chunks_low_memory = 1 + else: + self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory + distance_matrix = torch.empty(len(maps1), len(maps2)) + for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): + maps1_in_memory = maps1[idxs] + + if self.config["data"]["mask"]["do"]: + maps1_in_memory = maps1_in_memory[:].reshape(len(idxs), -1)[ + :, self.mask + ] + + else: + maps1_in_memory = maps1_in_memory.reshape(len(maps1_in_memory), -1) + if self.config["analysis"]["normalize"]["do"]: + maps1_in_memory = normalize( + maps1_in_memory, + method=self.config["analysis"]["normalize"]["method"], + ) + sub_distance_matrix = self.get_sub_distance_matrix( + maps1_in_memory, + maps2, + idxs, + ) + distance_matrix[idxs] = sub_distance_matrix + return distance_matrix @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index 8d3d9f5..c85b5d1 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -89,8 +89,6 @@ def run(config): for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) - print("maps_gt_flat", maps_gt_flat.shape) - print("maps_user_flat", maps_user_flat.shape) map_to_map_distance.distance_matrix_precomputation( maps_gt_flat, maps_user_flat diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml new file mode 100644 index 0000000..7b02d2e --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml @@ -0,0 +1,31 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: true + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 + normalize: + do: true + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index f0d65b7..492610a 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -4,6 +4,37 @@ def test_run_map2map_pipeline(): + for config_fname, config_fname_low_memory in zip( + [ + "tests/config_files/test_config_map_to_map.yaml", + ], + [ + "tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml", + ], + ): + args = OmegaConf.create({"config": config_fname}) + results_dict = run_map2map_pipeline.main(args) + + args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric]["computed_assets"]["fsc_matrix"], + ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric]["computed_assets"][ + "fraction_nyquist" + ], + ) + np.allclose( + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric]["cost_matrix"].values, + ) + for config_fname, config_fname_low_memory in zip( [ "tests/config_files/test_config_map_to_map.yaml", From cb7085e6569e6222e0ea665a2785d595290e56d6 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:30:38 -0400 Subject: [PATCH 22/26] all tests passing --- .../_map_to_map/map_to_map_distance.py | 14 +++++---- ...ow_memory_subbatch_nomask_nonormalize.yaml | 31 +++++++++++++++++++ tests/test_map_to_map.py | 2 ++ 3 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index eec6828..e70817e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -47,11 +47,12 @@ def __init__(self, config): self.chunk_size_low_memory = self.config["analysis"]["low_memory"][ "chunk_size_low_memory" ] - self.mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) + if self.config["data"]["mask"]["do"]: + self.mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) def get_distance(self, map1, map2): """Compute the distance between two maps.""" @@ -142,7 +143,8 @@ def get_distance(self, map1, map2, global_store_of_running_results): raise NotImplementedError( f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." ) - map1 = map1[self.mask] + if self.config["data"]["mask"]["do"]: + map1 = map1[self.mask] return self.compute_cost(map1, map2) diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml new file mode 100644 index 0000000..020f58f --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml @@ -0,0 +1,31 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: dummy-string +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 + normalize: + do: false + method: dummy-string +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 492610a..2c607e7 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -7,9 +7,11 @@ def test_run_map2map_pipeline(): for config_fname, config_fname_low_memory in zip( [ "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", ], [ "tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml", + "tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml", ], ): args = OmegaConf.create({"config": config_fname}) From 1f4057e29357da659c24e03c36eb296c17816a68 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:37:35 -0400 Subject: [PATCH 23/26] remove low memory versions --- .../_map_to_map/map_to_map_distance.py | 149 ------------------ .../_map_to_map/map_to_map_pipeline.py | 23 --- .../test_config_map_to_map_low_memory.yaml | 30 ---- ..._to_map_low_memory_nomask_nonormalize.yaml | 31 ---- tests/test_map_to_map.py | 35 ---- 5 files changed, 268 deletions(-) delete mode 100644 tests/config_files/test_config_map_to_map_low_memory.yaml delete mode 100644 tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e70817e..5b9dc8e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -122,47 +122,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} -class MapToMapDistanceLowMemory(MapToMapDistance): - """General class for map-to-map distance metrics that require low memory.""" - - def __init__(self, config): - super().__init__(config) - self.config = config - - def compute_cost(self, map_1, map_2): - raise NotImplementedError() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map1 = map1.flatten() - if self.config["analysis"]["normalize"]["do"]: - if self.config["analysis"]["normalize"]["method"] == "median_zscore": - map1 -= map1.median() - map1 /= map1.std() - else: - raise NotImplementedError( - f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." - ) - if self.config["data"]["mask"]["do"]: - map1 = map1[self.mask] - - return self.compute_cost(map1, map2) - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - return cost_matrix - - def norm2(map1, map2): return torch.norm(map1 - map2) ** 2 @@ -178,17 +137,6 @@ def get_distance(self, map1, map2): return norm2(map1, map2) -class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): - """L2 distance norm""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return norm2(map1, map2) - - def correlation(map1, map2): return (map1 * map2).sum() @@ -206,17 +154,6 @@ def get_distance(self, map1, map2): return correlation(map1, map2) -class CorrelationLowMemory(MapToMapDistanceLowMemory): - """Correlation.""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return correlation(map1, map2) - - def compute_bioem3d_cost(map1, map2): """ Compute the cost between two maps using the BioEM cost function in 3D. @@ -272,17 +209,6 @@ def get_distance(self, map1, map2): return compute_bioem3d_cost(map1, map2) -class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): - """BioEM 3D distance.""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return compute_bioem3d_cost(map1, map2) - - def fourier_shell_correlation( x: torch.Tensor, y: torch.Tensor, @@ -457,69 +383,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first -class FSCDistanceLowMemory(MapToMapDistance): - """Fourier Shell Correlation distance.""" - - def __init__(self, config): - super().__init__(config) - self.n_pix = self.config["data"]["n_pix"] - self.config = config - - def compute_cost(self, map_1, map_2): - raise NotImplementedError() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map_gt_flat = map1 = map1.flatten() - map_gt_flat_cube = torch.zeros(self.n_pix**3) - if self.config["data"]["mask"]["do"]: - map_gt_flat = map_gt_flat[self.mask] - map_gt_flat_cube[self.mask] = map_gt_flat - else: - map_gt_flat_cube = map_gt_flat - - corr_vector = fourier_shell_correlation( - map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), - map2.reshape(self.n_pix, self.n_pix, self.n_pix), - ) - dist = 1 - corr_vector.mean() # TODO: spectral cutoff - self.stored_computed_assets = {"corr_vector": corr_vector} - return dist - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - fsc_matrix = torch.zeros( - len(maps_gt_flat), len(maps_user_flat), self.n_pix // 2 - ) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - fsc_matrix[idx_gt, idx_user] = self.stored_computed_assets[ - "corr_vector" - ] - self.stored_computed_assets = {"fsc_matrix": fsc_matrix} - return cost_matrix - - @override - def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - """ - Return any computed assets that are needed for (downstream) analysis. - - Notes - ----- - The FSC matrix is stored in the computed assets. - Must run get_distance_matrix first. - """ - return self.stored_computed_assets - - class FSCResDistance(MapToMapDistance): """FSC Resolution distance. @@ -555,15 +418,3 @@ def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} return units_Angstroms[res_fsc_half] - - -class FSCResDistanceLowMemory(FSCResDistance): - """FSC Resolution distance. - - The resolution at which the Fourier Shell Correlation reaches 0.5. - Built on top of the FSCDistance class. This needs to be run first and store the FSC matrix in the computed assets. - """ - - def __init__(self, config): - super().__init__(config) - self.fsc_label = "fsc_low_memory" diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index c85b5d1..c281496 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -7,15 +7,10 @@ from .._map_to_map.map_to_map_distance import ( GT_Dataset, FSCDistance, - FSCDistanceLowMemory, Correlation, - CorrelationLowMemory, L2DistanceNorm, - L2DistanceNormLowMemory, BioEM3dDistance, - BioEM3dDistanceLowMemory, FSCResDistance, - FSCResDistanceLowMemory, ) @@ -27,14 +22,6 @@ "res": FSCResDistance, } -AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY = { - "corr_low_memory": CorrelationLowMemory, - "l2_low_memory": L2DistanceNormLowMemory, - "bioem_low_memory": BioEM3dDistanceLowMemory, - "fsc_low_memory": FSCDistanceLowMemory, - "res_low_memory": FSCResDistanceLowMemory, -} - def run(config): """ @@ -47,16 +34,6 @@ def run(config): if distance_label in config["analysis"]["metrics"] } - map_to_map_distances_low_memory = { - distance_label: distance_class(config) - for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY.items() - if distance_label in config["analysis"]["metrics"] - } - - assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 - if len(map_to_map_distances_low_memory) > 0: - map_to_map_distances = map_to_map_distances_low_memory - do_low_memory_mode = config["analysis"]["low_memory"]["do"] n_pix = config["data"]["n_pix"] diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml deleted file mode 100644 index 35ef921..0000000 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ /dev/null @@ -1,30 +0,0 @@ -data: - n_pix: 16 - psize: 30.044 - submission: - fname: tests/data/dataset_2_submissions/submission_1000.pt - volume_key: volumes - metadata_key: populations - label_key: id - ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy - metadata: tests/data/Ground_truth/test_metadata_10.csv - mask: - do: true - volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc -analysis: - metrics: - - l2_low_memory - - corr_low_memory - - bioem_low_memory - - fsc_low_memory - - res_low_memory - chunk_size_submission: 80 - chunk_size_gt: 190 - low_memory: - do: false - chunk_size_low_memory: null - normalize: - do: true - method: median_zscore -output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml deleted file mode 100644 index 4cfb346..0000000 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ /dev/null @@ -1,31 +0,0 @@ -data: - n_pix: 16 - psize: 30.044 - submission: - fname: tests/data/dataset_2_submissions/submission_1000.pt - volume_key: volumes - metadata_key: populations - label_key: id - ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy - metadata: tests/data/Ground_truth/test_metadata_10.csv - mask: - do: false - volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc -analysis: - metrics: - - l2_low_memory - - corr_low_memory - - bioem_low_memory - - fsc_low_memory - - res_low_memory - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 - low_memory: - do: true - chunk_size_low_memory: null - normalize: - do: false - method: median_zscore -output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 2c607e7..907e6d3 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -36,38 +36,3 @@ def test_run_map2map_pipeline(): results_dict[metric]["cost_matrix"].values, results_dict_low_memory[metric]["cost_matrix"].values, ) - - for config_fname, config_fname_low_memory in zip( - [ - "tests/config_files/test_config_map_to_map.yaml", - "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", - ], - [ - "tests/config_files/test_config_map_to_map_low_memory.yaml", - "tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml", - ], - ): - args = OmegaConf.create({"config": config_fname}) - results_dict = run_map2map_pipeline.main(args) - - args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) - results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) - for metric in ["fsc", "corr", "l2", "bioem"]: - if metric == "fsc": - np.allclose( - results_dict[metric]["computed_assets"]["fsc_matrix"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fsc_matrix" - ], - ) - elif metric == "res": - np.allclose( - results_dict[metric]["computed_assets"]["fraction_nyquist"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fraction_nyquist" - ], - ) - np.allclose( - results_dict[metric]["cost_matrix"].values, - results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, - ) From b33a6a06f707209efc3c9ec7581b6d4c8fb3e36b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 20:57:25 -0400 Subject: [PATCH 24/26] update configs to remove low memory metrics --- .../data/_validation/config_validators.py | 1 + .../data/_validation/output_validators.py | 18 ------------------ 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 22046c3..83083ed 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -151,6 +151,7 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: chunk_size_submission: int, is the chunk size for the submission volume. chunk_size_gt: int, is the chunk size for the ground truth volume. normalize: dict, is the normalize part of the analysis part of the config. + low_memory: dict, is the low memory part of the analysis part of the config. # TODO: add validation for low_memory """ # noqa: E501 keys_and_types = { diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index ada4492..9f76a6d 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -19,27 +19,18 @@ class MapToMapResultsValidator: config: dict, input config dictionary. user_submitted_populations: torch.Tensor, user submitted populations, which sum to 1. corr: dict, correlation results. - corr_low_memory: dict, correlation results in low memory mode. l2: dict, L2 results. - l2_low_memory: dict, L2 results in low memory mode. bioem: dict, BioEM results. - bioem_low_memory: dict, BioEM results in low memory mode. fsc: dict, FSC results. - fsc_low_memory: dict, FSC results in low memory mode. """ config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None - corr_low_memory: Optional[dict] = None l2: Optional[dict] = None - l2_low_memory: Optional[dict] = None bioem: Optional[dict] = None - bioem_low_memory: Optional[dict] = None fsc: Optional[dict] = None - fsc_low_memory: Optional[dict] = None res: Optional[dict] = None - res_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_mtm(self.config) @@ -150,10 +141,6 @@ class DistributionToDistributionResultsValidator: bioem: dict, BioEM distance results. l2: dict, L2 distance results. corr: dict, correlation distance results. - fsc_low_memory: dict, FSC distance results in low memory mode. - bioem_low_memory: dict, BioEM distance results in low memory mode. - l2_low_memory: dict, L2 distance results in low memory mode. - corr_low_memory: dict, correlation distance results in low memory mode. """ config: dict @@ -164,11 +151,6 @@ class DistributionToDistributionResultsValidator: res: Optional[dict] = None l2: Optional[dict] = None corr: Optional[dict] = None - fsc_low_memory: Optional[dict] = None - bioem_low_memory: Optional[dict] = None - res_low_memory: Optional[dict] = None - l2_low_memory: Optional[dict] = None - corr_low_memory: Optional[dict] = None def __post_init__(self): validate_input_config_disttodist(self.config) From 9792ed351dc81dcb3486116a024f59507a1bda71 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 18 Sep 2024 14:04:32 -0400 Subject: [PATCH 25/26] remove numpy depedency. use torch mmap --- .../_map_to_map/map_to_map_distance.py | 24 ++----------------- .../_map_to_map/map_to_map_pipeline.py | 13 +++------- .../config_files/test_config_map_to_map.yaml | 6 ++--- ...config_map_to_map_low_memory_subbatch.yaml | 9 ++++--- ...ow_memory_subbatch_nomask_nonormalize.yaml | 9 ++++--- ..._config_map_to_map_nomask_nonormalize.yaml | 6 ++--- 6 files changed, 19 insertions(+), 48 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 5b9dc8e..3021db5 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -4,28 +4,6 @@ from typing_extensions import override import mrcfile import numpy as np -from torch.utils.data import Dataset - - -class GT_Dataset(Dataset): - def __init__(self, npy_file): - self.npy_file = npy_file - self.data = np.load(npy_file, mmap_mode="r+") - - self.shape = self.data.shape - self._dim = len(self.data.shape) - - def dim(self): - return self._dim - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - sample = self.data[idx] - return torch.from_numpy(sample.copy()) def normalize(maps, method): @@ -108,6 +86,8 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix[idxs] = sub_distance_matrix else: + maps1 = maps1.reshape(len(maps1), -1) + maps2 = maps2.reshape(len(maps2), -1) distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index c281496..06ce66f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,11 +1,9 @@ -import numpy as np import pandas as pd import pickle import torch from ..data._validation.output_validators import MapToMapResultsValidator from .._map_to_map.map_to_map_distance import ( - GT_Dataset, FSCDistance, Correlation, L2DistanceNorm, @@ -36,8 +34,6 @@ def run(config): do_low_memory_mode = config["analysis"]["low_memory"]["do"] - n_pix = config["data"]["n_pix"] - submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] submission_metadata_key = config["data"]["submission"]["metadata_key"] @@ -55,12 +51,9 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - if do_low_memory_mode: - maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) - else: - maps_gt_flat = torch.from_numpy( - np.load(config["data"]["ground_truth"]["volumes"]) - ).reshape(-1, n_pix**3) + maps_gt_flat = torch.load( + config["data"]["ground_truth"]["volumes"], mmap=do_low_memory_mode + ) computed_assets = {} for distance_label, map_to_map_distance in map_to_map_distances.items(): diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 689eb82..2244e21 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -19,8 +19,8 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: false chunk_size_low_memory: null diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml index 7b02d2e..8bc02e7 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -19,12 +19,11 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: true - chunk_size_low_memory: 10 + chunk_size_low_memory: 2 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml index 020f58f..74b494b 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: false @@ -19,12 +19,11 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: true - chunk_size_low_memory: 10 + chunk_size_low_memory: 2 normalize: do: false method: dummy-string diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index 2a9a3a8..a8a4f09 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: false @@ -19,8 +19,8 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: false chunk_size_low_memory: null From 8977e865efc56d18eda4146f1749d919cc1eef2b Mon Sep 17 00:00:00 2001 From: DSilva27 Date: Mon, 23 Sep 2024 17:15:51 -0400 Subject: [PATCH 26/26] make it so that the icecream to submission tables get updated instead of overwritten --- .../_preprocessing/preprocessing_pipeline.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py index b4c3e61..5053a29 100644 --- a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py +++ b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py @@ -38,6 +38,22 @@ def save_submission(volumes, populations, submission_id, submission_index, confi return submission_dict +def update_hash_table(hash_table_path, hash_table): + if os.path.exists(hash_table_path): + with open(hash_table_path, "r") as f: + hash_table_old = json.load(f) + hash_table_old.update(hash_table) + + with open(hash_table_path, "w") as f: + json.dump(hash_table_old, f, indent=4) + + else: + with open(hash_table_path, "w") as f: + json.dump(hash_table, f, indent=4) + + return + + def preprocess_submissions(submission_dataset, config): hash_table = {} box_size_gt = submission_dataset.submission_config["gt"]["box_size"] @@ -79,7 +95,7 @@ def preprocess_submissions(submission_dataset, config): volumes = threshold_submissions(volumes, config["thresh_percentile"]) # center submission - print(" Centering submission") + # print(" Centering submission") # volumes = center_submission(volumes, pixel_size=pixel_size_gt) # flip handedness @@ -121,7 +137,7 @@ def preprocess_submissions(submission_dataset, config): hash_table_path = os.path.join( config["output_path"], "submission_to_icecream_table.json" ) - with open(hash_table_path, "w") as f: - json.dump(hash_table, f, indent=4) + + update_hash_table(hash_table_path, hash_table) return