Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

51 reduce memory burden of pipeline #56

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false


steps:
Expand Down
13 changes: 9 additions & 4 deletions config_files/config_svd.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
path_to_volumes: /path/to/volumes
box_size_ds: 32
box_size_ds: 128
submission_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"

# optional unless experiment_mode is "all_vs_ref"
path_to_reference: /path/to/reference/volumes.pt
reference_options:
path_to_reference: /path/to/reference/gt_maps.npy
n_volumes: 338 # optional, default is all volumes
random_subset: False # if False, the subset is chosen as volumes[::skip_vols, ...] to satisfy n_volumes

dtype: "float32" # options are "float32", "float64"
output_options:
# path will be created if it does not exist
output_path: /path/to/output
# whether or not to save the processed volumes (downsampled, normalized, etc.)
save_volumes: True
save_volumes: False
# whether or not to save the SVD matrices (U, S, V)
save_svd_matrices: True
save_svd_matrices: False
2 changes: 1 addition & 1 deletion src/cryo_challenge/_commands/run_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..data._validation.config_validators import validate_config_preprocessing
from .._preprocessing.preprocessing_pipeline import preprocess_submissions
from .._preprocessing.dataloader import SubmissionPreprocessingDataLoader
from ..data._dataloaders.preproc_dataloader import SubmissionPreprocessingDataLoader


def add_args(parser):
Expand Down
3 changes: 1 addition & 2 deletions src/cryo_challenge/_svd/svd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def run_all_vs_ref_pipeline(config: dict):
dtype = torch.float32 if config["dtype"] == "float32" else torch.float64

ref_volumes, mean_volume = load_ref_vols(
box_size_ds=config["box_size_ds"],
path_to_volumes=config["path_to_reference"],
config,
dtype=dtype,
)

Expand Down
21 changes: 21 additions & 0 deletions src/cryo_challenge/data/_dataloaders/gt_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import torch
from torch.utils.data import Dataset


class GT_Dataset(Dataset):
def __init__(self, npy_file):
self.npy_file = npy_file
self.data = np.load(npy_file, mmap_mode="r")

self.shape = self.data.shape
self._dim = len(self.data.shape)

def dim(self):
return self._dim

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = self.data[idx]
return torch.from_numpy(sample.copy())
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __getitem__(self, idx):
glob.glob(os.path.join(self.submission_paths[idx], "*.mrc"))
)
vol_paths = [vol_path for vol_path in vol_paths if "mask" not in vol_path]
vol_paths = vol_paths[:3]

assert len(vol_paths) > 0, "No volumes found in submission directory"

Expand Down
61 changes: 44 additions & 17 deletions src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Tuple

from ..._preprocessing.fourier_utils import downsample_volume
from ...data._dataloaders.gt_dataloader import GT_Dataset


def load_volumes(
Expand Down Expand Up @@ -75,16 +76,14 @@ def load_volumes(
return volumes, mean_volumes, metadata


def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):
def load_ref_vols(config: dict, dtype=torch.float32):
"""
Load the reference volumes, downsample them, normalize them, and remove the mean volume.

Parameters
----------
box_size_ds: int
Size of the downsampled box.
path_to_volumes: str
Path to the file containing the reference volumes. Must be in PyTorch format.
config: dict,
Dictionary containing the configuration parameters.
dtype: torch.dtype
Data type of the volumes.

Expand All @@ -99,27 +98,55 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):
>>> path_to_volumes = "/path/to/volumes.pt"
>>> volumes_ds = load_ref_vols(box_size_ds, path_to_volumes)
""" # noqa: E501
try:
volumes = torch.load(path_to_volumes)
except (FileNotFoundError, EOFError):
raise ValueError("Volumes not found or not in PyTorch format.")

path_to_volumes = config["reference_options"]["path_to_reference"]
box_size_ds = config["box_size_ds"]

volumes = GT_Dataset(path_to_volumes)

if config["reference_options"]["n_volumes"] is None:
n_vols = volumes.shape[0]
vol_skip = 1
random_subset = False

else:
n_vols = config["reference_options"]["n_volumes"]
vol_skip = volumes.shape[0] // n_vols
random_subset = config["reference_options"]["random_subset"]

if random_subset:
indices = torch.randperm(volumes.shape[0])[:n_vols]

else:
indices = torch.arange(0, n_vols) * vol_skip

# Reshape volumes to correct size
if volumes.dim() == 2:
box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.))))
volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size))
box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0))))
reshape = True

elif volumes.dim() == 4:
pass
else:
raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).")
raise ValueError(
f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"Please, review the file and regenerate it so that volumes stored have "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)."
)

