Skip to content

Commit

Permalink
Merge branch 'dev' into zernike_distance
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Dec 19, 2024
2 parents de37bf0 + 4602aea commit 3cc6707
Show file tree
Hide file tree
Showing 20 changed files with 1,757 additions and 669 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
pip install pytest omegaconf
pip install ".[dev]"
- name: Test with pytest
run: |
Expand Down
35 changes: 21 additions & 14 deletions config_files/config_svd.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
path_to_volumes: /path/to/volumes
box_size_ds: 32
submission_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"
# optional unless experiment_mode is "all_vs_ref"
path_to_reference: /path/to/reference/volumes.pt
dtype: "float32" # options are "float32", "float64"
output_options:
# path will be created if it does not exist
output_path: /path/to/output
# whether or not to save the processed volumes (downsampled, normalized, etc.)
save_volumes: True
# whether or not to save the SVD matrices (U, S, V)
save_svd_matrices: True
path_to_submissions: path/to/preprocessed/submissions/ # where all the submission_i.pt files are
#excluded_submissions: # you can exclude some submissions by filename, default = []
# - "submission_0.pt"
# - "submission_1.pt"
voxel_size: 1.0 # voxel size of the input maps (will probably be removed soon)

dtype: float32 # optional, default = float32
svd_max_rank: 5 # optional, default = full rank svd
normalize_params: # optional, if not given there will be no normalization
mask_path: path/to/mask.mrc # default = None, no masking applied
bfactor: 170 # default = None, no bfactor applied
box_size_ds: 16 # default = None, no downsampling applied

gt_params: # optional, if provided there will be extra results
gt_vols_file: path/to/gt_volumes.npy # volumes must be in .npy format (memory stuff)
skip_vols: 1 # default = 1, no volumes skipped. Equivalent to volumes[::skip_vols]

