Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolution metric #85

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config_files/config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ analysis:
- corr
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
normalize:
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_challenge/_commands/run_map2map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
args = parser.parse_args()
# args = parser.parse_args()
main(add_args(parser).parse_args())
41 changes: 37 additions & 4 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Sequence
from typing_extensions import override
import mrcfile
import numpy as np


class MapToMapDistance:
Expand All @@ -13,7 +14,7 @@ def get_distance(self, map1, map2):
"""Compute the distance between two maps."""
raise NotImplementedError()

def get_distance_matrix(self, maps1, maps2):
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
"""Compute the distance matrix between two sets of maps."""
chunk_size_submission = self.config["analysis"]["chunk_size_submission"]
chunk_size_gt = self.config["analysis"]["chunk_size_gt"]
Expand All @@ -27,7 +28,7 @@ def get_distance_matrix(self, maps1, maps2):

return distance_matrix

def get_computed_assets(self, maps1, maps2):
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
"""Return any computed assets that are needed for (downstream) analysis."""
return {}

Expand Down Expand Up @@ -214,7 +215,9 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix):
return cost_matrix, fsc_matrix

@override
def get_distance_matrix(self, maps1, maps2): # custom method
def get_distance_matrix(
self, maps1, maps2, global_store_of_running_results
): # custom method
maps_gt_flat = maps1
maps_user_flat = maps2
n_pix = self.config["data"]["n_pix"]
Expand All @@ -235,5 +238,35 @@ def get_distance_matrix(self, maps1, maps2): # custom method
return cost_matrix

@override
def get_computed_assets(self, maps1, maps2):
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first


class ResDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

@override
def get_distance_matrix(
self, maps1, maps2, global_store_of_running_results
): # custom method
# get fsc matrix
fourier_pixel_max = (
self.config["data"]["n_pix"] // 2
) # TODO: check for odd psizes if this should be +1
psize = self.config["data"]["psize"]
fsc_matrix = global_store_of_running_results["fsc"]["computed_assets"][
"fsc_matrix"
]
units_Angstroms = (
2 * psize / (np.arange(1, fourier_pixel_max + 1) / fourier_pixel_max)
)

def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist

res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix)
self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist}
return units_Angstroms[res_fsc_half]
12 changes: 9 additions & 3 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Correlation,
L2DistanceSum,
BioEM3dDistance,
ResDistance,
)


Expand All @@ -17,6 +18,7 @@
"corr": Correlation,
"l2": L2DistanceSum,
"bioem": BioEM3dDistance,
"res": ResDistance,
}


Expand Down Expand Up @@ -76,10 +78,14 @@ def run(config):
print("cost matrix", distance_label)

cost_matrix = map_to_map_distance.get_distance_matrix(
maps_gt_flat, maps_user_flat
).numpy()
maps_gt_flat,
maps_user_flat,
global_store_of_running_results=results_dict,
)
computed_assets = map_to_map_distance.get_computed_assets(
maps_gt_flat, maps_user_flat
maps_gt_flat,
maps_user_flat,
global_store_of_running_results=results_dict,
)
computed_assets.update(computed_assets)

Expand Down
2 changes: 2 additions & 0 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MapToMapResultsValidator:
l2: Optional[dict] = None
bioem: Optional[dict] = None
fsc: Optional[dict] = None
res: Optional[dict] = None

def __post_init__(self):
validate_input_config_mtm(self.config)
Expand Down Expand Up @@ -147,6 +148,7 @@ class DistributionToDistributionResultsValidator:
id: str
fsc: Optional[dict] = None
bioem: Optional[dict] = None
res: Optional[dict] = None
l2: Optional[dict] = None
corr: Optional[dict] = None

Expand Down
1 change: 1 addition & 0 deletions tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ analysis:
- corr
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
normalize:
Expand Down