diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 369c47f..70df449 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -110,42 +110,6 @@ def compute_cost(self, map1, map2): return torch.norm(map1 - map2) ** 2 -class CorrelationLowMemoryCheck(MapToMapDistance): - """Correlation. - - Not technically a distance metric, but a similarity.""" - - def __init__(self, config): - super().__init__(config) - - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - - @override - def get_distance(self, map1, map2, global_store_of_running_results): - map1 = map1.flatten() - map1 -= map1.median() - map1 /= map1.std() - map1 = map1[global_store_of_running_results["mask"]] - - return self.compute_cost_corr(map1, map2) - - @override - def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): - maps_gt_flat = maps1 - maps_user_flat = maps2 - cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) - for idx_gt in range(len(maps_gt_flat)): - for idx_user in range(len(maps_user_flat)): - cost_matrix[idx_gt, idx_user] = self.get_distance( - maps_gt_flat[idx_gt], - maps_user_flat[idx_user], - global_store_of_running_results, - ) - - return cost_matrix - - def correlation(map1, map2): return (map1 * map2).sum() @@ -158,12 +122,9 @@ class Correlation(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_cost_corr(self, map_1, map_2): - return (map_1 * map_2).sum() - @override def get_distance(self, map1, map2): - return self.compute_cost_corr(map1, map2) + return correlation(map1, map2) class CorrelationLowMemory(MapToMapDistanceLowMemory): @@ -227,57 +188,9 @@ class BioEM3dDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def compute_bioem3d_cost(self, map1, map2): - """ - Compute the cost between two maps using the BioEM cost function in 3D. - - Notes - ----- - See Eq. 10 in 10.1016/j.jsb.2013.10.006 - - Parameters - ---------- - map1 : torch.Tensor - shape (n_pix,n_pix,n_pix) - map2 : torch.Tensor - shape (n_pix,n_pix,n_pix) - - Returns - ------- - cost : torch.Tensor - shape (1,) - """ - m1, m2 = map1.reshape(-1), map2.reshape(-1) - co = m1.sum() - cc = m2.sum() - coo = m1.pow(2).sum() - ccc = m2.pow(2).sum() - coc = (m1 * m2).sum() - - N = len(m1) - - t1 = 2 * torch.pi * math.exp(1) - t2 = ( - N * (ccc * coo - coc * coc) - + 2 * co * coc * cc - - ccc * co * co - - coo * cc * cc - ) - t3 = (N - 2) * (N * ccc - cc * cc) - - smallest_float = torch.finfo(m1.dtype).tiny - log_prob = ( - 0.5 * torch.pi - + math.log(t1) * (1 - N / 2) - + t2.clamp(smallest_float).log() * (3 / 2 - N / 2) - + t3.clamp(smallest_float).log() * (N / 2 - 2) - ) - cost = -log_prob - return cost - @override def get_distance(self, map1, map2): - return self.compute_bioem3d_cost(map1, map2) + return compute_bioem3d_cost(map1, map2) class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory): @@ -365,72 +278,6 @@ class FSCDistance(MapToMapDistance): def __init__(self, config): super().__init__(config) - def fourier_shell_correlation( - self, - x: torch.Tensor, - y: torch.Tensor, - dim: Sequence[int] = (-3, -2, -1), - normalize: bool = True, - max_k: Optional[int] = None, - ): - """Computes Fourier Shell / Ring Correlation (FSC) between x and y. - - Parameters - ---------- - x : torch.Tensor - First input tensor. - y : torch.Tensor - Second input tensor. - dim : Tuple[int, ...] - Dimensions over which to take the Fourier transform. - normalize : bool - Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance). - Note that when `normalize=False`, we still divide by the number of elements in each shell. - max_k : int - The maximum shell to compute the correlation for. - - Returns - ------- - torch.Tensor - The correlation between x and y for each Fourier shell. - """ # noqa: E501 - batch_shape = x.shape[: -len(dim)] - - freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim] - freq_total = ( - torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1) - ) - - x_f = torch.fft.fftn(x, dim=dim) - y_f = torch.fft.fftn(y, dim=dim) - - n = min(x.shape[d] for d in dim) - - if max_k is None: - max_k = n // 2 - - result = x.new_zeros(batch_shape + (max_k,)) - - for i in range(1, max_k + 1): - mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5) - x_ri = x_f[..., mask] - y_fi = y_f[..., mask] - - if x.is_cuda: - c_i = torch.linalg.vecdot(x_ri, y_fi).real - else: - # vecdot currently bugged on CPU for torch 2.0 in some configurations - c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real - - if normalize: - c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1) - else: - c_i /= x_ri.shape[-1] - - result[..., i - 1] = c_i - - return result - def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): """ Compute the cost between two maps using the Fourier Shell Correlation in 3D. @@ -443,7 +290,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat)) fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2) for idx in range(len(maps_gt_flat)): - corr_vector = self.fourier_shell_correlation( + corr_vector = fourier_shell_correlation( maps_user_flat.reshape(-1, n_pix, n_pix, n_pix), maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix), )