diff --git a/src/cryo_challenge/_commands/run_map2map_pipeline.py b/src/cryo_challenge/_commands/run_map2map_pipeline.py index ab36f7a..90db1aa 100644 --- a/src/cryo_challenge/_commands/run_map2map_pipeline.py +++ b/src/cryo_challenge/_commands/run_map2map_pipeline.py @@ -39,9 +39,7 @@ def main(args): warnexists(config["output"]) mkbasedir(os.path.dirname(config["output"])) - run(config) - - return + return run(config) if __name__ == "__main__": diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index e253d25..5b9dc8e 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -4,28 +4,117 @@ from typing_extensions import override import mrcfile import numpy as np +from torch.utils.data import Dataset + + +class GT_Dataset(Dataset): + def __init__(self, npy_file): + self.npy_file = npy_file + self.data = np.load(npy_file, mmap_mode="r+") + + self.shape = self.data.shape + self._dim = len(self.data.shape) + + def dim(self): + return self._dim + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + sample = self.data[idx] + return torch.from_numpy(sample.copy()) + + +def normalize(maps, method): + if method == "median_zscore": + maps -= maps.median(dim=1, keepdim=True).values + maps /= maps.std(dim=1, keepdim=True) + else: + raise NotImplementedError(f"Normalization method {method} not implemented.") + return maps class MapToMapDistance: def __init__(self, config): self.config = config + self.do_low_memory_mode = self.config["analysis"]["low_memory"]["do"] + self.chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + self.chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + self.n_pix = self.config["data"]["n_pix"] + self.chunk_size_low_memory = self.config["analysis"]["low_memory"][ + "chunk_size_low_memory" + ] + if self.config["data"]["mask"]["do"]: + self.mask = ( + mrcfile.open(self.config["data"]["mask"]["volume"]) + .data.astype(bool) + .flatten() + ) def get_distance(self, map1, map2): """Compute the distance between two maps.""" raise NotImplementedError() - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + def get_sub_distance_matrix(self, maps1, maps2, idxs): """Compute the distance matrix between two sets of maps.""" - chunk_size_submission = self.config["analysis"]["chunk_size_submission"] - chunk_size_gt = self.config["analysis"]["chunk_size_gt"] - distance_matrix = torch.vmap( + sub_distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), - chunk_size=chunk_size_submission, + chunk_size=self.chunk_size_submission, )(maps2), - chunk_size=chunk_size_gt, + chunk_size=self.chunk_size_gt, )(maps1) + return sub_distance_matrix + def distance_matrix_precomputation(maps1, maps2, global_store_of_running_results): + """Pre-compute any assets needed for the distance matrix computation.""" + return + + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + """Compute the distance matrix between two sets of maps.""" + if self.config["data"]["mask"]["do"]: + maps2 = maps2[:, self.mask] + else: + maps2 = maps2.reshape(len(maps2), -1) + + if self.config["analysis"]["normalize"]["do"]: + maps2 = normalize( + maps2, method=self.config["analysis"]["normalize"]["method"] + ) + if self.do_low_memory_mode: + self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory + distance_matrix = torch.empty(len(maps1), len(maps2)) + for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): + maps1_in_memory = maps1[idxs] + if self.config["data"]["mask"]["do"]: + maps1_in_memory = maps1_in_memory.reshape(len(idxs), -1)[ + :, self.mask + ] + else: + maps1_in_memory = maps1_in_memory.reshape(len(maps1_in_memory), -1) + if self.config["analysis"]["normalize"]["do"]: + maps1_in_memory = normalize( + maps1_in_memory, + method=self.config["analysis"]["normalize"]["method"], + ) + sub_distance_matrix = self.get_sub_distance_matrix( + maps1_in_memory, + maps2, + idxs, + ) + distance_matrix[idxs] = sub_distance_matrix + + else: + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=self.chunk_size_submission, + )(maps2), + chunk_size=self.chunk_size_gt, + )(maps1) return distance_matrix def get_computed_assets(self, maps1, maps2, global_store_of_running_results): @@ -33,6 +122,10 @@ def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return {} +def norm2(map1, map2): + return torch.norm(map1 - map2) ** 2 + + class L2DistanceNorm(MapToMapDistance): """L2 distance norm""" @@ -41,39 +134,68 @@ def __init__(self, config): @override def get_distance(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) -class L2DistanceSum(MapToMapDistance): - """L2 distance. - - Computed by summing the squared differences between the two maps.""" - - def __init__(self, config): - super().__init__(config) - - def compute_cost_l2(self, map_1, map_2): - return ((map_1 - map_2) ** 2).sum() - - @override - def get_distance(self, map1, map2): - return self.compute_cost_l2(map1, map2) +def correlation(map1, map2): + return (map1 * map2).sum() class Correlation(MapToMapDistance): - """Correlation distance. + """Correlation. Not technically a distance metric, but a similarity.""" def __init__(self, config): super().__init__(config) - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - @override def get_distance(self, map1, map2): - return self.compute_cost_corr(map1, map2) + return correlation(map1, map2) + + +def compute_bioem3d_cost(map1, map2): + """ + Compute the cost between two maps using the BioEM cost function in 3D. + + Notes + ----- + See Eq. 10 in 10.1016/j.jsb.2013.10.006 + + Parameters + ---------- + map1 : torch.Tensor + shape (n_pix,n_pix,n_pix) + map2 : torch.Tensor + shape (n_pix,n_pix,n_pix) + + Returns + ------- + cost : torch.Tensor + shape (1,) + """ + m1, m2 = map1.reshape(-1), map2.reshape(-1) + co = m1.sum() + cc = m2.sum() + coo = m1.pow(2).sum() + ccc = m2.pow(2).sum() + coc = (m1 * m2).sum() + + N = len(m1) + + t1 = 2 * torch.pi * math.exp(1) + t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc + t3 = (N - 2) * (N * ccc - cc * cc) + + smallest_float = torch.finfo(m1.dtype).tiny + log_prob = ( + 0.5 * torch.pi + + math.log(t1) * (1 - N / 2) + + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) + + t3.clamp(smallest_float).log() * (N / 2 - 2) + ) + cost = -log_prob + return cost class BioEM3dDistance(MapToMapDistance): @@ -82,57 +204,75 @@ class BioEM3dDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_bioem3d_cost(self, map1, map2): - """ - Compute the cost between two maps using the BioEM cost function in 3D. - - Notes - ----- - See Eq. 10 in 10.1016/j.jsb.2013.10.006 - - Parameters - ---------- - map1 : torch.Tensor - shape (n_pix,n_pix,n_pix) - map2 : torch.Tensor - shape (n_pix,n_pix,n_pix) - - Returns - ------- - cost : torch.Tensor - shape (1,) - """ - m1, m2 = map1.reshape(-1), map2.reshape(-1) - co = m1.sum() - cc = m2.sum() - coo = m1.pow(2).sum() - ccc = m2.pow(2).sum() - coc = (m1 * m2).sum() - - N = len(m1) - - t1 = 2 * torch.pi * math.exp(1) - t2 = ( - N * (ccc * coo - coc * coc) - + 2 * co * coc * cc - - ccc * co * co - - coo * cc * cc - ) - t3 = (N - 2) * (N * ccc - cc * cc) - - smallest_float = torch.finfo(m1.dtype).tiny - log_prob = ( - 0.5 * torch.pi - + math.log(t1) * (1 - N / 2) - + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) - + t3.clamp(smallest_float).log() * (N / 2 - 2) - ) - cost = -log_prob - return cost - @override def get_distance(self, map1, map2): - return self.compute_bioem3d_cost(map1, map2) + return compute_bioem3d_cost(map1, map2) + + +def fourier_shell_correlation( + x: torch.Tensor, + y: torch.Tensor, + dim: Sequence[int] = (-3, -2, -1), + normalize: bool = True, + max_k: Optional[int] = None, +): + """Computes Fourier Shell / Ring Correlation (FSC) between x and y. + + Parameters + ---------- + x : torch.Tensor + First input tensor. + y : torch.Tensor + Second input tensor. + dim : Tuple[int, ...] + Dimensions over which to take the Fourier transform. + normalize : bool + Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). + Note that when `normalize=False`, we still divide by the number of elements in each shell. + max_k : int + The maximum shell to compute the correlation for. + + Returns + ------- + torch.Tensor + The correlation between x and y for each Fourier shell. + """ # noqa: E501 + batch_shape = x.shape[: -len(dim)] + + freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] + freq_total = ( + torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) + ) + + x_f = torch.fft.fftn(x, dim=dim) + y_f = torch.fft.fftn(y, dim=dim) + + n = min(x.shape[d] for d in dim) + + if max_k is None: + max_k = n // 2 + + result = x.new_zeros(batch_shape + (max_k,)) + + for i in range(1, max_k + 1): + mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) + x_ri = x_f[..., mask] + y_fi = y_f[..., mask] + + if x.is_cuda: + c_i = torch.linalg.vecdot(x_ri, y_fi).real + else: + # vecdot currently bugged on CPU for torch 2.0 in some configurations + c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real + + if normalize: + c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) + else: + c_i /= x_ri.shape[-1] + + result[..., i - 1] = c_i + + return result class FSCDistance(MapToMapDistance): @@ -143,72 +283,6 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def fourier_shell_correlation( - self, - x: torch.Tensor, - y: torch.Tensor, - dim: Sequence[int] = (-3, -2, -1), - normalize: bool = True, - max_k: Optional[int] = None, - ): - """Computes Fourier Shell / Ring Correlation (FSC) between x and y. - - Parameters - ---------- - x : torch.Tensor - First input tensor. - y : torch.Tensor - Second input tensor. - dim : Tuple[int, ...] - Dimensions over which to take the Fourier transform. - normalize : bool - Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). - Note that when `normalize=False`, we still divide by the number of elements in each shell. - max_k : int - The maximum shell to compute the correlation for. - - Returns - ------- - torch.Tensor - The correlation between x and y for each Fourier shell. - """ # noqa: E501 - batch_shape = x.shape[: -len(dim)] - - freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] - freq_total = ( - torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) - ) - - x_f = torch.fft.fftn(x, dim=dim) - y_f = torch.fft.fftn(y, dim=dim) - - n = min(x.shape[d] for d in dim) - - if max_k is None: - max_k = n // 2 - - result = x.new_zeros(batch_shape + (max_k,)) - - for i in range(1, max_k + 1): - mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) - x_ri = x_f[..., mask] - y_fi = y_f[..., mask] - - if x.is_cuda: - c_i = torch.linalg.vecdot(x_ri, y_fi).real - else: - # vecdot currently bugged on CPU for torch 2.0 in some configurations - c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real - - if normalize: - c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) - else: - c_i /= x_ri.shape[-1] - - result[..., i - 1] = c_i - - return result - def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ Compute the cost between two maps using the Fourier Shell Correlation in 3D. @@ -221,7 +295,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) for idx in range(len(maps_gt_flat)): - corr_vector = self.fourier_shell_correlation( + corr_vector = fourier_shell_correlation( maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), ) @@ -231,7 +305,16 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): return cost_matrix, fsc_matrix @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + def distance_matrix_precomputation(self, maps1, maps2): + self.len_maps1 = len(maps1) + self.len_maps2 = len(maps2) + self.stored_computed_assets = { + "fsc_matrix": torch.empty(self.len_maps1, self.len_maps2, self.n_pix // 2) + } + return + + @override + def get_sub_distance_matrix(self, maps1, maps2, idxs): """ Applies a mask to the maps and computes the cost matrix using the Fourier Shell Correlation. """ @@ -239,21 +322,62 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): maps_user_flat = maps2 n_pix = self.config["data"]["n_pix"] maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3) - mask = ( - mrcfile.open(self.config["data"]["mask"]["volume"]) - .data.astype(bool) - .flatten() - ) - maps_gt_flat_cube[:, mask] = maps_gt_flat maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3) - maps_user_flat_cube[:, mask] = maps_user_flat + + if self.config["data"]["mask"]["do"]: + maps_gt_flat_cube[:, self.mask] = maps_gt_flat[:] + maps_user_flat_cube[:, self.mask] = maps_user_flat + + else: + maps_gt_flat_cube = maps_gt_flat + maps_user_flat_cube = maps_user_flat cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk( maps_gt_flat_cube, maps_user_flat_cube, n_pix ) - self.stored_computed_assets = {"fsc_matrix": fsc_matrix} + self.stored_computed_assets["fsc_matrix"][idxs] = fsc_matrix return cost_matrix + @override + def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): + """Compute the distance matrix between two sets of maps.""" + if self.config["data"]["mask"]["do"]: + maps2 = maps2[:, self.mask] + else: + maps2 = maps2.reshape(len(maps2), -1) + + if self.config["analysis"]["normalize"]["do"]: + maps2 = normalize( + maps2, method=self.config["analysis"]["normalize"]["method"] + ) + if self.chunk_size_low_memory is None: + self.n_chunks_low_memory = 1 + else: + self.n_chunks_low_memory = len(maps1) // self.chunk_size_low_memory + distance_matrix = torch.empty(len(maps1), len(maps2)) + for idxs in torch.arange(len(maps1)).chunk(self.n_chunks_low_memory): + maps1_in_memory = maps1[idxs] + + if self.config["data"]["mask"]["do"]: + maps1_in_memory = maps1_in_memory[:].reshape(len(idxs), -1)[ + :, self.mask + ] + + else: + maps1_in_memory = maps1_in_memory.reshape(len(maps1_in_memory), -1) + if self.config["analysis"]["normalize"]["do"]: + maps1_in_memory = normalize( + maps1_in_memory, + method=self.config["analysis"]["normalize"]["method"], + ) + sub_distance_matrix = self.get_sub_distance_matrix( + maps1_in_memory, + maps2, + idxs, + ) + distance_matrix[idxs] = sub_distance_matrix + return distance_matrix + @override def get_computed_assets(self, maps1, maps2, global_store_of_running_results): return self.stored_computed_assets # must run get_distance_matrix first @@ -268,6 +392,7 @@ class FSCResDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) + self.fsc_label = "fsc" @override def get_distance_matrix( @@ -278,7 +403,7 @@ def get_distance_matrix( self.config["data"]["n_pix"] // 2 ) # TODO: check for odd psizes if this should be +1 psize = self.config["data"]["psize"] - fsc_matrix = global_store_of_running_results["fsc"]["computed_assets"][ + fsc_matrix = global_store_of_running_results[self.fsc_label]["computed_assets"][ "fsc_matrix" ] units_Angstroms = ( diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index ffc02df..c281496 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,13 +1,14 @@ -import mrcfile +import numpy as np import pandas as pd import pickle import torch from ..data._validation.output_validators import MapToMapResultsValidator from .._map_to_map.map_to_map_distance import ( + GT_Dataset, FSCDistance, Correlation, - L2DistanceSum, + L2DistanceNorm, BioEM3dDistance, FSCResDistance, ) @@ -16,7 +17,7 @@ AVAILABLE_MAP2MAP_DISTANCES = { "fsc": FSCDistance, "corr": Correlation, - "l2": L2DistanceSum, + "l2": L2DistanceNorm, "bioem": BioEM3dDistance, "res": FSCResDistance, } @@ -30,8 +31,11 @@ def run(config): map_to_map_distances = { distance_label: distance_class(config) for distance_label, distance_class in AVAILABLE_MAP2MAP_DISTANCES.items() + if distance_label in config["analysis"]["metrics"] } + do_low_memory_mode = config["analysis"]["low_memory"]["do"] + n_pix = config["data"]["n_pix"] submission = torch.load(config["data"]["submission"]["fname"]) @@ -51,32 +55,21 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - maps_gt_flat = torch.load(config["data"]["ground_truth"]["volumes"]).reshape( - -1, n_pix**3 - ) - - if config["data"]["mask"]["do"]: - mask = ( - mrcfile.open(config["data"]["mask"]["volume"]).data.astype(bool).flatten() - ) - maps_gt_flat = maps_gt_flat[:, mask] - maps_user_flat = maps_user_flat[:, mask] + if do_low_memory_mode: + maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) else: - maps_gt_flat.reshape(len(maps_gt_flat), -1, inplace=True) - maps_user_flat.reshape(len(maps_gt_flat), -1, inplace=True) - - if config["analysis"]["normalize"]["do"]: - if config["analysis"]["normalize"]["method"] == "median_zscore": - maps_gt_flat -= maps_gt_flat.median(dim=1, keepdim=True).values - maps_gt_flat /= maps_gt_flat.std(dim=1, keepdim=True) - maps_user_flat -= maps_user_flat.median(dim=1, keepdim=True).values - maps_user_flat /= maps_user_flat.std(dim=1, keepdim=True) + maps_gt_flat = torch.from_numpy( + np.load(config["data"]["ground_truth"]["volumes"]) + ).reshape(-1, n_pix**3) computed_assets = {} for distance_label, map_to_map_distance in map_to_map_distances.items(): if distance_label in config["analysis"]["metrics"]: # TODO: can remove print("cost matrix", distance_label) + map_to_map_distance.distance_matrix_precomputation( + maps_gt_flat, maps_user_flat + ) cost_matrix = map_to_map_distance.get_distance_matrix( maps_gt_flat, maps_user_flat, diff --git a/src/cryo_challenge/_preprocessing/align_utils.py b/src/cryo_challenge/_preprocessing/align_utils.py index cbbb6a1..f0ae4fa 100644 --- a/src/cryo_challenge/_preprocessing/align_utils.py +++ b/src/cryo_challenge/_preprocessing/align_utils.py @@ -116,6 +116,45 @@ def center_submission(volumes: torch.Tensor, pixel_size: float) -> torch.Tensor: return volumes +# def align_submission( +# volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict +# ) -> torch.Tensor: +# """ +# Align submission volumes to ground truth volume + +# Parameters: +# ----------- +# volumes (torch.Tensor): submission volumes +# shape: (n_volumes, im_x, im_y, im_z) +# ref_volume (torch.Tensor): ground truth volume +# shape: (im_x, im_y, im_z) +# params (dict): dictionary containing alignment parameters + +# Returns: +# -------- +# volumes (torch.Tensor): aligned submission volumes +# """ +# for i in range(len(volumes)): +# obj_vol = volumes[i].numpy().astype(np.float32).copy() + +# obj_vol = Volume(obj_vol / obj_vol.sum()) +# ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) + +# _, R_est = align_BO( +# ref_vol, +# obj_vol, +# loss_type=params["BOT_loss"], +# downsampled_size=params["BOT_box_size"], +# max_iters=params["BOT_iter"], +# refine=params["BOT_refine"], +# ) +# R_est = Rotation(R_est.astype(np.float32)) + +# volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) + +# return volumes + + def align_submission( volumes: torch.Tensor, ref_volume: torch.Tensor, params: dict ) -> torch.Tensor: diff --git a/src/cryo_challenge/_preprocessing/bfactor_normalize.py b/src/cryo_challenge/_preprocessing/bfactor_normalize.py new file mode 100644 index 0000000..37d876b --- /dev/null +++ b/src/cryo_challenge/_preprocessing/bfactor_normalize.py @@ -0,0 +1,89 @@ +import torch +from ..power_spectrum_utils import _centered_fftn, _centered_ifftn + + +def _compute_bfactor_scaling(b_factor, box_size, voxel_size): + """ + Compute the B-factor scaling factor for a given B-factor, box size, and voxel size. + The B-factor scaling factor is computed as exp(-B * s^2 / 4), where s is the squared + distance in Fourier space. + + Parameters + ---------- + b_factor: float + B-factor to apply. + box_size: int + Size of the box. + voxel_size: float + Voxel size of the box. + + Returns + ------- + bfactor_scaling_torch: torch.tensor(shape=(box_size, box_size, box_size)) + B-factor scaling factor. + """ + x = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + y = x.clone() + z = x.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + bfactor_scaling_torch = torch.exp(-b_factor * s2 / 4) + + return bfactor_scaling_torch + + +def bfactor_normalize_volumes(volumes, bfactor, voxel_size, in_place=False): + """ + Normalize volumes by applying a B-factor correction. This is done by multiplying + a centered Fourier transform of the volume by the B-factor scaling factor and then + applying the inverse Fourier transform. See _compute_bfactor_scaling for details on the + computation of the B-factor scaling. + + Parameters + ---------- + volumes: torch.tensor + Volumes to normalize. The volumes must have shape (N, N, N) or (n_volumes, N, N, N). + bfactor: float + B-factor to apply. + voxel_size: float + Voxel size of the volumes. + in_place: bool - default: False + Whether to normalize the volumes in place. + + Returns + ------- + volumes: torch.tensor + Normalized volumes. + """ + # assert that volumes have the correct shape + assert volumes.ndim in [ + 3, + 4, + ], "Input volumes must have shape (N, N, N) or (n_volumes, N, N, N)" + + if volumes.ndim == 3: + assert ( + volumes.shape[0] == volumes.shape[1] == volumes.shape[2] + ), "Input volumes must have equal dimensions" + else: + assert ( + volumes.shape[1] == volumes.shape[2] == volumes.shape[3] + ), "Input volumes must have equal dimensions" + + if not in_place: + volumes = volumes.clone() + + b_factor_scaling = _compute_bfactor_scaling(bfactor, volumes.shape[-1], voxel_size) + + if len(volumes.shape) == 3: + volumes = _centered_fftn(volumes, dim=(0, 1, 2)) + volumes = volumes * b_factor_scaling + volumes = _centered_ifftn(volumes, dim=(0, 1, 2)).real + + elif len(volumes.shape) == 4: + volumes = _centered_fftn(volumes, dim=(1, 2, 3)) + volumes = volumes * b_factor_scaling[None, ...] + volumes = _centered_ifftn(volumes, dim=(1, 2, 3)).real + + return volumes diff --git a/src/cryo_challenge/_preprocessing/normalize.py b/src/cryo_challenge/_preprocessing/normalize.py deleted file mode 100644 index 73449bf..0000000 --- a/src/cryo_challenge/_preprocessing/normalize.py +++ /dev/null @@ -1,22 +0,0 @@ -''' -TODO: Need to implement this properly - -def normalize_mean_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.mean(-1, keepdims=True)) / vols_flat.std( - -1, keepdims=True - ) - - -def normalize_median_std(vols_flat): - """ - vols_flat.shape is (n_vols, n_pix**3) - vols_flat is a torch tensor - """ - return (vols_flat - vols_flat.median(-1, keepdims=True).values) / vols_flat.std( - -1, keepdims=True - ) -''' diff --git a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py index 90ccc51..b4c3e61 100644 --- a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py +++ b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py @@ -2,7 +2,7 @@ import json import os -from .align_utils import align_submission, center_submission, threshold_submissions +from .align_utils import align_submission, threshold_submissions from .crop_pad_utils import crop_pad_submission from .fourier_utils import downsample_submission @@ -80,7 +80,7 @@ def preprocess_submissions(submission_dataset, config): # center submission print(" Centering submission") - volumes = center_submission(volumes, pixel_size=pixel_size_gt) + # volumes = center_submission(volumes, pixel_size=pixel_size_gt) # flip handedness if submission_dataset.submission_config[str(idx)]["flip"] == 1: diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index b2fa933..83083ed 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -151,6 +151,7 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: chunk_size_submission: int, is the chunk size for the submission volume. chunk_size_gt: int, is the chunk size for the ground truth volume. normalize: dict, is the normalize part of the analysis part of the config. + low_memory: dict, is the low memory part of the analysis part of the config. # TODO: add validation for low_memory """ # noqa: E501 keys_and_types = { @@ -158,6 +159,7 @@ def validate_config_mtm_analysis(config_analysis: dict) -> None: "chunk_size_submission": Number, "chunk_size_gt": Number, "normalize": dict, + "low_memory": dict, } validate_generic_config(config_analysis, keys_and_types) diff --git a/src/cryo_challenge/power_spectrum_utils.py b/src/cryo_challenge/power_spectrum_utils.py new file mode 100644 index 0000000..afae5bc --- /dev/null +++ b/src/cryo_challenge/power_spectrum_utils.py @@ -0,0 +1,153 @@ +import torch + + +def _cart2sph(x, y, z): + """ + Converts a grid in cartesian coordinates to spherical coordinates. + + Parameters + ---------- + x: torch.tensor + x-coordinate of the grid. + y: torch.tensor + y-coordinate of the grid. + z: torch.tensor + """ + hxy = torch.hypot(x, y) + r = torch.hypot(hxy, z) + el = torch.atan2(z, hxy) + az = torch.atan2(y, x) + return az, el, r + + +def _grid_3d(n, dtype=torch.float32): + """ + Generates a centered 3D grid. The grid is given in both cartesian and spherical coordinates. + + Parameters + ---------- + n: int + Size of the grid. + dtype: torch.dtype + Data type of the grid. + + Returns + ------- + grid: dict + Dictionary containing the grid in cartesian and spherical coordinates. + keys: x, y, z, phi, theta, r + """ + start = -n // 2 + 1 + end = n // 2 + + if n % 2 == 0: + start -= 1 / 2 + end -= 1 / 2 + + grid = torch.linspace(start, end, n, dtype=dtype) + z, x, y = torch.meshgrid(grid, grid, grid, indexing="ij") + + phi, theta, r = _cart2sph(x, y, z) + + theta = torch.pi / 2 - theta + + return {"x": x, "y": y, "z": z, "phi": phi, "theta": theta, "r": r} + + +def _centered_fftn(x, dim=None): + """ + Wrapper around torch.fft.fftn that centers the Fourier transform. + """ + x = torch.fft.fftn(x, dim=dim) + x = torch.fft.fftshift(x, dim=dim) + return x + + +def _centered_ifftn(x, dim=None): + """ + Wrapper around torch.fft.ifftn that centers the inverse Fourier transform. + """ + x = torch.fft.fftshift(x, dim=dim) + x = torch.fft.ifftn(x, dim=dim) + return x + + +def _average_over_single_shell(shell_index, volume, radii, shell_width=0.5): + """ + Given a volume in Fourier space, compute the average value of the volume over a shell. + + Parameters + ---------- + shell_index: int + Index of the shell in Fourier space. + volume: torch.tensor + Volume in Fourier space. + radii: torch.tensor + Radii of the Fourier space grid. + shell_width: float + Width of the shell. + + Returns + ------- + average: float + Average value of the volume over the shell. + """ + inner_diameter = shell_width + shell_index + outer_diameter = shell_width + (shell_index + 1) + mask = (radii > inner_diameter) & (radii < outer_diameter) + return torch.sum(mask * volume) / torch.sum(mask) + + +def _average_over_shells(volume_in_fourier_space, shell_width=0.5): + """ + Vmap wrapper over _average_over_single_shell to compute the average value of a volume in Fourier space over all shells. The input should be a volumetric quantity in Fourier space. + + Parameters + ---------- + volume_in_fourier_space: torch.tensor + Volume in Fourier space. + + Returns + ------- + radial_average: torch.tensor + Average value of the volume over all shells. + """ + L = volume_in_fourier_space.shape[0] + dtype = torch.float32 + radii = _grid_3d(L, dtype=dtype)["r"] + + radial_average = torch.vmap( + _average_over_single_shell, in_dims=(0, None, None, None) + )(torch.arange(0, L // 2), volume_in_fourier_space, radii, shell_width) + + return radial_average + + +def compute_power_spectrum(volume, shell_width=0.5): + """ + Compute the power spectrum of a volume. + + Parameters + ---------- + volume: torch.tensor + Volume for which to compute the power spectrum. + shell_width: float + Width of the shell. + + Returns + ------- + power_spectrum: torch.tensor + Power spectrum of the volume. + + Examples + -------- + volume = mrcfile.open("volume.mrc").data.copy() + volume = torch.tensor(volume, dtype=torch.float32) + power_spectrum = compute_power_spectrum(volume) + """ + + # Compute centered Fourier transforms. + vol_fft = torch.abs(_centered_fftn(volume)) ** 2 + power_spectrum = _average_over_shells(vol_fft, shell_width=shell_width) + + return power_spectrum diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 7dfa7e9..689eb82 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -21,6 +21,9 @@ analysis: - res chunk_size_submission: 80 chunk_size_gt: 190 + low_memory: + do: false + chunk_size_low_memory: null normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml new file mode 100644 index 0000000..7b02d2e --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml @@ -0,0 +1,31 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: true + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 + normalize: + do: true + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml new file mode 100644 index 0000000..020f58f --- /dev/null +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml @@ -0,0 +1,31 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: dummy-string +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + chunk_size_low_memory: 10 + low_memory: + do: true + chunk_size_low_memory: 10 + normalize: + do: false + method: dummy-string +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml new file mode 100644 index 0000000..2a9a3a8 --- /dev/null +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -0,0 +1,30 @@ +data: + n_pix: 16 + psize: 30.044 + submission: + fname: tests/data/dataset_2_submissions/submission_1000.pt + volume_key: volumes + metadata_key: populations + label_key: id + ground_truth: + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + metadata: tests/data/Ground_truth/test_metadata_10.csv + mask: + do: false + volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc +analysis: + metrics: + - l2 + - corr + - bioem + - fsc + - res + chunk_size_submission: 80 + chunk_size_gt: 190 + low_memory: + do: false + chunk_size_low_memory: null + normalize: + do: false + method: median_zscore +output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index 1fb797d..e7669ab 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -8,13 +8,13 @@ }, "0": { "name": "raw_submission_in_testdata", - "align": 1, "flavor_name": "test flavor", + "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", + "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", + "submission_version": "1.0", "box_size": 32, "pixel_size": 15.022, - "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", "flip": 1, - "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", - "submission_version": "1.0" + "align": 1 } } diff --git a/tests/test_map_to_map.py b/tests/test_map_to_map.py index e31f29f..907e6d3 100644 --- a/tests/test_map_to_map.py +++ b/tests/test_map_to_map.py @@ -1,9 +1,38 @@ from omegaconf import OmegaConf from cryo_challenge._commands import run_map2map_pipeline +import numpy as np def test_run_map2map_pipeline(): - args = OmegaConf.create( - {"config": "tests/config_files/test_config_map_to_map.yaml"} - ) - run_map2map_pipeline.main(args) + for config_fname, config_fname_low_memory in zip( + [ + "tests/config_files/test_config_map_to_map.yaml", + "tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml", + ], + [ + "tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml", + "tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml", + ], + ): + args = OmegaConf.create({"config": config_fname}) + results_dict = run_map2map_pipeline.main(args) + + args_low_memory = OmegaConf.create({"config": config_fname_low_memory}) + results_dict_low_memory = run_map2map_pipeline.main(args_low_memory) + for metric in ["fsc", "corr", "l2", "bioem"]: + if metric == "fsc": + np.allclose( + results_dict[metric]["computed_assets"]["fsc_matrix"], + results_dict_low_memory[metric]["computed_assets"]["fsc_matrix"], + ) + elif metric == "res": + np.allclose( + results_dict[metric]["computed_assets"]["fraction_nyquist"], + results_dict_low_memory[metric]["computed_assets"][ + "fraction_nyquist" + ], + ) + np.allclose( + results_dict[metric]["cost_matrix"].values, + results_dict_low_memory[metric]["cost_matrix"].values, + ) diff --git a/tests/test_power_spectrum_and_bfactor.py b/tests/test_power_spectrum_and_bfactor.py new file mode 100644 index 0000000..218632b --- /dev/null +++ b/tests/test_power_spectrum_and_bfactor.py @@ -0,0 +1,85 @@ +import torch +from cryo_challenge.power_spectrum_utils import _centered_ifftn, compute_power_spectrum +from cryo_challenge._preprocessing.bfactor_normalize import ( + _compute_bfactor_scaling, + bfactor_normalize_volumes, +) + + +def test_compute_power_spectrum(): + """ + Test the computation of the power spectrum of a radially symmetric Gaussian volume. + Since the volume is radially symmetric, the power spectrum of the whole volume should be + approximately the power spectrum in a central slice. The computation is not exact as our + averaging over shells is approximated. + """ + box_size = 224 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.073 * 2 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + b_factor = 170 + + gaussian_volume = torch.exp(-b_factor / 4 * s2).reshape(volume_shape) + gaussian_volume = _centered_ifftn(gaussian_volume) + + power_spectrum = compute_power_spectrum(gaussian_volume) + power_spectrum_slice = ( + torch.abs(torch.fft.fftn(gaussian_volume)[: box_size // 2, 0, 0]) ** 2 + ) + + mean_squared_error = torch.mean((power_spectrum - power_spectrum_slice) ** 2) + + assert mean_squared_error < 1e-3 + + return + + +def test_bfactor_normalize_volumes(): + """ + Similarly to the other test, we test the normalization of a radially symmetric volume. + In this case we test with an oscillatory volume, which is a volume with a sinusoidal. + Since both the b-factor correction volume and the volume are radially symmetric, the + power spectrum of the normalized volume should be the same as the power spectrum of + a normalized central slice + """ + box_size = 128 + volume_shape = (box_size, box_size, box_size) + voxel_size = 1.5 + + freq = torch.fft.fftshift(torch.fft.fftfreq(box_size, d=voxel_size)) + x = freq.clone() + y = freq.clone() + z = freq.clone() + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + + s2 = x**2 + y**2 + z**2 + + oscillatory_volume = torch.sin(300 * s2).reshape(volume_shape) + oscillatory_volume = _centered_ifftn(oscillatory_volume) + bfactor_scaling_vol = _compute_bfactor_scaling(170, box_size, voxel_size) + + norm_oscillatory_vol = bfactor_normalize_volumes( + oscillatory_volume, 170, voxel_size, in_place=False + ) + + ps_osci = torch.fft.fftn(oscillatory_volume, dim=(-3, -2, -1), norm="backward")[ + : box_size // 2, 0, 0 + ] + ps_norm_osci = torch.fft.fftn( + norm_oscillatory_vol, dim=(-3, -2, -1), norm="backward" + )[: box_size // 2, 0, 0] + ps_bfactor_scaling = torch.fft.fftshift(bfactor_scaling_vol)[: box_size // 2, 0, 0] + + ps_osci = torch.abs(ps_osci) ** 2 + ps_norm_osci = torch.abs(ps_norm_osci) ** 2 + ps_bfactor_scaling = torch.abs(ps_bfactor_scaling) ** 2 + + assert torch.allclose(ps_norm_osci, ps_osci * ps_bfactor_scaling)