Skip to content

Commit

Permalink
Merge pull request #85 from flatironinstitute/resolution_metric
Browse files Browse the repository at this point in the history
Resolution metric
  • Loading branch information
DSilva27 authored Aug 12, 2024
2 parents 62563ef + 1cd0225 commit d1381c6
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 8 deletions.
1 change: 1 addition & 0 deletions config_files/config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ analysis:
- corr
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
normalize:
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_challenge/_commands/run_map2map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
args = parser.parse_args()
# args = parser.parse_args()
main(add_args(parser).parse_args())
41 changes: 37 additions & 4 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Sequence
from typing_extensions import override
import mrcfile
import numpy as np


class MapToMapDistance:
Expand All @@ -13,7 +14,7 @@ def get_distance(self, map1, map2):
"""Compute the distance between two maps."""
raise NotImplementedError()

def get_distance_matrix(self, maps1, maps2):
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
"""Compute the distance matrix between two sets of maps."""
chunk_size_submission = self.config["analysis"]["chunk_size_submission"]
chunk_size_gt = self.config["analysis"]["chunk_size_gt"]
Expand All @@ -27,7 +28,7 @@ def get_distance_matrix(self, maps1, maps2):

return distance_matrix

def get_computed_assets(self, maps1, maps2):
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
"""Return any computed assets that are needed for (downstream) analysis."""
return {}

Expand Down Expand Up @@ -214,7 +215,9 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix):
return cost_matrix, fsc_matrix

@override
def get_distance_matrix(self, maps1, maps2): # custom method
def get_distance_matrix(
self, maps1, maps2, global_store_of_running_results
): # custom method
maps_gt_flat = maps1
maps_user_flat = maps2
n_pix = self.config["data"]["n_pix"]
Expand All @@ -235,5 +238,35 @@ def get_distance_matrix(self, maps1, maps2): # custom method
return cost_matrix

@override
def get_computed_assets(self, maps1, maps2):
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first


class ResDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

@override
def get_distance_matrix(
self, maps1, maps2, global_store_of_running_results
): # custom method
# get fsc matrix
fourier_pixel_max = (
self.config["data"]["n_pix"] // 2
) # TODO: check for odd psizes if this should be +1
psize = self.config["data"]["psize"]
fsc_matrix = global_store_of_running_results["fsc"]["computed_assets"][
"fsc_matrix"
]
units_Angstroms = (
2 * psize / (np.arange(1, fourier_pixel_max + 1) / fourier_pixel_max)
)

def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist

res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix)
self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist}
return units_Angstroms[res_fsc_half]
12 changes: 9 additions & 3 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Correlation,
L2DistanceSum,
BioEM3dDistance,
ResDistance,
)


Expand All @@ -17,6 +18,7 @@
"corr": Correlation,
"l2": L2DistanceSum,
"bioem": BioEM3dDistance,
"res": ResDistance,
}


Expand Down Expand Up @@ -76,10 +78,14 @@ def run(config):
print("cost matrix", distance_label)

cost_matrix = map_to_map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
maps_gt_flat,
maps_user_flat,
global_store_of_running_results=results_dict,
)
computed_assets = map_to_map_distance.get_computed_assets(
maps_gt_flat, maps_user_flat
maps_gt_flat,
maps_user_flat,
global_store_of_running_results=results_dict,
)
computed_assets.update(computed_assets)

Expand Down
2 changes: 2 additions & 0 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MapToMapResultsValidator:
l2: Optional[dict] = None
bioem: Optional[dict] = None
fsc: Optional[dict] = None
res: Optional[dict] = None

def __post_init__(self):
validate_input_config_mtm(self.config)
Expand Down Expand Up @@ -147,6 +148,7 @@ class DistributionToDistributionResultsValidator:
id: str
fsc: Optional[dict] = None
bioem: Optional[dict] = None
res: Optional[dict] = None
l2: Optional[dict] = None
corr: Optional[dict] = None

Expand Down
1 change: 1 addition & 0 deletions tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ analysis:
- corr
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
normalize:
Expand Down

0 comments on commit d1381c6

Please sign in to comment.