Skip to content

Commit

Permalink
all metrics implemented and test passing for matched results
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Sep 11, 2024
1 parent 22ec5a5 commit de3cddc
Showing 1 changed file with 3 additions and 156 deletions.
159 changes: 3 additions & 156 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,42 +110,6 @@ def compute_cost(self, map1, map2):
return torch.norm(map1 - map2) ** 2


class CorrelationLowMemoryCheck(MapToMapDistance):
"""Correlation.
Not technically a distance metric, but a similarity."""

def __init__(self, config):
super().__init__(config)

def compute_cost_corr(self, map_1, map_2):
return (map_1 * map_2).sum()

@override
def get_distance(self, map1, map2, global_store_of_running_results):
map1 = map1.flatten()
map1 -= map1.median()
map1 /= map1.std()
map1 = map1[global_store_of_running_results["mask"]]

return self.compute_cost_corr(map1, map2)

@override
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
maps_gt_flat = maps1
maps_user_flat = maps2
cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat))
for idx_gt in range(len(maps_gt_flat)):
for idx_user in range(len(maps_user_flat)):
cost_matrix[idx_gt, idx_user] = self.get_distance(
maps_gt_flat[idx_gt],
maps_user_flat[idx_user],
global_store_of_running_results,
)

return cost_matrix


def correlation(map1, map2):
return (map1 * map2).sum()

Expand All @@ -158,12 +122,9 @@ class Correlation(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def compute_cost_corr(self, map_1, map_2):
return (map_1 * map_2).sum()

@override
def get_distance(self, map1, map2):
return self.compute_cost_corr(map1, map2)
return correlation(map1, map2)


class CorrelationLowMemory(MapToMapDistanceLowMemory):
Expand Down Expand Up @@ -227,57 +188,9 @@ class BioEM3dDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def compute_bioem3d_cost(self, map1, map2):
"""
Compute the cost between two maps using the BioEM cost function in 3D.
Notes
-----
See Eq. 10 in 10.1016/j.jsb.2013.10.006
Parameters
----------
map1 : torch.Tensor
shape (n_pix,n_pix,n_pix)
map2 : torch.Tensor
shape (n_pix,n_pix,n_pix)
Returns
-------
cost : torch.Tensor
shape (1,)
"""
m1, m2 = map1.reshape(-1), map2.reshape(-1)
co = m1.sum()
cc = m2.sum()
coo = m1.pow(2).sum()
ccc = m2.pow(2).sum()
coc = (m1 * m2).sum()

N = len(m1)

t1 = 2 * torch.pi * math.exp(1)
t2 = (
N * (ccc * coo - coc * coc)
+ 2 * co * coc * cc
- ccc * co * co
- coo * cc * cc
)
t3 = (N - 2) * (N * ccc - cc * cc)

smallest_float = torch.finfo(m1.dtype).tiny
log_prob = (
0.5 * torch.pi
+ math.log(t1) * (1 - N / 2)
+ t2.clamp(smallest_float).log() * (3 / 2 - N / 2)
+ t3.clamp(smallest_float).log() * (N / 2 - 2)
)
cost = -log_prob
return cost

@override
def get_distance(self, map1, map2):
return self.compute_bioem3d_cost(map1, map2)
return compute_bioem3d_cost(map1, map2)


class BioEM3dDistanceLowMemory(MapToMapDistanceLowMemory):
Expand Down Expand Up @@ -365,72 +278,6 @@ class FSCDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def fourier_shell_correlation(
self,
x: torch.Tensor,
y: torch.Tensor,
dim: Sequence[int] = (-3, -2, -1),
normalize: bool = True,
max_k: Optional[int] = None,
):
"""Computes Fourier Shell / Ring Correlation (FSC) between x and y.
Parameters
----------
x : torch.Tensor
First input tensor.
y : torch.Tensor
Second input tensor.
dim : Tuple[int, ...]
Dimensions over which to take the Fourier transform.
normalize : bool
Whether to normalize (i.e. compute correlation) or not (i.e. compute covariance).
Note that when `normalize=False`, we still divide by the number of elements in each shell.
max_k : int
The maximum shell to compute the correlation for.
Returns
-------
torch.Tensor
The correlation between x and y for each Fourier shell.
""" # noqa: E501
batch_shape = x.shape[: -len(dim)]

freqs = [torch.fft.fftfreq(x.shape[d], d=1 / x.shape[d]).to(x) for d in dim]
freq_total = (
torch.cartesian_prod(*freqs).view(*[len(f) for f in freqs], -1).norm(dim=-1)
)

x_f = torch.fft.fftn(x, dim=dim)
y_f = torch.fft.fftn(y, dim=dim)

n = min(x.shape[d] for d in dim)

if max_k is None:
max_k = n // 2

result = x.new_zeros(batch_shape + (max_k,))

for i in range(1, max_k + 1):
mask = (freq_total >= i - 0.5) & (freq_total < i + 0.5)
x_ri = x_f[..., mask]
y_fi = y_f[..., mask]

if x.is_cuda:
c_i = torch.linalg.vecdot(x_ri, y_fi).real
else:
# vecdot currently bugged on CPU for torch 2.0 in some configurations
c_i = torch.sum(x_ri * y_fi.conj(), dim=-1).real

if normalize:
c_i /= torch.linalg.norm(x_ri, dim=-1) * torch.linalg.norm(y_fi, dim=-1)
else:
c_i /= x_ri.shape[-1]

result[..., i - 1] = c_i

return result

def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix):
"""
Compute the cost between two maps using the Fourier Shell Correlation in 3D.
Expand All @@ -443,7 +290,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix):
cost_matrix = torch.empty(len(maps_gt_flat), len(maps_user_flat))
fsc_matrix = torch.zeros(len(maps_gt_flat), len(maps_user_flat), n_pix // 2)
for idx in range(len(maps_gt_flat)):
corr_vector = self.fourier_shell_correlation(
corr_vector = fourier_shell_correlation(
maps_user_flat.reshape(-1, n_pix, n_pix, n_pix),
maps_gt_flat[idx].reshape(n_pix, n_pix, n_pix),
)
Expand Down

0 comments on commit de3cddc

Please sign in to comment.