From c0f2f831facafc2fd3bddc5b112dfa0c97343254 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Tue, 17 Sep 2024 14:16:11 -0400 Subject: [PATCH] code duplication for norm --- src/cryo_challenge/_map_to_map/map_to_map_distance.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 4769528..b413ade 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -95,6 +95,10 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): return cost_matrix +def norm2(map1, map2): + return torch.norm(map1 - map2) ** 2 + + class L2DistanceNorm(MapToMapDistance): """L2 distance norm""" @@ -103,7 +107,7 @@ def __init__(self, config): @override def get_distance(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) class L2DistanceNormLowMemory(MapToMapDistanceLowMemory): @@ -114,7 +118,7 @@ def __init__(self, config): @override def compute_cost(self, map1, map2): - return torch.norm(map1 - map2) ** 2 + return norm2(map1, map2) def correlation(map1, map2):