From 67ebf6f920b81ca8b38718715cce37eaca34caac Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 10:04:04 -0400 Subject: [PATCH 01/13] new general map to map distance class --- .../_map_to_map/map_to_map_distance_matrix.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 0578dfa..7880578 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -28,6 +28,34 @@ def vmap_distance( )(maps_gt) +class MapToMapDistance: + def __init__(self, config): + self.config = config + + def get_distance(self, map1, map2): + raise NotImplementedError() + + def get_distance_matrix(self, maps1, maps2): + chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1) + + return distance_matrix + +class L2Distance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def get_distance(self, map1, map2): + return torch.norm(map1 - map2)**2 + + def run(config): """ Compare a submission to ground truth. @@ -99,6 +127,13 @@ def run(config): ) cost_matrix = cost_matrix.numpy() computed_assets["fsc_matrix"] = fsc_matrix + elif cost_label == "nope": + l2_map2map_distance = L2Distance(config) + cost_matrix = l2_map2map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat + ).numpy() + print('run new method') + else: cost_matrix = vmap_distance( maps_gt_flat, @@ -111,6 +146,7 @@ def run(config): cost_matrix_df = pd.DataFrame( cost_matrix, columns=None, index=metadata_gt.populations.tolist() ) + print(cost_matrix_df) # output results single_distance_results_dict = { From 648a8c91487e28305a193162997118f3c728317c Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 10:42:41 -0400 Subject: [PATCH 02/13] old and new methods numerically agreeing --- .../_map_to_map/map_to_map_distance_matrix.py | 50 +++++++++++++++---- .../data/_validation/output_validators.py | 6 +++ .../config_files/test_config_map_to_map.yaml | 5 ++ 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 7880578..9964069 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -48,13 +48,33 @@ def get_distance_matrix(self, maps1, maps2): return distance_matrix -class L2Distance(MapToMapDistance): +class L2DistanceNorm(MapToMapDistance): def __init__(self, config): super().__init__(config) def get_distance(self, map1, map2): return torch.norm(map1 - map2)**2 +class L2DistanceSum(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def get_distance(self, map1, map2): + return compute_cost_l2(map1, map2) + +class Correlation(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def get_distance(self, map1, map2): + return compute_cost_corr(map1, map2) + +class BioEM3dDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def get_distance(self, map1, map2): + return compute_bioem3d_cost(map1, map2) def run(config): """ @@ -69,8 +89,7 @@ def run(config): label_key = config["data"]["submission"]["label_key"] user_submission_label = submission[label_key] - # n_trunc = 10 - metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc] + metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) results_dict = {} results_dict["config"] = config @@ -80,9 +99,12 @@ def run(config): cost_funcs_d = { "fsc": compute_cost_fsc_chunk, - "corr": compute_cost_corr, - "l2": compute_cost_l2, - "bioem": compute_bioem3d_cost, + "corrold": compute_cost_corr, + "corr": Correlation(config).get_distance_matrix, + "l2old": compute_cost_l2, + "l2": L2DistanceSum(config).get_distance_matrix, + "bioemold": compute_bioem3d_cost, + "bioem": BioEM3dDistance(config).get_distance_matrix, } maps_user_flat = submission[submission_volume_key].reshape( @@ -127,12 +149,20 @@ def run(config): ) cost_matrix = cost_matrix.numpy() computed_assets["fsc_matrix"] = fsc_matrix - elif cost_label == "nope": - l2_map2map_distance = L2Distance(config) - cost_matrix = l2_map2map_distance.get_distance_matrix( + elif cost_label == "l2": + cost_matrix = cost_func( + maps_gt_flat, maps_user_flat + ).numpy() + elif cost_label == "corr": + corr_map2map_distance = Correlation(config) + cost_matrix = corr_map2map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat + ).numpy() + elif cost_label == "bioem": + bioem_map2map_distance = BioEM3dDistance(config) + cost_matrix = bioem_map2map_distance.get_distance_matrix( maps_gt_flat, maps_user_flat ).numpy() - print('run new method') else: cost_matrix = vmap_distance( diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 35d9791..5a853a5 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -26,8 +26,11 @@ class MapToMapResultsValidator: config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None + corrold: Optional[dict] = None l2: Optional[dict] = None + l2old: Optional[dict] = None bioem: Optional[dict] = None + bioemold: Optional[dict] = None fsc: Optional[dict] = None def __post_init__(self): @@ -142,8 +145,11 @@ class DistributionToDistributionResultsValidator: id: str fsc: Optional[dict] = None bioem: Optional[dict] = None + bioemold: Optional[dict] = None l2: Optional[dict] = None + l2old: Optional[dict] = None corr: Optional[dict] = None + corrold: Optional[dict] = None def __post_init__(self): validate_input_config_disttodist(self.config) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 85d663d..17a2223 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -15,6 +15,11 @@ data: analysis: metrics: - l2 + - l2old + - corr + - corrold + - bioem + - bioemold chunk_size_submission: 80 chunk_size_gt: 190 normalize: From 4438b92a3cd5904a9d3ca480f7f5e993de3376be Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 10:46:18 -0400 Subject: [PATCH 03/13] remove old reference to l2, corr, and bioem3d --- .../_map_to_map/map_to_map_distance_matrix.py | 40 +------------------ .../data/_validation/output_validators.py | 6 --- .../config_files/test_config_map_to_map.yaml | 3 -- 3 files changed, 1 insertion(+), 48 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 9964069..426dedf 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -12,22 +12,6 @@ from ..data._validation.output_validators import MapToMapResultsValidator -def vmap_distance( - maps_gt, - maps_submission, - map_to_map_distance, - chunk_size_gt=None, - chunk_size_submission=None, -): - return torch.vmap( - lambda maps_gt: torch.vmap( - lambda maps_submission: map_to_map_distance(maps_gt, maps_submission), - chunk_size=chunk_size_submission, - )(maps_submission), - chunk_size=chunk_size_gt, - )(maps_gt) - - class MapToMapDistance: def __init__(self, config): self.config = config @@ -99,11 +83,8 @@ def run(config): cost_funcs_d = { "fsc": compute_cost_fsc_chunk, - "corrold": compute_cost_corr, "corr": Correlation(config).get_distance_matrix, - "l2old": compute_cost_l2, "l2": L2DistanceSum(config).get_distance_matrix, - "bioemold": compute_bioem3d_cost, "bioem": BioEM3dDistance(config).get_distance_matrix, } @@ -149,29 +130,10 @@ def run(config): ) cost_matrix = cost_matrix.numpy() computed_assets["fsc_matrix"] = fsc_matrix - elif cost_label == "l2": + else: cost_matrix = cost_func( maps_gt_flat, maps_user_flat ).numpy() - elif cost_label == "corr": - corr_map2map_distance = Correlation(config) - cost_matrix = corr_map2map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat - ).numpy() - elif cost_label == "bioem": - bioem_map2map_distance = BioEM3dDistance(config) - cost_matrix = bioem_map2map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat - ).numpy() - - else: - cost_matrix = vmap_distance( - maps_gt_flat, - maps_user_flat, - cost_func, - chunk_size_gt=config["analysis"]["chunk_size_gt"], - chunk_size_submission=config["analysis"]["chunk_size_submission"], - ).numpy() cost_matrix_df = pd.DataFrame( cost_matrix, columns=None, index=metadata_gt.populations.tolist() diff --git a/src/cryo_challenge/data/_validation/output_validators.py b/src/cryo_challenge/data/_validation/output_validators.py index 5a853a5..35d9791 100644 --- a/src/cryo_challenge/data/_validation/output_validators.py +++ b/src/cryo_challenge/data/_validation/output_validators.py @@ -26,11 +26,8 @@ class MapToMapResultsValidator: config: dict user_submitted_populations: torch.Tensor corr: Optional[dict] = None - corrold: Optional[dict] = None l2: Optional[dict] = None - l2old: Optional[dict] = None bioem: Optional[dict] = None - bioemold: Optional[dict] = None fsc: Optional[dict] = None def __post_init__(self): @@ -145,11 +142,8 @@ class DistributionToDistributionResultsValidator: id: str fsc: Optional[dict] = None bioem: Optional[dict] = None - bioemold: Optional[dict] = None l2: Optional[dict] = None - l2old: Optional[dict] = None corr: Optional[dict] = None - corrold: Optional[dict] = None def __post_init__(self): validate_input_config_disttodist(self.config) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 17a2223..b2f74bb 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -15,11 +15,8 @@ data: analysis: metrics: - l2 - - l2old - corr - - corrold - bioem - - bioemold chunk_size_submission: 80 chunk_size_gt: 190 normalize: From 59b409de2ff26f1bc6c979adaf0e99f1579a7b45 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 11:05:30 -0400 Subject: [PATCH 04/13] working... --- .../_map_to_map/map_to_map_distance_matrix.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 426dedf..cc00ce8 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -31,6 +31,9 @@ def get_distance_matrix(self, maps1, maps2): )(maps1) return distance_matrix + + def get_computed_assets(self, maps1, maps2): + return {} class L2DistanceNorm(MapToMapDistance): def __init__(self, config): @@ -60,6 +63,13 @@ def __init__(self, config): def get_distance(self, map1, map2): return compute_bioem3d_cost(map1, map2) +class FSCDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def get_distance(self, map1, map2): + return compute_cost_fsc_chunk(map1, map2, self.config["data"]["n_pix"]) + def run(config): """ Compare a submission to ground truth. @@ -83,9 +93,9 @@ def run(config): cost_funcs_d = { "fsc": compute_cost_fsc_chunk, - "corr": Correlation(config).get_distance_matrix, - "l2": L2DistanceSum(config).get_distance_matrix, - "bioem": BioEM3dDistance(config).get_distance_matrix, + "corr": Correlation(config), + "l2": L2DistanceSum(config), + "bioem": BioEM3dDistance(config), } maps_user_flat = submission[submission_volume_key].reshape( @@ -114,13 +124,14 @@ def run(config): maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) computed_assets = {} - for cost_label, cost_func in cost_funcs_d.items(): + for cost_label, map_to_map_distance in cost_funcs_d.items(): if cost_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", cost_label) if ( cost_label == "fsc" ): # TODO: make pydantic (include base class). type hint inputs to this (what it needs like gt volumes and populations) # noqa: E501 + cost_func = map_to_map_distance maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) @@ -131,9 +142,14 @@ def run(config): cost_matrix = cost_matrix.numpy() computed_assets["fsc_matrix"] = fsc_matrix else: - cost_matrix = cost_func( + + cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, maps_user_flat ).numpy() + computed_assets = map_to_map_distance.get_computed_assets( + maps_gt_flat, maps_user_flat + ) + computed_assets.update(computed_assets) cost_matrix_df = pd.DataFrame( cost_matrix, columns=None, index=metadata_gt.populations.tolist() From 4dddfc030584e415ca2dff9dd2914a7ba2c48274 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 11:20:18 -0400 Subject: [PATCH 05/13] working... --- .../_map_to_map/map_to_map_distance_matrix.py | 27 +++++++++++-------- .../config_files/test_config_map_to_map.yaml | 5 ++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index cc00ce8..46f7e1c 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -67,8 +67,18 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def get_distance(self, map1, map2): - return compute_cost_fsc_chunk(map1, map2, self.config["data"]["n_pix"]) + def get_distance_matrix(self, maps1, maps2): # custom method + maps_gt_flat = maps1 + maps_user_flat = maps2 + n_pix = self.config["data"]["n_pix"] + maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) + mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]).data.astype(bool).flatten() + ) + maps_gt_flat_cube[:, mask] = maps_gt_flat + maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) + maps_user_flat_cube[:, mask] = maps_user_flat + return compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) def run(config): """ @@ -92,7 +102,7 @@ def run(config): ) cost_funcs_d = { - "fsc": compute_cost_fsc_chunk, + "fsc": FSCDistance(config), "corr": Correlation(config), "l2": L2DistanceSum(config), "bioem": BioEM3dDistance(config), @@ -104,7 +114,6 @@ def run(config): maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( -1, n_pix**3 ) - # maps_gt_flat = torch.randn(n_trunc, n_pix**3) if config["data"]["mask"]["do"]: mask = ( @@ -131,13 +140,9 @@ def run(config): if ( cost_label == "fsc" ): # TODO: make pydantic (include base class). type hint inputs to this (what it needs like gt volumes and populations) # noqa: E501 - cost_func = map_to_map_distance - maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - maps_gt_flat_cube[:, mask] = maps_gt_flat - maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat - cost_matrix, fsc_matrix = cost_func( - maps_gt_flat_cube, maps_user_flat_cube, n_pix + + cost_matrix, fsc_matrix = map_to_map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat ) cost_matrix = cost_matrix.numpy() computed_assets["fsc_matrix"] = fsc_matrix diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index b2f74bb..09362b5 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -15,8 +15,9 @@ data: analysis: metrics: - l2 - - corr - - bioem + # - corr + # - bioem + - fsc chunk_size_submission: 80 chunk_size_gt: 190 normalize: From e33801a0f0f7f06be09c3387f980fa1efd1bb11b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 11:29:58 -0400 Subject: [PATCH 06/13] complted. FSCDistance class --- .../_map_to_map/map_to_map_distance_matrix.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 46f7e1c..75ce02a 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -66,7 +66,7 @@ def get_distance(self, map1, map2): class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - + def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 @@ -78,7 +78,13 @@ def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) maps_user_flat_cube[:, mask] = maps_user_flat - return compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) + + cost_matrix, fsc_matrix = compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) + self.stored_computed_assets = {'fsc_matrix': fsc_matrix} + return cost_matrix + + def get_computed_assets(self, maps1, maps2): + return self.stored_computed_assets # must run get_distance_matrix first def run(config): """ @@ -137,29 +143,17 @@ def run(config): if cost_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", cost_label) - if ( - cost_label == "fsc" - ): # TODO: make pydantic (include base class). type hint inputs to this (what it needs like gt volumes and populations) # noqa: E501 - - cost_matrix, fsc_matrix = map_to_map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat - ) - cost_matrix = cost_matrix.numpy() - computed_assets["fsc_matrix"] = fsc_matrix - else: - - cost_matrix = map_to_map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat - ).numpy() - computed_assets = map_to_map_distance.get_computed_assets( - maps_gt_flat, maps_user_flat - ) - computed_assets.update(computed_assets) + cost_matrix = map_to_map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat + ).numpy() + computed_assets = map_to_map_distance.get_computed_assets( + maps_gt_flat, maps_user_flat + ) + computed_assets.update(computed_assets) cost_matrix_df = pd.DataFrame( cost_matrix, columns=None, index=metadata_gt.populations.tolist() ) - print(cost_matrix_df) # output results single_distance_results_dict = { From 2730734a19c35db7bd8979b400766142a97a42b0 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 10 Jul 2024 11:44:13 -0400 Subject: [PATCH 07/13] clean up. rename. remove print --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 1 - .../_map_to_map/map_to_map_distance_matrix.py | 10 +++++----- tests/config_files/test_config_map_to_map.yaml | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e3ec1ba..34e126b 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -133,7 +133,6 @@ def compute_cost_fsc_chunk(maps_gt_flat, maps_user_flat, n_pix): cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) for idx in range(len(maps_gt_flat)): - print(idx) corr_vector = fourier_shell_correlation( maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 75ce02a..f424751 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -107,7 +107,7 @@ def run(config): submission[submission_metadata_key] / submission[submission_metadata_key].sum() ) - cost_funcs_d = { + map_to_map_distances = { "fsc": FSCDistance(config), "corr": Correlation(config), "l2": L2DistanceSum(config), @@ -139,9 +139,9 @@ def run(config): maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) computed_assets = {} - for cost_label, map_to_map_distance in cost_funcs_d.items(): - if cost_label in config["analysis"]["metrics"]: # TODO: can remove - print("cost matrix", cost_label) + for distance_label, map_to_map_distance in map_to_map_distances.items(): + if distance_label in config["analysis"]["metrics"]: # TODO: can remove + print("cost matrix", distance_label) cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, maps_user_flat @@ -162,7 +162,7 @@ def run(config): "computed_assets": computed_assets, } - results_dict[cost_label] = single_distance_results_dict + results_dict[distance_label] = single_distance_results_dict # Validate before saving _ = MapToMapResultsValidator.from_dict(results_dict) diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 09362b5..155bd4f 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -15,8 +15,8 @@ data: analysis: metrics: - l2 - # - corr - # - bioem + - corr + - bioem - fsc chunk_size_submission: 80 chunk_size_gt: 190 From 35f74e332dc47a3dfbe057083b81ee7e0658d33b Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 11:39:22 -0400 Subject: [PATCH 08/13] docstrings --- src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index f424751..46d9c04 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -17,9 +17,11 @@ def __init__(self, config): self.config = config def get_distance(self, map1, map2): + """Compute the distance between two maps.""" raise NotImplementedError() def get_distance_matrix(self, maps1, maps2): + """Compute the distance matrix between two sets of maps.""" chunk_size_submission = self.config["analysis"]["chunk_size_submission"] chunk_size_gt = self.config["analysis"]["chunk_size_gt"] distance_matrix = torch.vmap( @@ -33,6 +35,7 @@ def get_distance_matrix(self, maps1, maps2): return distance_matrix def get_computed_assets(self, maps1, maps2): + """Return any computed assets that are needed for (downstream) analysis.""" return {} class L2DistanceNorm(MapToMapDistance): From ab4576327595a4952c68293caaba0ba0d07052e1 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 11:43:44 -0400 Subject: [PATCH 09/13] override --- .../_map_to_map/map_to_map_distance_matrix.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 46d9c04..40b1ad2 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -2,6 +2,8 @@ import pandas as pd import pickle import torch +from typing import override + from .map_to_map_distance import ( compute_bioem3d_cost, @@ -42,6 +44,7 @@ class L2DistanceNorm(MapToMapDistance): def __init__(self, config): super().__init__(config) + @override def get_distance(self, map1, map2): return torch.norm(map1 - map2)**2 @@ -49,6 +52,7 @@ class L2DistanceSum(MapToMapDistance): def __init__(self, config): super().__init__(config) + @override def get_distance(self, map1, map2): return compute_cost_l2(map1, map2) @@ -56,6 +60,7 @@ class Correlation(MapToMapDistance): def __init__(self, config): super().__init__(config) + @override def get_distance(self, map1, map2): return compute_cost_corr(map1, map2) @@ -63,13 +68,15 @@ class BioEM3dDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) + @override def get_distance(self, map1, map2): return compute_bioem3d_cost(map1, map2) class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - + + @override def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 @@ -86,6 +93,7 @@ def get_distance_matrix(self, maps1, maps2): # custom method self.stored_computed_assets = {'fsc_matrix': fsc_matrix} return cost_matrix + @override def get_computed_assets(self, maps1, maps2): return self.stored_computed_assets # must run get_distance_matrix first From ed00ecd4f29e9636de372faedd0a65e548315b7f Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 11:56:36 -0400 Subject: [PATCH 10/13] refactor --- .../_map_to_map/map_to_map_distance.py | 359 +++++++++++------- .../_map_to_map/map_to_map_distance_matrix.py | 110 +----- 2 files changed, 232 insertions(+), 237 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 34e126b..6ad76ce 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -1,143 +1,222 @@ import math import torch -from typing import Optional, Sequence - - -def fourier_shell_correlation( - x: torch.Tensor, - y: torch.Tensor, - dim: Sequence[int] = (-3, -2, -1), - normalize: bool = True, - max_k: Optional[int] = None, -): - """Computes Fourier Shell / Ring Correlation (FSC) between x and y. - - Parameters - ---------- - x : torch.Tensor - First input tensor. - y : torch.Tensor - Second input tensor. - dim : Tuple[int, ...] - Dimensions over which to take the Fourier transform. - normalize : bool - Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). - Note that when `normalize=False`, we still divide by the number of elements in each shell. - max_k : int - The maximum shell to compute the correlation for. - - Returns - ------- - torch.Tensor - The correlation between x and y for each Fourier shell. - """ # noqa: E501 - batch_shape = x.shape[: -len(dim)] - - freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] - freq_total = ( - torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) - ) - - x_f = torch.fft.fftn(x, dim=dim) - y_f = torch.fft.fftn(y, dim=dim) - - n = min(x.shape[d] for d in dim) - - if max_k is None: - max_k = n // 2 - - result = x.new_zeros(batch_shape + (max_k,)) - - for i in range(1, max_k + 1): - mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) - x_ri = x_f[..., mask] - y_fi = y_f[..., mask] - - if x.is_cuda: - c_i = torch.linalg.vecdot(x_ri, y_fi).real - else: - # vecdot currently bugged on CPU for torch 2.0 in some configurations - c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real - - if normalize: - c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) - else: - c_i /= x_ri.shape[-1] - - result[..., i - 1] = c_i - - return result - - -def compute_bioem3d_cost(map1, map2): - """ - Compute the cost between two maps using the BioEM cost function in 3D. - - Notes - ----- - See Eq. 10 in 10.1016/j.jsb.2013.10.006 - - Parameters - ---------- - map1 : torch.Tensor - shape (n_pix,n_pix,n_pix) - map2 : torch.Tensor - shape (n_pix,n_pix,n_pix) - - Returns - ------- - cost : torch.Tensor - shape (1,) - """ - m1, m2 = map1.reshape(-1), map2.reshape(-1) - co = m1.sum() - cc = m2.sum() - coo = m1.pow(2).sum() - ccc = m2.pow(2).sum() - coc = (m1 * m2).sum() - - N = len(m1) - - t1 = 2 * torch.pi * math.exp(1) - t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc - t3 = (N - 2) * (N * ccc - cc * cc) - - smallest_float = torch.finfo(m1.dtype).tiny - log_prob = ( - 0.5 * torch.pi - + math.log(t1) * (1 - N / 2) - + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) - + t3.clamp(smallest_float).log() * (N / 2 - 2) - ) - cost = -log_prob - return cost - - -def compute_cost_l2(map_1, map_2): - return ((map_1 - map_2) ** 2).sum() - - -def compute_cost_corr(map_1, map_2): - return (map_1 * map_2).sum() - - -def compute_cost_fsc_chunk(maps_gt_flat, maps_user_flat, n_pix): - """ - Compute the cost between two maps using the Fourier Shell Correlation in 3D. - - Notes - ----- - fourier_shell_correlation can only batch on first input set of maps, - so we compute the cost one row (gt map idx) at a time - """ - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) - for idx in range(len(maps_gt_flat)): - corr_vector = fourier_shell_correlation( - maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), - maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), +from typing import Optional, Sequence, override +import mrcfile + +class MapToMapDistance: + def __init__(self, config): + self.config = config + + def get_distance(self, map1, map2): + """Compute the distance between two maps.""" + raise NotImplementedError() + + def get_distance_matrix(self, maps1, maps2): + """Compute the distance matrix between two sets of maps.""" + chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1) + + return distance_matrix + + def get_computed_assets(self, maps1, maps2): + """Return any computed assets that are needed for (downstream) analysis.""" + return {} + +class L2DistanceNorm(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + @override + def get_distance(self, map1, map2): + return torch.norm(map1 - map2)**2 + +class L2DistanceSum(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def compute_cost_l2(map_1, map_2): + return ((map_1 - map_2) ** 2).sum() + + @override + def get_distance(self, map1, map2): + return self.compute_cost_l2(map1, map2) + +class Correlation(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def compute_cost_corr(map_1, map_2): + return (map_1 * map_2).sum() + + @override + def get_distance(self, map1, map2): + return self.compute_cost_corr(map1, map2) + +class BioEM3dDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def compute_bioem3d_cost(map1, map2): + """ + Compute the cost between two maps using the BioEM cost function in 3D. + + Notes + ----- + See Eq. 10 in 10.1016/j.jsb.2013.10.006 + + Parameters + ---------- + map1 : torch.Tensor + shape (n_pix,n_pix,n_pix) + map2 : torch.Tensor + shape (n_pix,n_pix,n_pix) + + Returns + ------- + cost : torch.Tensor + shape (1,) + """ + m1, m2 = map1.reshape(-1), map2.reshape(-1) + co = m1.sum() + cc = m2.sum() + coo = m1.pow(2).sum() + ccc = m2.pow(2).sum() + coc = (m1 * m2).sum() + + N = len(m1) + + t1 = 2 * torch.pi * math.exp(1) + t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc + t3 = (N - 2) * (N * ccc - cc * cc) + + smallest_float = torch.finfo(m1.dtype).tiny + log_prob = ( + 0.5 * torch.pi + + math.log(t1) * (1 - N / 2) + + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) + + t3.clamp(smallest_float).log() * (N / 2 - 2) ) - dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff - fsc_matrix[idx] = corr_vector - cost_matrix[idx] = dist - return cost_matrix, fsc_matrix + cost = -log_prob + return cost + + @override + def get_distance(self, map1, map2): + return self.compute_bioem3d_cost(map1, map2) + +class FSCDistance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def fourier_shell_correlation( + x: torch.Tensor, + y: torch.Tensor, + dim: Sequence[int] = (-3, -2, -1), + normalize: bool = True, + max_k: Optional[int] = None, + ): + """Computes Fourier Shell / Ring Correlation (FSC) between x and y. + + Parameters + ---------- + x : torch.Tensor + First input tensor. + y : torch.Tensor + Second input tensor. + dim : Tuple[int, ...] + Dimensions over which to take the Fourier transform. + normalize : bool + Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). + Note that when `normalize=False`, we still divide by the number of elements in each shell. + max_k : int + The maximum shell to compute the correlation for. + + Returns + ------- + torch.Tensor + The correlation between x and y for each Fourier shell. + """ # noqa: E501 + batch_shape = x.shape[: -len(dim)] + + freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] + freq_total = ( + torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) + ) + + x_f = torch.fft.fftn(x, dim=dim) + y_f = torch.fft.fftn(y, dim=dim) + + n = min(x.shape[d] for d in dim) + + if max_k is None: + max_k = n // 2 + + result = x.new_zeros(batch_shape + (max_k,)) + + for i in range(1, max_k + 1): + mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) + x_ri = x_f[..., mask] + y_fi = y_f[..., mask] + + if x.is_cuda: + c_i = torch.linalg.vecdot(x_ri, y_fi).real + else: + # vecdot currently bugged on CPU for torch 2.0 in some configurations + c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real + + if normalize: + c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) + else: + c_i /= x_ri.shape[-1] + + result[..., i - 1] = c_i + + return result + + def compute_cost_fsc_chunk(maps_gt_flat, maps_user_flat, n_pix): + """ + Compute the cost between two maps using the Fourier Shell Correlation in 3D. + + Notes + ----- + fourier_shell_correlation can only batch on first input set of maps, + so we compute the cost one row (gt map idx) at a time + """ + cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) + fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) + for idx in range(len(maps_gt_flat)): + corr_vector = self.fourier_shell_correlation( + maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), + maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), + ) + dist = 1 - corr_vector.mean(dim=1) # TODO: spectral cutoff + fsc_matrix[idx] = corr_vector + cost_matrix[idx] = dist + return cost_matrix, fsc_matrix + + @override + def get_distance_matrix(self, maps1, maps2): # custom method + maps_gt_flat = maps1 + maps_user_flat = maps2 + n_pix = self.config["data"]["n_pix"] + maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) + mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]).data.astype(bool).flatten() + ) + maps_gt_flat_cube[:, mask] = maps_gt_flat + maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) + maps_user_flat_cube[:, mask] = maps_user_flat + + cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) + self.stored_computed_assets = {'fsc_matrix': fsc_matrix} + return cost_matrix + + @override + def get_computed_assets(self, maps1, maps2): + return self.stored_computed_assets # must run get_distance_matrix first \ No newline at end of file diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 40b1ad2..e39942f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -2,106 +2,28 @@ import pandas as pd import pickle import torch -from typing import override - -from .map_to_map_distance import ( - compute_bioem3d_cost, - compute_cost_corr, - compute_cost_fsc_chunk, - compute_cost_l2, -) from ..data._validation.output_validators import MapToMapResultsValidator +from .._map_to_map.map_to_map_distance import FSCDistance, Correlation, L2DistanceSum, BioEM3dDistance -class MapToMapDistance: - def __init__(self, config): - self.config = config - - def get_distance(self, map1, map2): - """Compute the distance between two maps.""" - raise NotImplementedError() - - def get_distance_matrix(self, maps1, maps2): - """Compute the distance matrix between two sets of maps.""" - chunk_size_submission = self.config["analysis"]["chunk_size_submission"] - chunk_size_gt = self.config["analysis"]["chunk_size_gt"] - distance_matrix = torch.vmap( - lambda maps1: torch.vmap( - lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, - )(maps2), - chunk_size=chunk_size_gt, - )(maps1) - - return distance_matrix - - def get_computed_assets(self, maps1, maps2): - """Return any computed assets that are needed for (downstream) analysis.""" - return {} - -class L2DistanceNorm(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - @override - def get_distance(self, map1, map2): - return torch.norm(map1 - map2)**2 - -class L2DistanceSum(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - @override - def get_distance(self, map1, map2): - return compute_cost_l2(map1, map2) - -class Correlation(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - @override - def get_distance(self, map1, map2): - return compute_cost_corr(map1, map2) - -class BioEM3dDistance(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - @override - def get_distance(self, map1, map2): - return compute_bioem3d_cost(map1, map2) - -class FSCDistance(MapToMapDistance): - def __init__(self, config): - super().__init__(config) - - @override - def get_distance_matrix(self, maps1, maps2): # custom method - maps_gt_flat = maps1 - maps_user_flat = maps2 - n_pix = self.config["data"]["n_pix"] - maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]).data.astype(bool).flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat - maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat - - cost_matrix, fsc_matrix = compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) - self.stored_computed_assets = {'fsc_matrix': fsc_matrix} - return cost_matrix - - @override - def get_computed_assets(self, maps1, maps2): - return self.stored_computed_assets # must run get_distance_matrix first - +AVAILABLE_MAP2MAP_DISTANCES = { + "fsc": FSCDistance, + "corr": Correlation, + "l2": L2DistanceSum, + "bioem": BioEM3dDistance, + } + def run(config): """ Compare a submission to ground truth. """ + map_to_map_distances = { + distance_label: distance_class(config) + for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() + } + n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) @@ -118,12 +40,6 @@ def run(config): submission[submission_metadata_key] / submission[submission_metadata_key].sum() ) - map_to_map_distances = { - "fsc": FSCDistance(config), - "corr": Correlation(config), - "l2": L2DistanceSum(config), - "bioem": BioEM3dDistance(config), - } maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 From 3988c47baad753d3685bf97875e4e0785a25a4d1 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 12:32:59 -0400 Subject: [PATCH 11/13] passing tests --- .../_map_to_map/map_to_map_distance.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 6ad76ce..9b62543 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -1,6 +1,6 @@ import math import torch -from typing import Optional, Sequence, override +from typing import Optional, Sequence import mrcfile class MapToMapDistance: @@ -33,7 +33,6 @@ class L2DistanceNorm(MapToMapDistance): def __init__(self, config): super().__init__(config) - @override def get_distance(self, map1, map2): return torch.norm(map1 - map2)**2 @@ -41,10 +40,9 @@ class L2DistanceSum(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_cost_l2(map_1, map_2): + def compute_cost_l2(self, map_1, map_2): return ((map_1 - map_2) ** 2).sum() - @override def get_distance(self, map1, map2): return self.compute_cost_l2(map1, map2) @@ -52,10 +50,9 @@ class Correlation(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_cost_corr(map_1, map_2): + def compute_cost_corr(self, map_1, map_2): return (map_1 * map_2).sum() - @override def get_distance(self, map1, map2): return self.compute_cost_corr(map1, map2) @@ -63,7 +60,7 @@ class BioEM3dDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_bioem3d_cost(map1, map2): + def compute_bioem3d_cost(self, map1, map2): """ Compute the cost between two maps using the BioEM cost function in 3D. @@ -106,7 +103,6 @@ def compute_bioem3d_cost(map1, map2): cost = -log_prob return cost - @override def get_distance(self, map1, map2): return self.compute_bioem3d_cost(map1, map2) @@ -115,6 +111,7 @@ def __init__(self, config): super().__init__(config) def fourier_shell_correlation( + self, x: torch.Tensor, y: torch.Tensor, dim: Sequence[int] = (-3, -2, -1), @@ -179,7 +176,7 @@ def fourier_shell_correlation( return result - def compute_cost_fsc_chunk(maps_gt_flat, maps_user_flat, n_pix): + def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ Compute the cost between two maps using the Fourier Shell Correlation in 3D. @@ -200,7 +197,6 @@ def compute_cost_fsc_chunk(maps_gt_flat, maps_user_flat, n_pix): cost_matrix[idx] = dist return cost_matrix, fsc_matrix - @override def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 @@ -217,6 +213,5 @@ def get_distance_matrix(self, maps1, maps2): # custom method self.stored_computed_assets = {'fsc_matrix': fsc_matrix} return cost_matrix - @override def get_computed_assets(self, maps1, maps2): return self.stored_computed_assets # must run get_distance_matrix first \ No newline at end of file From 9e323a9c960ce4a1c9fc99bf65f604e3b6560036 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 12:37:59 -0400 Subject: [PATCH 12/13] rename config --- ...g_map_to_map_distance_matrix.yaml => config_map_to_map.yaml} | 0 src/cryo_challenge/_commands/run_map2map_pipeline.py | 2 +- .../{map_to_map_distance_matrix.py => map_to_map_pipeline.py} | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename config_files/{config_map_to_map_distance_matrix.yaml => config_map_to_map.yaml} (100%) rename src/cryo_challenge/_map_to_map/{map_to_map_distance_matrix.py => map_to_map_pipeline.py} (100%) diff --git a/config_files/config_map_to_map_distance_matrix.yaml b/config_files/config_map_to_map.yaml similarity index 100% rename from config_files/config_map_to_map_distance_matrix.yaml rename to config_files/config_map_to_map.yaml diff --git a/src/cryo_challenge/_commands/run_map2map_pipeline.py b/src/cryo_challenge/_commands/run_map2map_pipeline.py index 0f3d85f..3cf602d 100644 --- a/src/cryo_challenge/_commands/run_map2map_pipeline.py +++ b/src/cryo_challenge/_commands/run_map2map_pipeline.py @@ -6,7 +6,7 @@ import os import yaml -from .._map_to_map.map_to_map_distance_matrix import run +from .._map_to_map.map_to_map_pipeline import run from ..data._validation.config_validators import validate_input_config_mtm diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py similarity index 100% rename from src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py rename to src/cryo_challenge/_map_to_map/map_to_map_pipeline.py From 3d3c8f1a1a922f33ccaf08977260c3f7dfabe187 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 12:42:42 -0400 Subject: [PATCH 13/13] @override from typing_extensions --- .../_map_to_map/map_to_map_distance.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 9b62543..55d01d3 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -1,6 +1,7 @@ import math import torch from typing import Optional, Sequence +from typing_extensions import override import mrcfile class MapToMapDistance: @@ -33,6 +34,7 @@ class L2DistanceNorm(MapToMapDistance): def __init__(self, config): super().__init__(config) + @override def get_distance(self, map1, map2): return torch.norm(map1 - map2)**2 @@ -42,7 +44,8 @@ def __init__(self, config): def compute_cost_l2(self, map_1, map_2): return ((map_1 - map_2) ** 2).sum() - + + @override def get_distance(self, map1, map2): return self.compute_cost_l2(map1, map2) @@ -52,7 +55,8 @@ def __init__(self, config): def compute_cost_corr(self, map_1, map_2): return (map_1 * map_2).sum() - + + @override def get_distance(self, map1, map2): return self.compute_cost_corr(map1, map2) @@ -102,7 +106,8 @@ def compute_bioem3d_cost(self, map1, map2): ) cost = -log_prob return cost - + + @override def get_distance(self, map1, map2): return self.compute_bioem3d_cost(map1, map2) @@ -197,6 +202,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix[idx] = dist return cost_matrix, fsc_matrix + @override def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 @@ -212,6 +218,7 @@ def get_distance_matrix(self, maps1, maps2): # custom method cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) self.stored_computed_assets = {'fsc_matrix': fsc_matrix} return cost_matrix - + + @override def get_computed_assets(self, maps1, maps2): return self.stored_computed_assets # must run get_distance_matrix first \ No newline at end of file