Skip to content

Commit

Permalink
implement improved svd pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 28, 2024
1 parent ddb19f9 commit d4592c1
Show file tree
Hide file tree
Showing 8 changed files with 771 additions and 528 deletions.
36 changes: 19 additions & 17 deletions config_files/config_svd.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
path_to_volumes: /path/to/volumes
box_size_ds: 32
submission_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"
path_to_submissions: path/to/preprocessed/submissions/ # where all the submission_i.pt files are
#excluded_submissions: # you can exclude some submissions by filename, default = []
# - "submission_0.pt"
# - "submission_1.pt"
voxel_size: 1.0 # voxel size of the input maps (will probably be removed soon)

power_spectrum_normalization:
ref_vol_key: "FLAVOR" # which submission should be used
ref_vol_index: 0 # which volume of that submission should be used
dtype: float32 # optional, default = float32
svd_max_rank: 5 # optional, default = full rank svd
normalize_params: # optional, if not given there will be no normalization
mask_path: path/to/mask.mrc # default = None, no masking applied
bfactor: 170 # default = None, no bfactor applied
box_size_ds: 16 # default = None, no downsampling applied

# optional unless experiment_mode is "all_vs_ref"
path_to_reference: /path/to/reference/volumes.pt
dtype: "float32" # options are "float32", "float64"
output_options:
# path to file will be created if it does not exist
output_file: /path/to/output_file.pt
# whether or not to save the processed volumes (downsampled, normalized, etc.)
save_volumes: False
# whether or not to save the SVD matrices (U, S, V)
save_svd_matrices: False
gt_params: # optional, if provided there will be extra results
gt_vols_file: path/to/gt_volumes.npy # volumes must be in .npy format (memory stuff)
skip_vols: 1 # default = 1, no volumes skipped. Equivalent to volumes[::skip_vols]

output_params:
output_file: path/to/output_file.pt # where the results will be saved
save_svd_data: True # optional, default = False
generate_plots: True # optional, default = False
24 changes: 15 additions & 9 deletions src/cryo_challenge/_commands/run_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import os
import yaml

from .._svd.svd_pipeline import run_all_vs_all_pipeline, run_all_vs_ref_pipeline
from ..data._validation.config_validators import validate_config_svd
from .._svd.svd_pipeline import run_svd_noref, run_svd_with_ref
from ..data._validation.config_validators import SVDConfig


def add_args(parser):
Expand Down Expand Up @@ -35,15 +35,21 @@ def main(args):
with open(args.config, "r") as file:
config = yaml.safe_load(file)

validate_config_svd(config)
warnexists(config["output_options"]["output_file"])
mkbasedir(os.path.dirname(config["output_options"]["output_file"]))
config = SVDConfig(**config).dict()

if config["experiment_mode"] == "all_vs_all":
run_all_vs_all_pipeline(config)
warnexists(config["output_params"]["output_file"])
mkbasedir(os.path.dirname(config["output_params"]["output_file"]))

elif config["experiment_mode"] == "all_vs_ref":
run_all_vs_ref_pipeline(config)
output_path = os.path.dirname(config["output_params"]["output_file"])

with open(os.path.join(output_path, "config.yaml"), "w") as file:
yaml.dump(config, file)

if config["gt_params"] is None:
run_svd_noref(config)

else:
run_svd_with_ref(config)

return

Expand Down
Loading

0 comments on commit d4592c1

Please sign in to comment.