output_params:
output_file: path/to/output_file.pt # where the results will be saved
save_svd_data: True # optional, default = False
generate_plots: True # optional, default = False
110 changes: 110 additions & 0 deletions figure_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Some random code that I have found to be useful for plotting figures.
This should become part of the main repo at some point, I will leave it out for now.
- David
"""

from natsort import natsorted

# Here is how I generate the general dictionary parameter for plots
COLORS = {
"Coffee": "#97b4ff",
"Salted Caramel": "#97b4ff",
"Neapolitan": "#648fff",
"Peanut Butter": "#1858ff",
"Cherry": "#b3a4f7",
"Pina Colada": "#8c75f2",
"Chocolate": "#785ef0",
"Cookie Dough": "#512fec",
"Chocolate Chip": "#3d18e9",
"Vanilla": "#e35299",
"Mango": "#dc267f",
"Black Raspberry": "#ff8032",
"Rocky Road": "#fe6100",
"Ground Truth": "#ffb000",
"Mint Chocolate Chip": "#ffb000",
"Bubble Gum": "#ffb000",
}

PLOT_SETUP = {
"Salted Caramel": {"category": "1", "marker": "o"},
"Neapolitan": {"category": "1", "marker": "v"},
"Peanut Butter": {"category": "1", "marker": "^"},
"Coffee": {"category": "1", "marker": "<"},
"Cherry": {"category": "2", "marker": "o"},
"Pina Colada": {"category": "2", "marker": "v"},
"Cookie Dough": {"category": "2", "marker": "^"},
"Chocolate Chip": {"category": "2", "marker": "<"},
"Chocolate": {"category": "2", "marker": ">"},
"Vanilla": {"category": "3", "marker": "o"},
"Mango": {"category": "3", "marker": "v"},
"Rocky Road": {"category": "4", "marker": "o"},
"Black Raspberry": {"category": "4", "marker": "v"},
"Ground Truth": {"category": "5", "marker": "o"},
"Bubble Gum": {"category": "5", "marker": "v"},
"Mint Chocolate Chip": {"category": "5", "marker": "^"},
}

for key in list(PLOT_SETUP.keys()):
# PLOT_SETUP[key]["color"] = COLORS[PLOT_SETUP[key]["category"]]
PLOT_SETUP[key]["color"] = COLORS[key]


# These two functions are useful when setting the order of how to plot figures
def compare_strings(fixed_string, other_string):
return other_string.startswith(fixed_string)


def sort_labels_category(labels, plot_setup):
labels_sorted = []
for i in range(5): # there are 5 categories
for label in labels:
if plot_setup[label]["category"] == str(i + 1):
labels_sorted.append(label)

return labels_sorted


labels = ... # get labels from somwhere (pipeline results for example)

# This is the particular plot_setup for your data
plot_setup = {}
for i, label in enumerate(labels):
for (
possible_label
) in PLOT_SETUP.keys(): # generalized for labels like FLAVOR 1, FLAVOR 2, etc.
# print(label, possible_label)
if compare_strings(possible_label, label):
plot_setup[label] = PLOT_SETUP[possible_label]

for label in labels:
if label not in plot_setup.keys():
raise ValueError(f"Label {label} not found in PLOT_SETUP")

labels = sort_labels_category(natsorted(labels), plot_setup)


# Then I do something like this, which let's me configure how the
# labels will be displayed in the plot
labels_for_plot = {
"Neapolitan": "Neapolitan R1",
"Neapolitan 2": "Neapolitan R2",
"Peanut Butter": "Peanut Butter R1",
"Peanut Butter 2": "Peanut Butter R2",
"Salted Caramel": "Salted Caramel R1",
"Salted Caramel 2": "Salted Caramel R2 1",
"Salted Caramel 3": "Salted Caramel R2 2",
"Chocolate": "Chocolate R1",
"Chocolate 2": "Chocolate R2",
"Chocolate Chip": "Chocolate Chip R1",
"Cookie Dough": "Cookie Dough R1",
"Cookie Dough 2": "Cookie Dough R2",
"Pina Colada 1": "Piña Colada R2",
"Mango": "Mango R1",
"Vanilla": "Vanilla R1",
"Vanilla 2": "Vanilla R2",
"Black Raspberry": "Black Raspberry R1",
"Black Raspberry 2": "Black Raspberry R2",
"Rocky Road": "Rocky Road R1",
}
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ dependencies = [
"osfclient",
"seaborn",
"ipyfilechooser",
"omegaconf"
"omegaconf",
"pydantic",
"ecos"
]

[project.optional-dependencies]
Expand Down
24 changes: 15 additions & 9 deletions src/cryo_challenge/_commands/run_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import os
import yaml

from .._svd.svd_pipeline import run_all_vs_all_pipeline, run_all_vs_ref_pipeline
from ..data._validation.config_validators import validate_config_svd
from .._svd.svd_pipeline import run_svd_noref, run_svd_with_ref
from ..data._validation.config_validators import SVDConfig


def add_args(parser):
Expand Down Expand Up @@ -35,15 +35,21 @@ def main(args):
with open(args.config, "r") as file:
config = yaml.safe_load(file)

validate_config_svd(config)
warnexists(config["output_options"]["output_path"])
mkbasedir(config["output_options"]["output_path"])
config = SVDConfig(**config).model_dump()

if config["experiment_mode"] == "all_vs_all":
run_all_vs_all_pipeline(config)
warnexists(config["output_params"]["output_file"])
mkbasedir(os.path.dirname(config["output_params"]["output_file"]))

elif config["experiment_mode"] == "all_vs_ref":
run_all_vs_ref_pipeline(config)
output_path = os.path.dirname(config["output_params"]["output_file"])

with open(os.path.join(output_path, "config.yaml"), "w") as file:
yaml.dump(config, file)

if config["gt_params"] is None:
run_svd_noref(config)

else:
run_svd_with_ref(config)

return

Expand Down
22 changes: 19 additions & 3 deletions src/cryo_challenge/_preprocessing/preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ def save_submission(volumes, populations, submission_id, submission_index, confi
return submission_dict


def update_hash_table(hash_table_path, hash_table):
if os.path.exists(hash_table_path):
with open(hash_table_path, "r") as f:
hash_table_old = json.load(f)
hash_table_old.update(hash_table)

with open(hash_table_path, "w") as f:
json.dump(hash_table_old, f, indent=4)

else:
with open(hash_table_path, "w") as f:
json.dump(hash_table, f, indent=4)

return


def preprocess_submissions(submission_dataset, config):
hash_table = {}
box_size_gt = submission_dataset.submission_config["gt"]["box_size"]
Expand Down Expand Up @@ -79,7 +95,7 @@ def preprocess_submissions(submission_dataset, config):
volumes = threshold_submissions(volumes, config["thresh_percentile"])

# center submission
print(" Centering submission")
# print(" Centering submission")
# volumes = center_submission(volumes, pixel_size=pixel_size_gt)

# flip handedness
Expand Down Expand Up @@ -121,7 +137,7 @@ def preprocess_submissions(submission_dataset, config):
hash_table_path = os.path.join(
config["output_path"], "submission_to_icecream_table.json"
)
with open(hash_table_path, "w") as f:
json.dump(hash_table, f, indent=4)

update_hash_table(hash_table_path, hash_table)

return
Loading

0 comments on commit 3cc6707

Please sign in to comment.