Skip to content

Commit

Permalink
remove low memory versions
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Sep 18, 2024
1 parent cb7085e commit 1f4057e
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 268 deletions.
149 changes: 0 additions & 149 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,47 +122,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return {}


class MapToMapDistanceLowMemory(MapToMapDistance):
"""General class for map-to-map distance metrics that require low memory."""

def __init__(self, config):
super().__init__(config)
self.config = config

def compute_cost(self, map_1, map_2):
raise NotImplementedError()

@override
def get_distance(self, map1, map2, global_store_of_running_results):
map1 = map1.flatten()
if self.config["analysis"]["normalize"]["do"]:
if self.config["analysis"]["normalize"]["method"] == "median_zscore":
map1 -= map1.median()
map1 /= map1.std()
else:
raise NotImplementedError(
f"Normalization method {self.config['analysis']['normalize']['method']} not implemented."
)
if self.config["data"]["mask"]["do"]:
map1 = map1[self.mask]

return self.compute_cost(map1, map2)

@override
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
maps_gt_flat = maps1
maps_user_flat = maps2
cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat))
for idx_gt in range(len(maps_gt_flat)):
for idx_user in range(len(maps_user_flat)):
cost_matrix[idx_gt, idx_user] = self.get_distance(
maps_gt_flat[idx_gt],
maps_user_flat[idx_user],
global_store_of_running_results,
)
return cost_matrix


def norm2(map1, map2):
return torch.norm(map1 - map2) ** 2

Expand All @@ -178,17 +137,6 @@ def get_distance(self, map1, map2):
return norm2(map1, map2)


class L2DistanceNormLowMemory(MapToMapDistanceLowMemory):
"""L2 distance norm"""

def __init__(self, config):
super().__init__(config)

@override
def compute_cost(self, map1, map2):
return norm2(map1, map2)


def correlation(map1, map2):
return (map1 * map2).sum()

Expand All @@ -206,17 +154,6 @@ def get_distance(self, map1, map2):
return correlation(map1, map2)


class CorrelationLowMemory(MapToMapDistanceLowMemory):
"""Correlation."""

def __init__(self, config):
super().__init__(config)

@override
def compute_cost(self, map1, map2):
return correlation(map1, map2)


def compute_bioem3d_cost(map1, map2):
"""
Compute the cost between two maps using the BioEM cost function in 3D.
Expand Down Expand Up @@ -272,17 +209,6 @@ def get_distance(self, map1, map2):
return compute_bioem3d_cost(map1, map2)


class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory):
"""BioEM 3D distance."""

def __init__(self, config):
super().__init__(config)

@override
def compute_cost(self, map1, map2):
return compute_bioem3d_cost(map1, map2)


def fourier_shell_correlation(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -457,69 +383,6 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first


class FSCDistanceLowMemory(MapToMapDistance):
"""Fourier Shell Correlation distance."""

def __init__(self, config):
super().__init__(config)
self.n_pix = self.config["data"]["n_pix"]
self.config = config

def compute_cost(self, map_1, map_2):
raise NotImplementedError()

@override
def get_distance(self, map1, map2, global_store_of_running_results):
map_gt_flat = map1 = map1.flatten()
map_gt_flat_cube = torch.zeros(self.n_pix**3)
if self.config["data"]["mask"]["do"]:
map_gt_flat = map_gt_flat[self.mask]
map_gt_flat_cube[self.mask] = map_gt_flat
else:
map_gt_flat_cube = map_gt_flat

corr_vector = fourier_shell_correlation(
map_gt_flat_cube.reshape(self.n_pix, self.n_pix, self.n_pix),
map2.reshape(self.n_pix, self.n_pix, self.n_pix),
)
dist = 1 - corr_vector.mean() # TODO: spectral cutoff
self.stored_computed_assets = {"corr_vector": corr_vector}
return dist

@override
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
maps_gt_flat = maps1
maps_user_flat = maps2
cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat))
fsc_matrix = torch.zeros(
len(maps_gt_flat), len(maps_user_flat), self.n_pix // 2
)
for idx_gt in range(len(maps_gt_flat)):
for idx_user in range(len(maps_user_flat)):
cost_matrix[idx_gt, idx_user] = self.get_distance(
maps_gt_flat[idx_gt],
maps_user_flat[idx_user],
global_store_of_running_results,
)
fsc_matrix[idx_gt, idx_user] = self.stored_computed_assets[
"corr_vector"
]
self.stored_computed_assets = {"fsc_matrix": fsc_matrix}
return cost_matrix

@override
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
"""
Return any computed assets that are needed for (downstream) analysis.
Notes
-----
The FSC matrix is stored in the computed assets.
Must run get_distance_matrix first.
"""
return self.stored_computed_assets


class FSCResDistance(MapToMapDistance):
"""FSC Resolution distance.
Expand Down Expand Up @@ -555,15 +418,3 @@ def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix)
self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist}
return units_Angstroms[res_fsc_half]


