From 1fb2c98e75e0d12e3a96016b09e06fe7df40c7e2 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 30 Apr 2024 08:50:09 +0200 Subject: [PATCH] Testing if distributed samplers work as default --- mala/common/parameters.py | 33 ------------------------------- mala/network/trainer.py | 41 ++++++++++++++++++--------------------- 2 files changed, 19 insertions(+), 55 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 6a431e04f..65523d048 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1208,9 +1208,6 @@ def __init__(self): # Properties self.use_gpu = False self.use_ddp = False - self.use_distributed_sampler_train = True - self.use_distributed_sampler_val = True - self.use_distributed_sampler_test = True self.use_mpi = False self.verbosity = 1 self.device = "cpu" @@ -1300,36 +1297,6 @@ def use_ddp(self): """Control whether or not dd is used for parallel training.""" return self._use_ddp - @property - def use_distributed_sampler_train(self): - """Control wether or not distributed sampler is used to distribute training data.""" - return self._use_distributed_sampler_train - - @use_distributed_sampler_train.setter - def use_distributed_sampler_train(self, value): - """Control whether or not distributed sampler is used to distribute training data.""" - self._use_distributed_sampler_train = value - - @property - def use_distributed_sampler_val(self): - """Control whether or not distributed sampler is used to distribute validation data.""" - return self._use_distributed_sampler_val - - @use_distributed_sampler_val.setter - def use_distributed_sampler_val(self, value): - """Control whether or not distributed sampler is used to distribute validation data.""" - self._use_distributed_sampler_val = value - - @property - def use_distributed_sampler_test(self): - """Control whether or not distributed sampler is used to distribute test data.""" - return self._use_distributed_sampler_test - - @use_distributed_sampler_test.setter - def use_distributed_sampler_test(self, value): - """Control whether or not distributed sampler is used to distribute test data.""" - self._use_distributed_sampler_test = value - @use_ddp.setter def use_ddp(self, value): if value: diff --git a/mala/network/trainer.py b/mala/network/trainer.py index a100d5c35..bb9d4d41b 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -652,36 +652,33 @@ def __prepare_to_train(self, optimizer_dict): if self.data.parameters.use_lazy_loading: do_shuffle = False - if self.parameters_full.use_distributed_sampler_train: - self.train_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.data.training_data_sets[0], - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=do_shuffle, - ) + self.train_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.training_data_sets[0], + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=do_shuffle, ) - if self.parameters_full.use_distributed_sampler_val: - self.validation_sampler = ( + ) + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.data.validation_data_sets[0], + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + ) + ) + + if self.data.test_data_sets: + self.test_sampler = ( torch.utils.data.distributed.DistributedSampler( - self.data.validation_data_sets[0], + self.data.test_data_sets[0], num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, ) ) - if self.parameters_full.use_distributed_sampler_test: - if self.data.test_data_sets: - self.test_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.data.test_data_sets[0], - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=False, - ) - ) - # Instantiate the learning rate scheduler, if necessary. if self.parameters.learning_rate_scheduler == "ReduceLROnPlateau": self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(