From a20921392c848cfe7d44e52ca4715cc7d091c530 Mon Sep 17 00:00:00 2001 From: Sarthak Pati Date: Wed, 21 Jul 2021 19:16:07 -0400 Subject: [PATCH 1/2] Update gandlf_data.py --- fets/data/pytorch/gandlf_data.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fets/data/pytorch/gandlf_data.py b/fets/data/pytorch/gandlf_data.py index 99a6c08..564e3b5 100644 --- a/fets/data/pytorch/gandlf_data.py +++ b/fets/data/pytorch/gandlf_data.py @@ -20,6 +20,15 @@ from fets.data.gandlf_utils import get_dataframe_and_headers from fets.data import get_appropriate_file_paths_from_subject_dir +## added for reproducibility +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + numpy.random.seed(worker_seed) + random.seed(worker_seed) + +g = torch.Generator() +g.manual_seed(0) +## added for reproducibility # adapted from https://codereview.stackexchange.com/questions/132914/crop-black-border-of-image-using-numpy/132933#132933 def crop_image_outside_zeros(array, psize): @@ -672,7 +681,7 @@ def get_loaders(self, data_frame, train, augmentations): preprocessing=self.preprocessing, in_memory=self.in_memory) if train: - loader = DataLoader(data, shuffle=True, batch_size=self.batch_size) + loader = DataLoader(data, shuffle=True, batch_size=self.batch_size, worker_init_fn=seed_worker) else: loader = DataLoader(data, shuffle=False, batch_size=1) From 9a1d9f78e717c9e093e963e4569e04fe9d520961 Mon Sep 17 00:00:00 2001 From: Sarthak Pati Date: Wed, 21 Jul 2021 19:50:21 -0400 Subject: [PATCH 2/2] Update gandlf_data.py --- fets/data/pytorch/gandlf_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fets/data/pytorch/gandlf_data.py b/fets/data/pytorch/gandlf_data.py index 564e3b5..e1ebb12 100644 --- a/fets/data/pytorch/gandlf_data.py +++ b/fets/data/pytorch/gandlf_data.py @@ -683,7 +683,7 @@ def get_loaders(self, data_frame, train, augmentations): if train: loader = DataLoader(data, shuffle=True, batch_size=self.batch_size, worker_init_fn=seed_worker) else: - loader = DataLoader(data, shuffle=False, batch_size=1) + loader = DataLoader(data, shuffle=False, batch_size=1, , worker_init_fn=seed_worker) companion_loader = None if train: