Skip to content

Commit

Permalink
Merge pull request #74 from flatironinstitute/adding_flavor_keys
Browse files Browse the repository at this point in the history
added flavor keys
  • Loading branch information
DSilva27 authored Aug 9, 2024
2 parents e3a5655 + 4a24b6d commit fc99c96
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 37 deletions.
1 change: 0 additions & 1 deletion config_files/config_preproc.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
submission_config_file: submission_config.json
seed_flavor_assignment: 0
thresh_percentile: 93.0
BOT_box_size: 32
BOT_loss: wemd
Expand Down
44 changes: 11 additions & 33 deletions src/cryo_challenge/_preprocessing/preprocessing_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import numpy as np
import json
import os

Expand Down Expand Up @@ -40,44 +39,20 @@ def save_submission(volumes, populations, submission_id, submission_index, confi


def preprocess_submissions(submission_dataset, config):
np.random.seed(config["seed_flavor_assignment"])
ice_cream_flavors = [
"Chocolate",
"Vanilla",
"Cookies N' Cream",
"Mint Chocolate Chip",
"Strawberry",
"Butter Pecan",
"Salted Caramel",
"Pistachio",
"Rocky Road",
"Coffee",
"Cookie Dough",
"Chocolate Chip",
"Neapolitan",
"Cherry",
"Rainbow Sherbet",
"Peanut Butter",
"Cotton Candy",
"Lemon Sorbet",
"Mango",
"Black Raspberry",
]

n_subs = max(submission_dataset.subs_index) + 1
random_mapping = np.random.choice(len(ice_cream_flavors), n_subs, replace=False)
hash_table = {}

box_size_gt = submission_dataset.submission_config["gt"]["box_size"]
pixel_size_gt = submission_dataset.submission_config["gt"]["pixel_size"]
vol_gt_ref = submission_dataset.vol_gt_ref

for i in range(len(submission_dataset)):
idx = submission_dataset.subs_index[i]

hash_table[submission_dataset.submission_config[str(idx)]["name"]] = (
ice_cream_flavors[random_mapping[idx]]
)
sub_flavor = submission_dataset.submission_config[str(idx)]["flavor_name"]
sub_name = submission_dataset.submission_config[str(idx)]["name"]
hash_table[sub_flavor] = {
"name": sub_name,
"filename": f"submission_{idx}.pt",
}

print(f"Preprocessing submission {idx}...")

Expand Down Expand Up @@ -126,8 +101,11 @@ def preprocess_submissions(submission_dataset, config):
submission_version = ""
else:
submission_version = f" {submission_version}"
print(f" SUBMISSIION VERSION {submission_version}")
submission_id = ice_cream_flavors[random_mapping[idx]] + submission_version
print(f" SUBMISSION VERSION {submission_version}")
submission_id = (
submission_dataset.submission_config[str(idx)]["flavor_name"]
+ submission_version
)
print(f"SUBMISSION ID {submission_id}")

save_submission(
Expand Down
3 changes: 1 addition & 2 deletions src/cryo_challenge/data/_validation/config_validators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numbers import Number
import pandas as pd
import os
from typing import List


def validate_generic_config(config: dict, reference: dict) -> None:
"""
Expand Down Expand Up @@ -48,7 +48,6 @@ def validate_config_preprocessing(config_data: dict) -> None:
"BOT_loss": str,
"BOT_iter": Number,
"BOT_refine": bool,
"seed_flavor_assignment": int,
}
validate_generic_config(config_data, keys_and_types)
return
Expand Down
1 change: 0 additions & 1 deletion tests/config_files/test_config_preproc.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
submission_config_file: tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json
seed_flavor_assignment: 0
thresh_percentile: 93.0
BOT_box_size: 32
BOT_loss: wemd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"0": {
"name": "raw_submission_in_testdata",
"align": 1,
"flavor_name": "test flavor",
"box_size": 244,
"pixel_size": 2.146,
"path": "tests/data/unprocessed_dataset_2_submissions/submission_x",
Expand Down

0 comments on commit fc99c96

Please sign in to comment.