Skip to content

Commit

Permalink
global_store_of_running_results for res distance from fsc computed as…
Browse files Browse the repository at this point in the history
…sets
  • Loading branch information
geoffwoollard committed Jul 12, 2024
1 parent f863466 commit 2ae38ac
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, config):
def get_distance(self, map1, map2):
raise NotImplementedError()

def get_distance_matrix(self, maps1, maps2):
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
chunk_size_submission = self.config["analysis"]["chunk_size_submission"]
chunk_size_gt = self.config["analysis"]["chunk_size_gt"]
distance_matrix = torch.vmap(
Expand All @@ -33,7 +33,7 @@ def get_distance_matrix(self, maps1, maps2):

return distance_matrix.numpy()

def get_computed_assets(self, maps1, maps2):
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return {}

class L2DistanceNorm(MapToMapDistance):
Expand Down Expand Up @@ -68,7 +68,7 @@ class FSCDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance_matrix(self, maps1, maps2): # custom method
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method
maps_gt_flat = maps1
maps_user_flat = maps2
n_pix = self.config["data"]["n_pix"]
Expand All @@ -84,24 +84,18 @@ def get_distance_matrix(self, maps1, maps2): # custom method
self.stored_computed_assets = {'fsc_matrix': fsc_matrix}
return cost_matrix.numpy()

def get_computed_assets(self, maps1, maps2):
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first

class ResDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def get_distance_matrix(self, maps1, maps2): # custom method
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): # custom method
# get fsc matrix
fourier_pixel_max = 120
psize = 2.146
fname = 'tests/results/test_map_to_map_distance_matrix_submission_0.pkl' #self.config['external']

with open(fname, 'rb') as f:
data = pickle.load(f)

# fsc_matrix = fscs_sorted = torch.zeros(len(maps1), len(maps2), fourier_pixel_max)
fsc_matrix = data['fsc']['computed_assets']['fsc_matrix']
fsc_matrix = global_store_of_running_results['fsc']['computed_assets']['fsc_matrix']
units_Angstroms = 2 * psize / (np.arange(1,fourier_pixel_max+1) / fourier_pixel_max)
def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
Expand Down Expand Up @@ -172,10 +166,10 @@ def run(config):
print("cost matrix", distance_label)

cost_matrix = map_to_map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict,
)
computed_assets = map_to_map_distance.get_computed_assets(
maps_gt_flat, maps_user_flat
maps_gt_flat, maps_user_flat, global_store_of_running_results=results_dict,
)
computed_assets.update(computed_assets)

Expand Down

0 comments on commit 2ae38ac

Please sign in to comment.