diff --git a/src/cryo_challenge/_preprocessing/dataloader.py b/src/cryo_challenge/_preprocessing/dataloader.py index 4cc75d6..27ca57a 100644 --- a/src/cryo_challenge/_preprocessing/dataloader.py +++ b/src/cryo_challenge/_preprocessing/dataloader.py @@ -25,7 +25,11 @@ class SubmissionPreprocessingDataLoader(Dataset): def __init__(self, submission_config): self.submission_config = submission_config - self.submission_paths, self.gt_path = self.extract_submission_paths() + self.validate_submission_config() + + self.submission_paths, self.population_files, self.gt_path = ( + self.extract_submission_paths() + ) self.subs_index = [int(idx) for idx in list(self.submission_config.keys())[1:]] path_to_gt_ref = os.path.join( self.gt_path, self.submission_config["gt"]["ref_align_fname"] @@ -53,32 +57,40 @@ def validate_submission_config(self): raise ValueError("Box size not found for ground truth") if "pixel_size" not in value.keys(): raise ValueError("Pixel size not found for ground truth") + if "ref_align_fname" not in value.keys(): + raise ValueError( + "Reference align file name not found for ground truth" + ) continue else: if "path" not in value.keys(): raise ValueError(f"Path not found for submission {key}") - if "id" not in value.keys(): - raise ValueError(f"ID not found for submission {key}") + if "name" not in value.keys(): + raise ValueError(f"Name not found for submission {key}") if "box_size" not in value.keys(): raise ValueError(f"Box size not found for submission {key}") if "pixel_size" not in value.keys(): raise ValueError(f"Pixel size not found for submission {key}") if "align" not in value.keys(): raise ValueError(f"Align not found for submission {key}") + if "populations_file" not in value.keys(): + raise ValueError(f"Population file not found for submission {key}") if "flip" not in value.keys(): raise ValueError(f"Flip not found for submission {key}") - + if "submission_version" not in value.keys(): + raise ValueError( + f"Submission version not found for submission {key}" + ) if not os.path.exists(value["path"]): raise ValueError(f"Path {value['path']} does not exist") if not os.path.isdir(value["path"]): raise ValueError(f"Path {value['path']} is not a directory") - ids = list(self.submission_config.keys())[1:] - if ids != list(range(len(ids))): - raise ValueError( - "Submission IDs should be integers starting from 0 and increasing by 1" - ) + if not os.path.exists(value["populations_file"]): + raise ValueError( + f"Population file {value['populations_file']} does not exist" + ) return @@ -137,13 +149,16 @@ def help(cls): def extract_submission_paths(self): submission_paths = [] + population_files = [] for key, value in self.submission_config.items(): if key == "gt": gt_path = value["path"] else: submission_paths.append(value["path"]) - return submission_paths, gt_path + population_files.append(value["populations_file"]) + + return submission_paths, population_files, gt_path def __len__(self): return len(self.submission_paths) @@ -153,13 +168,9 @@ def __getitem__(self, idx): glob.glob(os.path.join(self.submission_paths[idx], "*.mrc")) ) vol_paths = [vol_path for vol_path in vol_paths if "mask" not in vol_path] - assert len(vol_paths) > 0, "No volumes found in submission directory" - populations = np.loadtxt( - os.path.join(self.submission_paths[idx], "populations.txt") - ) - populations = torch.from_numpy(populations) + populations = torch.from_numpy(np.loadtxt(self.population_files[idx])) vol0 = mrcfile.open(vol_paths[0], mode="r") volumes = torch.zeros( diff --git a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py index 926b2c1..589239c 100644 --- a/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py +++ b/src/cryo_challenge/_preprocessing/preprocessing_pipeline.py @@ -119,10 +119,21 @@ def preprocess_submissions(submission_dataset, config): # save preprocessed volumes print(" Saving preprocessed submission") + submission_version = submission_dataset.submission_config[str(idx)][ + "submission_version" + ] + if str(submission_version) == "0": + submission_version = "" + else: + submission_version = f" {submission_version}" + print(f" SUBMISSIION VERSION {submission_version}") + submission_id = ice_cream_flavors[random_mapping[idx]] + submission_version + print(f"SUBMISSION ID {submission_id}") + save_submission( volumes, submission_dataset[i]["populations"], - ice_cream_flavors[random_mapping[idx]], + submission_id, idx, config, ) diff --git a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json index 354060e..5671514 100644 --- a/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json +++ b/tests/data/unprocessed_dataset_2_submissions/submission_x/submission_config.json @@ -12,6 +12,8 @@ "box_size": 244, "pixel_size": 2.146, "path": "tests/data/unprocessed_dataset_2_submissions/submission_x", - "flip": 1 + "flip": 1, + "populations_file": "tests/data/unprocessed_dataset_2_submissions/submission_x/populations.txt", + "submission_version": "1.0" } }