From 2ae38acde490d3707699181b238b477caf3f9d32 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Fri, 12 Jul 2024 11:53:47 -0400 Subject: [PATCH] global_store_of_running_results for res distance from fsc computed assets --- .../_map_to_map/map_to_map_distance_matrix.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 0ba9a27..528f84d 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -20,7 +20,7 @@ def __init__(self, config): def get_distance(self, map1, map2): raise NotImplementedError() - def get_distance_matrix(self, maps1, maps2): + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): chunk_size_submission = self.config["analysis"]["chunk_size_submission"] chunk_size_gt = self.config["analysis"]["chunk_size_gt"] distance_matrix = torch.vmap( @@ -33,7 +33,7 @@ def get_distance_matrix(self, maps1, maps2): return distance_matrix.numpy() - def get_computed_assets(self, maps1, maps2): + def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} class L2DistanceNorm(MapToMapDistance): @@ -68,7 +68,7 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def get_distance_matrix(self, maps1, maps2): # custom method + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 n_pix = self.config["data"]["n_pix"] @@ -84,24 +84,18 @@ def get_distance_matrix(self, maps1, maps2): # custom method self.stored_computed_assets = {'fsc_matrix': fsc_matrix} return cost_matrix.numpy() - def get_computed_assets(self, maps1, maps2): + def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first class ResDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def get_distance_matrix(self, maps1, maps2): # custom method + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method # get fsc matrix fourier_pixel_max = 120 psize = 2.146 - fname = 'tests/results/test_map_to_map_distance_matrix_submission_0.pkl' #self.config['external'] - - with open(fname, 'rb') as f: - data = pickle.load(f) - - # fsc_matrix = fscs_sorted = torch.zeros(len(maps1), len(maps2), fourier_pixel_max) - fsc_matrix = data['fsc']['computed_assets']['fsc_matrix'] + fsc_matrix = global_store_of_running_results['fsc']['computed_assets']['fsc_matrix'] units_Angstroms = 2 * psize / (np.arange(1,fourier_pixel_max+1) / fourier_pixel_max) def res_at_fsc_threshold(fscs, threshold=0.5): res_fsc_half = np.argmin(fscs > threshold, axis=-1) @@ -172,10 +166,10 @@ def run(config): print("cost matrix", distance_label) cost_matrix = map_to_map_distance.get_distance_matrix( - maps_gt_flat, maps_user_flat + maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict, ) computed_assets = map_to_map_distance.get_computed_assets( - maps_gt_flat, maps_user_flat + maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict, ) computed_assets.update(computed_assets)