Skip to content

Commit

Permalink
This should fix lazy loading mixing
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed May 2, 2024
1 parent f49e63d commit e1753d0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 3 additions & 0 deletions mala/datahandling/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def __build_datasets(self):
self.descriptor_calculator,
self.target_calculator,
self.use_ddp,
self.parameters._configuration["device"]
)
)
self.validation_data_sets.append(
Expand All @@ -651,6 +652,7 @@ def __build_datasets(self):
self.descriptor_calculator,
self.target_calculator,
self.use_ddp,
self.parameters._configuration["device"]
)
)

Expand All @@ -664,6 +666,7 @@ def __build_datasets(self):
self.descriptor_calculator,
self.target_calculator,
self.use_ddp,
self.parameters._configuration["device"]
input_requires_grad=True,
)
)
Expand Down
9 changes: 8 additions & 1 deletion mala/datahandling/lazy_load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
descriptor_calculator,
target_calculator,
use_ddp,
device,
input_requires_grad=False,
):
self.snapshot_list = []
Expand All @@ -79,6 +80,7 @@ def __init__(
self.use_ddp = use_ddp
self.return_outputs_directly = False
self.input_requires_grad = input_requires_grad
self.device = device

@property
def return_outputs_directly(self):
Expand Down Expand Up @@ -119,8 +121,13 @@ def mix_datasets(self):
used_perm = torch.randperm(self.number_of_snapshots)
barrier()
if self.use_ddp:
used_perm.to(device=self.device)
used_perm = dist.broadcast(used_perm, 0)
self.snapshot_list = [self.snapshot_list[i] for i in used_perm]
self.snapshot_list = [
self.snapshot_list[i] for i in used_perm.to("cpu")
]
else:
self.snapshot_list = [self.snapshot_list[i] for i in used_perm]
self.get_new_data(0)

def get_new_data(self, file_index):
Expand Down

0 comments on commit e1753d0

Please sign in to comment.