diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index dae111c0d..96d1dc6c0 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -640,6 +640,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, + self.parameters._configuration["device"] ) ) self.validation_data_sets.append( @@ -651,6 +652,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, + self.parameters._configuration["device"] ) ) @@ -664,6 +666,7 @@ def __build_datasets(self): self.descriptor_calculator, self.target_calculator, self.use_ddp, + self.parameters._configuration["device"] input_requires_grad=True, ) ) diff --git a/mala/datahandling/lazy_load_dataset.py b/mala/datahandling/lazy_load_dataset.py index f37fdb60d..a3af4ab64 100644 --- a/mala/datahandling/lazy_load_dataset.py +++ b/mala/datahandling/lazy_load_dataset.py @@ -59,6 +59,7 @@ def __init__( descriptor_calculator, target_calculator, use_ddp, + device, input_requires_grad=False, ): self.snapshot_list = [] @@ -79,6 +80,7 @@ def __init__( self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad + self.device = device @property def return_outputs_directly(self): @@ -119,8 +121,13 @@ def mix_datasets(self): used_perm = torch.randperm(self.number_of_snapshots) barrier() if self.use_ddp: + used_perm.to(device=self.device) used_perm = dist.broadcast(used_perm, 0) - self.snapshot_list = [self.snapshot_list[i] for i in used_perm] + self.snapshot_list = [ + self.snapshot_list[i] for i in used_perm.to("cpu") + ] + else: + self.snapshot_list = [self.snapshot_list[i] for i in used_perm] self.get_new_data(0) def get_new_data(self, file_index):