Skip to content

Commit

Permalink
@OverRide from typing_extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Aug 12, 2024
1 parent 9e323a9 commit 3d3c8f1
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import torch
from typing import Optional, Sequence
from typing_extensions import override
import mrcfile

class MapToMapDistance:
Expand Down Expand Up @@ -33,6 +34,7 @@ class L2DistanceNorm(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

@override
def get_distance(self, map1, map2):
return torch.norm(map1 - map2)**2

Expand All @@ -42,7 +44,8 @@ def __init__(self, config):

def compute_cost_l2(self, map_1, map_2):
return ((map_1 - map_2) ** 2).sum()


@override
def get_distance(self, map1, map2):
return self.compute_cost_l2(map1, map2)

Expand All @@ -52,7 +55,8 @@ def __init__(self, config):

def compute_cost_corr(self, map_1, map_2):
return (map_1 * map_2).sum()


@override
def get_distance(self, map1, map2):
return self.compute_cost_corr(map1, map2)

Expand Down Expand Up @@ -102,7 +106,8 @@ def compute_bioem3d_cost(self, map1, map2):
)
cost = -log_prob
return cost


@override
def get_distance(self, map1, map2):
return self.compute_bioem3d_cost(map1, map2)

Expand Down Expand Up @@ -197,6 +202,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix):
cost_matrix[idx] = dist
return cost_matrix, fsc_matrix

@override
def get_distance_matrix(self, maps1, maps2): # custom method
maps_gt_flat = maps1
maps_user_flat = maps2
Expand All @@ -212,6 +218,7 @@ def get_distance_matrix(self, maps1, maps2): # custom method
cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix)
self.stored_computed_assets = {'fsc_matrix': fsc_matrix}
return cost_matrix


@override
def get_computed_assets(self, maps1, maps2):
return self.stored_computed_assets # must run get_distance_matrix first

0 comments on commit 3d3c8f1

Please sign in to comment.