Skip to content

Commit

Permalink
implement submission version and volume flipping
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 6, 2024
1 parent f80dc6a commit eb5bb62
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
41 changes: 26 additions & 15 deletions src/cryo_challenge/_preprocessing/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ class SubmissionPreprocessingDataLoader(Dataset):

def __init__(self, submission_config):
self.submission_config = submission_config
self.submission_paths, self.gt_path = self.extract_submission_paths()
self.validate_submission_config()

self.submission_paths, self.population_files, self.gt_path = (
self.extract_submission_paths()
)
self.subs_index = [int(idx) for idx in list(self.submission_config.keys())[1:]]
path_to_gt_ref = os.path.join(
self.gt_path, self.submission_config["gt"]["ref_align_fname"]
Expand Down Expand Up @@ -53,32 +57,40 @@ def validate_submission_config(self):
raise ValueError("Box size not found for ground truth")
if "pixel_size" not in value.keys():
raise ValueError("Pixel size not found for ground truth")
if "ref_align_fname" not in value.keys():
raise ValueError(
"Reference align file name not found for ground truth"
)
continue
else:
if "path" not in value.keys():
raise ValueError(f"Path not found for submission {key}")
if "id" not in value.keys():
raise ValueError(f"ID not found for submission {key}")
if "name" not in value.keys():
raise ValueError(f"Name not found for submission {key}")
if "box_size" not in value.keys():
raise ValueError(f"Box size not found for submission {key}")
if "pixel_size" not in value.keys():
raise ValueError(f"Pixel size not found for submission {key}")
if "align" not in value.keys():
raise ValueError(f"Align not found for submission {key}")
if "populations_file" not in value.keys():
raise ValueError(f"Population file not found for submission {key}")
if "flip" not in value.keys():
raise ValueError(f"Flip not found for submission {key}")

if "submission_version" not in value.keys():
raise ValueError(
f"Submission version not found for submission {key}"
)
if not os.path.exists(value["path"]):
raise ValueError(f"Path {value['path']} does not exist")

if not os.path.isdir(value["path"]):
raise ValueError(f"Path {value['path']} is not a directory")

ids = list(self.submission_config.keys())[1:]
if ids != list(range(len(ids))):
raise ValueError(
"Submission IDs should be integers starting from 0 and increasing by 1"
)
if not os.path.exists(value["populations_file"]):
raise ValueError(
f"Population file {value['populations_file']} does not exist"
)

return

Expand Down Expand Up @@ -137,13 +149,16 @@ def help(cls):

def extract_submission_paths(self):
submission_paths = []
population_files = []
for key, value in self.submission_config.items():
if key == "gt":
gt_path = value["path"]

else:
submission_paths.append(value["path"])
return submission_paths, gt_path
population_files.append(value["populations_file"])

return submission_paths, population_files, gt_path

def __len__(self):
return len(self.submission_paths)
Expand All @@ -153,13 +168,9 @@ def __getitem__(self, idx):
glob.glob(os.path.join(self.submission_paths[idx], "*.mrc"))
)
vol_paths = [vol_path for vol_path in vol_paths if "mask" not in vol_path]

assert len(vol_paths) > 0, "No volumes found in submission directory"

populations = np.loadtxt(
os.path.join(self.submission_paths[idx], "populations.txt")
)
populations = torch.from_numpy(populations)
populations = torch.from_numpy(np.loadtxt(self.population_files[idx]))

vol0 = mrcfile.open(vol_paths[0], mode="r")
volumes = torch.zeros(
Expand Down
13 changes: 12 additions & 1 deletion src/cryo_challenge/_preprocessing/preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,21 @@ def preprocess_submissions(submission_dataset, config):

# save preprocessed volumes
print(" Saving preprocessed submission")
submission_version = submission_dataset.submission_config[str(idx)][
"submission_version"
]
if str(submission_version) == "0":
submission_version = ""
else:
submission_version = f" {submission_version}"
print(f" SUBMISSIION VERSION {submission_version}")
submission_id = ice_cream_flavors[random_mapping[idx]] + submission_version
print(f"SUBMISSION ID {submission_id}")

save_submission(
volumes,
submission_dataset[i]["populations"],
ice_cream_flavors[random_mapping[idx]],
submission_id,
idx,
config,
)
Expand Down

0 comments on commit eb5bb62

Please sign in to comment.