diff --git a/src/cryo_challenge/_map_to_map/map_to_map_distance.py b/src/cryo_challenge/_map_to_map/map_to_map_distance.py index 5b9dc8e..3021db5 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_distance.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_distance.py @@ -4,28 +4,6 @@ from typing_extensions import override import mrcfile import numpy as np -from torch.utils.data import Dataset - - -class GT_Dataset(Dataset): - def __init__(self, npy_file): - self.npy_file = npy_file - self.data = np.load(npy_file, mmap_mode="r+") - - self.shape = self.data.shape - self._dim = len(self.data.shape) - - def dim(self): - return self._dim - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - sample = self.data[idx] - return torch.from_numpy(sample.copy()) def normalize(maps, method): @@ -108,6 +86,8 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results): distance_matrix[idxs] = sub_distance_matrix else: + maps1 = maps1.reshape(len(maps1), -1) + maps2 = maps2.reshape(len(maps2), -1) distance_matrix = torch.vmap( lambda maps1: torch.vmap( lambda maps2: self.get_distance(maps1, maps2), diff --git a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py index c281496..06ce66f 100644 --- a/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py +++ b/src/cryo_challenge/_map_to_map/map_to_map_pipeline.py @@ -1,11 +1,9 @@ -import numpy as np import pandas as pd import pickle import torch from ..data._validation.output_validators import MapToMapResultsValidator from .._map_to_map.map_to_map_distance import ( - GT_Dataset, FSCDistance, Correlation, L2DistanceNorm, @@ -36,8 +34,6 @@ def run(config): do_low_memory_mode = config["analysis"]["low_memory"]["do"] - n_pix = config["data"]["n_pix"] - submission = torch.load(config["data"]["submission"]["fname"]) submission_volume_key = config["data"]["submission"]["volume_key"] submission_metadata_key = config["data"]["submission"]["metadata_key"] @@ -55,12 +51,9 @@ def run(config): maps_user_flat = submission[submission_volume_key].reshape( len(submission["volumes"]), -1 ) - if do_low_memory_mode: - maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"]) - else: - maps_gt_flat = torch.from_numpy( - np.load(config["data"]["ground_truth"]["volumes"]) - ).reshape(-1, n_pix**3) + maps_gt_flat = torch.load( + config["data"]["ground_truth"]["volumes"], mmap=do_low_memory_mode + ) computed_assets = {} for distance_label, map_to_map_distance in map_to_map_distances.items(): diff --git a/tests/config_files/test_config_map_to_map.yaml b/tests/config_files/test_config_map_to_map.yaml index 689eb82..2244e21 100644 --- a/tests/config_files/test_config_map_to_map.yaml +++ b/tests/config_files/test_config_map_to_map.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -19,8 +19,8 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: false chunk_size_low_memory: null diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml index 7b02d2e..8bc02e7 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: true @@ -19,12 +19,11 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: true - chunk_size_low_memory: 10 + chunk_size_low_memory: 2 normalize: do: true method: median_zscore diff --git a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml index 020f58f..74b494b 100644 --- a/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_low_memory_subbatch_nomask_nonormalize.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: false @@ -19,12 +19,11 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 - chunk_size_low_memory: 10 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: true - chunk_size_low_memory: 10 + chunk_size_low_memory: 2 normalize: do: false method: dummy-string diff --git a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml index 2a9a3a8..a8a4f09 100644 --- a/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml +++ b/tests/config_files/test_config_map_to_map_nomask_nonormalize.yaml @@ -7,7 +7,7 @@ data: metadata_key: populations label_key: id ground_truth: - volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy + volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt metadata: tests/data/Ground_truth/test_metadata_10.csv mask: do: false @@ -19,8 +19,8 @@ analysis: - bioem - fsc - res - chunk_size_submission: 80 - chunk_size_gt: 190 + chunk_size_submission: 4 + chunk_size_gt: 5 low_memory: do: false chunk_size_low_memory: null