class FSCResDistanceLowMemory(FSCResDistance):
"""FSC Resolution distance.
The resolution at which the Fourier Shell Correlation reaches 0.5.
Built on top of the FSCDistance class. This needs to be run first and store the FSC matrix in the computed assets.
"""

def __init__(self, config):
super().__init__(config)
self.fsc_label = "fsc_low_memory"
23 changes: 0 additions & 23 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,10 @@
from .._map_to_map.map_to_map_distance import (
GT_Dataset,
FSCDistance,
FSCDistanceLowMemory,
Correlation,
CorrelationLowMemory,
L2DistanceNorm,
L2DistanceNormLowMemory,
BioEM3dDistance,
BioEM3dDistanceLowMemory,
FSCResDistance,
FSCResDistanceLowMemory,
)


Expand All @@ -27,14 +22,6 @@
"res": FSCResDistance,
}

AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY = {
"corr_low_memory": CorrelationLowMemory,
"l2_low_memory": L2DistanceNormLowMemory,
"bioem_low_memory": BioEM3dDistanceLowMemory,
"fsc_low_memory": FSCDistanceLowMemory,
"res_low_memory": FSCResDistanceLowMemory,
}


def run(config):
"""
Expand All @@ -47,16 +34,6 @@ def run(config):
if distance_label in config["analysis"]["metrics"]
}

map_to_map_distances_low_memory = {
distance_label: distance_class(config)
for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES_LOW_MEMORY.items()
if distance_label in config["analysis"]["metrics"]
}

assert len(map_to_map_distances_low_memory) == 0 or len(map_to_map_distances) == 0
if len(map_to_map_distances_low_memory) > 0:
map_to_map_distances = map_to_map_distances_low_memory

do_low_memory_mode = config["analysis"]["low_memory"]["do"]

n_pix = config["data"]["n_pix"]
Expand Down
30 changes: 0 additions & 30 deletions tests/config_files/test_config_map_to_map_low_memory.yaml

This file was deleted.

This file was deleted.

35 changes: 0 additions & 35 deletions tests/test_map_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,3 @@ def test_run_map2map_pipeline():
results_dict[metric]["cost_matrix"].values,
results_dict_low_memory[metric]["cost_matrix"].values,
)

for config_fname, config_fname_low_memory in zip(
[
"tests/config_files/test_config_map_to_map.yaml",
"tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml",
],
[
"tests/config_files/test_config_map_to_map_low_memory.yaml",
"tests/config_files/test_config_map_to_map_low_memory_nomask_nonormalize.yaml",
],
):
args = OmegaConf.create({"config": config_fname})
results_dict = run_map2map_pipeline.main(args)

args_low_memory = OmegaConf.create({"config": config_fname_low_memory})
results_dict_low_memory = run_map2map_pipeline.main(args_low_memory)
for metric in ["fsc", "corr", "l2", "bioem"]:
if metric == "fsc":
np.allclose(
results_dict[metric]["computed_assets"]["fsc_matrix"],
results_dict_low_memory[metric + "_low_memory"]["computed_assets"][
"fsc_matrix"
],
)
elif metric == "res":
np.allclose(
results_dict[metric]["computed_assets"]["fraction_nyquist"],
results_dict_low_memory[metric + "_low_memory"]["computed_assets"][
"fraction_nyquist"
],
)
np.allclose(
results_dict[metric]["cost_matrix"].values,
results_dict_low_memory[metric + "_low_memory"]["cost_matrix"].values,
)

0 comments on commit 1f4057e

Please sign in to comment.