From 3d3c8f1a1a922f33ccaf08977260c3f7dfabe187 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Mon, 12 Aug 2024 12:42:42 -0400 Subject: [PATCH] @override from typing_extensions --- .../_map_to_map/map_to_map_distance.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 9b62543..55d01d3 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -1,6 +1,7 @@ import math import torch from typing import Optional, Sequence +from typing_extensions import override import mrcfile class MapToMapDistance: @@ -33,6 +34,7 @@ class L2DistanceNorm(MapToMapDistance): def __init__(self, config): super().__init__(config) + @override def get_distance(self, map1, map2): return torch.norm(map1 - map2)**2 @@ -42,7 +44,8 @@ def __init__(self, config): def compute_cost_l2(self, map_1, map_2): return ((map_1 - map_2) ** 2).sum() - + + @override def get_distance(self, map1, map2): return self.compute_cost_l2(map1, map2) @@ -52,7 +55,8 @@ def __init__(self, config): def compute_cost_corr(self, map_1, map_2): return (map_1 * map_2).sum() - + + @override def get_distance(self, map1, map2): return self.compute_cost_corr(map1, map2) @@ -102,7 +106,8 @@ def compute_bioem3d_cost(self, map1, map2): ) cost = -log_prob return cost - + + @override def get_distance(self, map1, map2): return self.compute_bioem3d_cost(map1, map2) @@ -197,6 +202,7 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix): cost_matrix[idx] = dist return cost_matrix, fsc_matrix + @override def get_distance_matrix(self, maps1, maps2): # custom method maps_gt_flat = maps1 maps_user_flat = maps2 @@ -212,6 +218,7 @@ def get_distance_matrix(self, maps1, maps2): # custom method cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix) self.stored_computed_assets = {'fsc_matrix': fsc_matrix} return cost_matrix - + + @override def get_computed_assets(self, maps1, maps2): return self.stored_computed_assets # must run get_distance_matrix first \ No newline at end of file