Skip to content

Commit

Permalink
remove numpy depedency. use torch mmap
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Sep 18, 2024
1 parent faf7e29 commit 9792ed3
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 48 deletions.
24 changes: 2 additions & 22 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,6 @@
from typing_extensions import override
import mrcfile
import numpy as np
from torch.utils.data import Dataset


class GT_Dataset(Dataset):
def __init__(self, npy_file):
self.npy_file = npy_file
self.data = np.load(npy_file, mmap_mode="r+")

self.shape = self.data.shape
self._dim = len(self.data.shape)

def dim(self):
return self._dim

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = self.data[idx]
return torch.from_numpy(sample.copy())


def normalize(maps, method):
Expand Down Expand Up @@ -108,6 +86,8 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
distance_matrix[idxs] = sub_distance_matrix

else:
maps1 = maps1.reshape(len(maps1), -1)
maps2 = maps2.reshape(len(maps2), -1)
distance_matrix = torch.vmap(
lambda maps1: torch.vmap(
lambda maps2: self.get_distance(maps1, maps2),
Expand Down
13 changes: 3 additions & 10 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import numpy as np
import pandas as pd
import pickle
import torch

from ..data._validation.output_validators import MapToMapResultsValidator
from .._map_to_map.map_to_map_distance import (
GT_Dataset,
FSCDistance,
Correlation,
L2DistanceNorm,
Expand Down Expand Up @@ -36,8 +34,6 @@ def run(config):

do_low_memory_mode = config["analysis"]["low_memory"]["do"]

n_pix = config["data"]["n_pix"]

submission = torch.load(config["data"]["submission"]["fname"])
submission_volume_key = config["data"]["submission"]["volume_key"]
submission_metadata_key = config["data"]["submission"]["metadata_key"]
Expand All @@ -55,12 +51,9 @@ def run(config):
maps_user_flat = submission[submission_volume_key].reshape(
len(submission["volumes"]), -1
)
if do_low_memory_mode:
maps_gt_flat = GT_Dataset(config["data"]["ground_truth"]["volumes"])
else:
maps_gt_flat = torch.from_numpy(
np.load(config["data"]["ground_truth"]["volumes"])
).reshape(-1, n_pix**3)
maps_gt_flat = torch.load(
config["data"]["ground_truth"]["volumes"], mmap=do_low_memory_mode
)

computed_assets = {}
for distance_label, map_to_map_distance in map_to_map_distances.items():
Expand Down
6 changes: 3 additions & 3 deletions tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data:
metadata_key: populations
label_key: id
ground_truth:
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: true
Expand All @@ -19,8 +19,8 @@ analysis:
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
chunk_size_submission: 4
chunk_size_gt: 5
low_memory:
do: false
chunk_size_low_memory: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data:
metadata_key: populations
label_key: id
ground_truth:
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: true
Expand All @@ -19,12 +19,11 @@ analysis:
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
chunk_size_low_memory: 10
chunk_size_submission: 4
chunk_size_gt: 5
low_memory:
do: true
chunk_size_low_memory: 10
chunk_size_low_memory: 2
normalize:
do: true
method: median_zscore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data:
metadata_key: populations
label_key: id
ground_truth:
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: false
Expand All @@ -19,12 +19,11 @@ analysis:
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
chunk_size_low_memory: 10
chunk_size_submission: 4
chunk_size_gt: 5
low_memory:
do: true
chunk_size_low_memory: 10
chunk_size_low_memory: 2
normalize:
do: false
method: dummy-string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data:
metadata_key: populations
label_key: id
ground_truth:
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.npy
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: false
Expand All @@ -19,8 +19,8 @@ analysis:
- bioem
- fsc
- res
chunk_size_submission: 80
chunk_size_gt: 190
chunk_size_submission: 4
chunk_size_gt: 5
low_memory:
do: false
chunk_size_low_memory: null
Expand Down

0 comments on commit 9792ed3

Please sign in to comment.