volumes_ds = torch.empty(
(volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype
(n_vols, box_size_ds, box_size_ds, box_size_ds), dtype=dtype
)
for i, vol in enumerate(volumes):
volumes_ds[i] = downsample_volume(vol, box_size_ds)
for i, idx in enumerate(indices):
vol = volumes[idx]
if reshape:
volumes_ds[i] = downsample_volume(
vol.reshape(box_size, box_size, box_size), box_size_ds
)

else:
volumes_ds[i] = downsample_volume(vol, box_size_ds)

volumes_ds[i] = volumes_ds[i] / volumes_ds[i].sum()

mean_volume = volumes_ds.mean(dim=0)
Expand Down
28 changes: 19 additions & 9 deletions src/cryo_challenge/data/_validation/config_validators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numbers import Number
import pandas as pd
import os
from typing import List


def validate_generic_config(config: dict, reference: dict) -> None:
"""
Expand Down Expand Up @@ -282,17 +282,27 @@ def validate_config_svd(config: dict) -> None:
validate_config_svd_output(config["output_options"])

if config["experiment_mode"] == "all_vs_ref":
if "path_to_reference" not in config.keys():
if "reference_options" not in config.keys():
raise ValueError(
"Reference path is required for experiment mode 'all_vs_ref'"
"Reference options are required for experiment mode 'all_vs_ref'"
)

else:
assert isinstance(config["path_to_reference"], str)
os.path.exists(config["path_to_reference"])
assert (
"pt" in config["path_to_reference"]
), "Reference path point to a .pt file"
keys_and_types_ref = {
"path_to_reference": str,
}
validate_generic_config(config["reference_options"], keys_and_types_ref)

assert isinstance(config["reference_options"]["path_to_reference"], str)
os.path.exists(config["reference_options"]["path_to_reference"])
assert (
"npy" in config["reference_options"]["path_to_reference"]
), "Reference path point to a .npy file"

if "n_volumes" not in config["reference_options"].keys():
config["reference_options"]["n_volumes"] = None

if "random_subset" not in config["reference_options"].keys():
config["reference_options"]["random_subset"] = False

os.path.exists(config["path_to_volumes"])
for submission in config["submission_list"]:
Expand Down
4 changes: 3 additions & 1 deletion tests/config_files/test_config_svd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ box_size_ds: 32
submission_list: [0]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"
# optional unless experiment_mode is "all_vs_ref"
path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.pt
reference_options:
path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.npy

dtype: "float32" # options are "float32", "float64"
output_options:
# path will be created if it does not exist
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/fetch_test_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/dataset_2_s
ADIR=$(pwd)
ln -s $ADIR/tests/data/dataset_2_submissions/test_submission_0_n8.pt $ADIR/tests/data/dataset_2_submissions/submission_0.pt # symlink for svd which needs submission_0.pt for filename
wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_maps_gt_flat_10.pt?download=true -O tests/data/Ground_truth/test_maps_gt_flat_10.pt
wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_maps_gt_flat_10.npy?download=true -O tests/data/Ground_truth/test_maps_gt_flat_10.npy
wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/test_metadata_10.csv?download=true -O tests/data/Ground_truth/test_metadata_10.csv
wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/tests/Ground_truth/1.mrc?download=true -O tests/data/Ground_truth/1.mrc
wget https://files.osf.io/v1/resources/8h6fz/providers/dropbox/Ground_truth/mask_dilated_wide_224x224.mrc?download=true -O tests/data/Ground_truth/mask_dilated_wide_224x224.mrc
Expand Down
25 changes: 18 additions & 7 deletions tutorials/2_tutorial_svd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"# Select path to SVD config file\n",
"# An example of this file is available in the path ../config_files/config_svd.yaml\n",
"config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n",
"config_svd_path.filter_pattern = '*.yaml'\n",
"config_svd_path.filter_pattern = \"*.yaml\"\n",
"display(config_svd_path)"
]
},
Expand Down Expand Up @@ -93,11 +93,22 @@
"Here is a brief explanation of each key\n",
"\n",
"* path_to_volumes (str): this is the path to your submissions (the result of running the preprocessing). They should be called submission_0.pt, submission_1.pt, ...\n",
"\n",
"* box_size_ds (int): you can choose to downsample the volumes to speed up the analysis, or to get rid of high frequency features.\n",
"\n",
"* submission_list (List): here you can choose which submissions are used for the analysis. If you want to use submissions 0, 3, 6; then this should be [0, 3, 6]\n",
"\n",
"* experiment_mode (str): the options are \"all_vs_all\", \"all_vs_ref\". If you are using ref, then SVD is computed from the refence volumes and the rest of the volumes are projected to it. Otherwise, all volumes are used to do the projection\n",
"* path_to_reference (str): path to the reference volumes (only needed if mode is \"all_vs_ref\")\n",
"\n",
"* reference_options (dict)\n",
" * path_to_reference (str): path to the reference volumes (only needed if mode is \"all_vs_ref\"). Should be a .npy file that contains all the reference volumes, e.g., maps_flat.npy.\n",
" * n_volumes (int): number of volumes to use for analysis\n",
" * random_subset (bool): whether to use a random subset or not\n",
"\n",
" If you set `random_set = True`, then SVD will be run with randomly chosen n_volumes Volumes. Otherwise, the volumes are chosen as `skip_vols = total_volumes // n_volumes` and something equivalent to `volumes[::skip_vols, ...]`\n",
"\n",
"* dtype (str): can be float32 or float64\n",
"\n",
"* output_options (dict): dictionary with options to personalize the output\n",
" * output_path (str): where the volumes will be saved\n",
" * save_volumes (bool): whether or not to save the volumes used (this will save the normalized, downsampled, and mean-removed volumes)\n",
Expand Down Expand Up @@ -125,7 +136,7 @@
"source": [
"# Select path to SVD results\n",
"svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n",
"svd_results_path.filter_pattern = '*.pt'\n",
"svd_results_path.filter_pattern = \"*.pt\"\n",
"display(svd_results_path)"
]
},
Expand Down Expand Up @@ -316,7 +327,7 @@
"source": [
"# Select path to SVD results\n",
"svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n",
"svd_all_vs_all_results_path.filter_pattern = '*.pt'\n",
"svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n",
"display(svd_all_vs_all_results_path)"
]
},
Expand Down Expand Up @@ -425,9 +436,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "gpucryonerf",
"display_name": "cryo-challenge-kernel",
"language": "python",
"name": "python3"
"name": "cryo-challenge-kernel"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -439,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.10.10"
}
},
"nbformat": 4,
Expand Down
Loading