diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py index 0578dfa..7880578 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance_matrix.py @@ -28,6 +28,34 @@ def vmap_distance( )(maps_gt) +class MapToMapDistance: + def __init__(self, config): + self.config = config + + def get_distance(self, map1, map2): + raise NotImplementedError() + + def get_distance_matrix(self, maps1, maps2): + chunk_size_submission = self.config["analysis"]["chunk_size_submission"] + chunk_size_gt = self.config["analysis"]["chunk_size_gt"] + distance_matrix = torch.vmap( + lambda maps1: torch.vmap( + lambda maps2: self.get_distance(maps1, maps2), + chunk_size=chunk_size_submission, + )(maps2), + chunk_size=chunk_size_gt, + )(maps1) + + return distance_matrix + +class L2Distance(MapToMapDistance): + def __init__(self, config): + super().__init__(config) + + def get_distance(self, map1, map2): + return torch.norm(map1 - map2)**2 + + def run(config): """ Compare a submission to ground truth. @@ -99,6 +127,13 @@ def run(config): ) cost_matrix = cost_matrix.numpy() computed_assets["fsc_matrix"] = fsc_matrix + elif cost_label == "nope": + l2_map2map_distance = L2Distance(config) + cost_matrix = l2_map2map_distance.get_distance_matrix( + maps_gt_flat, maps_user_flat + ).numpy() + print('run new method') + else: cost_matrix = vmap_distance( maps_gt_flat, @@ -111,6 +146,7 @@ def run(config): cost_matrix_df = pd.DataFrame( cost_matrix, columns=None, index=metadata_gt.populations.tolist() ) + print(cost_matrix_df) # output results single_distance_results_dict = {