diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e70817e..5b9dc8e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -122,47 +122,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} -class MapToMapDistanceLowMemory(MapToMapDistance): - """General class for map-to-map distance metrics that require low memory.""" - - def __init__(self, config): - super().__init__(config) - self.config = config - - def compute_cost(self, map_1, map_2): - raise NotImplementedError() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map1 = map1.flatten() - if self.config["analysis"]["normalize"]["do"]: - if self.config["analysis"]["normalize"]["method"] == "median_zscore": - map1 -= map1.median() - map1 /= map1.std() - else: - raise NotImplementedError( - f"Normalization method {self.config['analysis']['normalize']['method']} not implemented." - ) - if self.config["data"]["mask"]["do"]: - map1 = map1[self.mask] - - return self.compute_cost(map1, map2) - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - return cost_matrix - - def norm2(map1, map2): return torch.norm(map1 - map2) ** 2 @@ -178,17 +137,6 @@ def get_distance(self, map1, map2): return norm2(map1, map2) -class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): - """L2 distance norm""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return norm2(map1, map2) - - def correlation(map1, map2): return (map1 * map2).sum() @@ -206,17 +154,6 @@ def get_distance(self, map1, map2): return correlation(map1, map2) -class CorrelationLowMemory(MapToMapDistanceLowMemory): - """Correlation.""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return correlation(map1, map2) - - def compute_bioem3d_cost(map1, map2): """ Compute the cost between two maps using the BioEM cost function in 3D. @@ -272,17 +209,6 @@ def get_distance(self, map1, map2): return compute_bioem3d_cost(map1, map2) -class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): - """BioEM 3D distance.""" - - def __init__(self, config): - super().__init__(config) - - @override - def compute_cost(self, map1, map2): - return compute_bioem3d_cost(map1, map2) - - def fourier_shell_correlation( x: torch.Tensor, y: torch.Tensor, @@ -457,69 +383,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first -class FSCDistanceLowMemory(MapToMapDistance): - """Fourier Shell Correlation distance.""" - - def __init__(self, config): - super().__init__(config) - self.n_pix = self.config["data"]["n_pix"] - self.config = config - - def compute_cost(self, map_1, map_2): - raise NotImplementedError() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map_gt_flat = map1 = map1.flatten() - map_gt_flat_cube = torch.zeros(self.n_pix**3) - if self.config["data"]["mask"]["do"]: - map_gt_flat = map_gt_flat[self.mask] - map_gt_flat_cube[self.mask] = map_gt_flat - else: - map_gt_flat_cube = map_gt_flat - - corr_vector = fourier_shell_correlation( - map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix), - map2.reshape(self.n_pix, self.n_pix, self.n_pix), - ) - dist = 1 - corr_vector.mean() # TODO: spectral cutoff - self.stored_computed_assets = {"corr_vector": corr_vector} - return dist - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - fsc_matrix = torch.zeros( - len(maps_gt_flat), len(maps_user_flat), self.n_pix // 2 - ) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - fsc_matrix[idx_gt, idx_user] = self.stored_computed_assets[ - "corr_vector" - ] - self.stored_computed_assets = {"fsc_matrix": fsc_matrix} - return cost_matrix - - @override - def get_computed_assets(self, maps1, maps2, global_store_of_running_results): - """ - Return any computed assets that are needed for (downstream) analysis. - - Notes - ----- - The FSC matrix is stored in the computed assets. - Must run get_distance_matrix first. - """ - return self.stored_computed_assets - - class FSCResDistance(MapToMapDistance): """FSC Resolution distance. @@ -555,15 +418,3 @@ def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix) self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist} return units_Angstroms[res_fsc_half] - - -class FSCResDistanceLowMemory(FSCResDistance): - """FSC Resolution distance. - - The resolution at which the Fourier Shell Correlation reaches 0.5. - Built on top of the FSCDistance class. This needs to be run first and store the FSC matrix in the computed assets. - """ - - def __init__(self, config): - super().__init__(config) - self.fsc_label = "fsc_low_memory" diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index c85b5d1..c281496 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -7,15 +7,10 @@ from .._map_to_map.map_to_map_distance import ( GT_Dataset, FSCDistance, - FSCDistanceLowMemory, Correlation, - CorrelationLowMemory, L2DistanceNorm, - L2DistanceNormLowMemory, BioEM3dDistance, - BioEM3dDistanceLowMemory, FSCResDistance, - FSCResDistanceLowMemory, ) @@ -27,14 +22,6 @@ "res": FSCResDistance, } -AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY = { - "corr_low_memory": CorrelationLowMemory, - "l2_low_memory": L2DistanceNormLowMemory, - "bioem_low_memory": BioEM3dDistanceLowMemory, - "fsc_low_memory": FSCDistanceLowMemory, - "res_low_memory": FSCResDistanceLowMemory, -} - def run(config): """ @@ -47,16 +34,6 @@ def run(config): if distance_label in config["analysis"]["metrics"] } - map_to_map_distances_low_memory = { - distance_label: distance_class(config) - for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY.items() - if distance_label in config["analysis"]["metrics"] - } - - assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0 - if len(map_to_map_distances_low_memory) > 0: - map_to_map_distances = map_to_map_distances_low_memory - do_low_memory_mode = config["analysis"]["low_memory"]["do"] n_pix = config["data"]["n_pix"] diff --git a/tests/config_files/test_config_map_to_map_low_memory.yaml b/tests/config_files/test_config_map_to_map_low_memory.yaml deleted file mode 100644 index 35ef921..0000000 --- a/tests/config_files/test_config_map_to_map_low_memory.yaml +++ /dev/null @@ -1,30 +0,0 @@ -data: - n_pix: 16 - psize: 30.044 - submission: - fname: tests/data/dataset_2_submissions/submission_1000.pt - volume_key: volumes - metadata_key: populations - label_key: id - ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy - metadata: tests/data/Ground_truth/test_metadata_10.csv - mask: - do: true - volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc -analysis: - metrics: - - l2_low_memory - - corr_low_memory - - bioem_low_memory - - fsc_low_memory - - res_low_memory - chunk_size_submission: 80 - chunk_size_gt: 190 - low_memory: - do: false - chunk_size_low_memory: null - normalize: - do: true - method: median_zscore -output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml deleted file mode 100644 index 4cfb346..0000000 --- a/tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml +++ /dev/null @@ -1,31 +0,0 @@ -data: - n_pix: 16 - psize: 30.044 - submission: - fname: tests/data/dataset_2_submissions/submission_1000.pt - volume_key: volumes - metadata_key: populations - label_key: id - ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy - metadata: tests/data/Ground_truth/test_metadata_10.csv - mask: - do: false - volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc -analysis: - metrics: - - l2_low_memory - - corr_low_memory - - bioem_low_memory - - fsc_low_memory - - res_low_memory - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 - low_memory: - do: true - chunk_size_low_memory: null - normalize: - do: false - method: median_zscore -output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index 2c607e7..907e6d3 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -36,38 +36,3 @@ def test_run_map2map_pipeline(): results_dict[metric]["cost_matrix"].values, results_dict_low_memory[metric]["cost_matrix"].values, ) - - for config_fname, config_fname_low_memory in zip( - [ - "tests/config_files/test_config_map_to_map.yaml", - "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", - ], - [ - "tests/config_files/test_config_map_to_map_low_memory.yaml", - "tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml", - ], - ): - args = OmegaConf.create({"config": config_fname}) - results_dict = run_map2map_pipeline.main(args) - - args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) - results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) - for metric in ["fsc", "corr", "l2", "bioem"]: - if metric == "fsc": - np.allclose( - results_dict[metric]["computed_assets"]["fsc_matrix"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fsc_matrix" - ], - ) - elif metric == "res": - np.allclose( - results_dict[metric]["computed_assets"]["fraction_nyquist"], - results_dict_low_memory[metric + "_low_memory"]["computed_assets"][ - "fraction_nyquist" - ], - ) - np.allclose( - results_dict[metric]["cost_matrix"].values, - results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values, - )