Skip to content

Commit

Permalink
Testing if distributed samplers work as default
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Apr 30, 2024
1 parent 04b0050 commit 1fb2c98
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 55 deletions.
33 changes: 0 additions & 33 deletions mala/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,6 @@ def __init__(self):
# Properties
self.use_gpu = False
self.use_ddp = False
self.use_distributed_sampler_train = True
self.use_distributed_sampler_val = True
self.use_distributed_sampler_test = True
self.use_mpi = False
self.verbosity = 1
self.device = "cpu"
Expand Down Expand Up @@ -1300,36 +1297,6 @@ def use_ddp(self):
"""Control whether or not dd is used for parallel training."""
return self._use_ddp

@property
def use_distributed_sampler_train(self):
"""Control wether or not distributed sampler is used to distribute training data."""
return self._use_distributed_sampler_train

@use_distributed_sampler_train.setter
def use_distributed_sampler_train(self, value):
"""Control whether or not distributed sampler is used to distribute training data."""
self._use_distributed_sampler_train = value

@property
def use_distributed_sampler_val(self):
"""Control whether or not distributed sampler is used to distribute validation data."""
return self._use_distributed_sampler_val

@use_distributed_sampler_val.setter
def use_distributed_sampler_val(self, value):
"""Control whether or not distributed sampler is used to distribute validation data."""
self._use_distributed_sampler_val = value

@property
def use_distributed_sampler_test(self):
"""Control whether or not distributed sampler is used to distribute test data."""
return self._use_distributed_sampler_test

@use_distributed_sampler_test.setter
def use_distributed_sampler_test(self, value):
"""Control whether or not distributed sampler is used to distribute test data."""
self._use_distributed_sampler_test = value

@use_ddp.setter
def use_ddp(self, value):
if value:
Expand Down
41 changes: 19 additions & 22 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,36 +652,33 @@ def __prepare_to_train(self, optimizer_dict):
if self.data.parameters.use_lazy_loading:
do_shuffle = False

if self.parameters_full.use_distributed_sampler_train:
self.train_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.data.training_data_sets[0],
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=do_shuffle,
)
self.train_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.data.training_data_sets[0],
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=do_shuffle,
)
if self.parameters_full.use_distributed_sampler_val:
self.validation_sampler = (
)
self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.data.validation_data_sets[0],
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
)
)

if self.data.test_data_sets:
self.test_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.data.validation_data_sets[0],
self.data.test_data_sets[0],
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
)
)

if self.parameters_full.use_distributed_sampler_test:
if self.data.test_data_sets:
self.test_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.data.test_data_sets[0],
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
)
)

# Instantiate the learning rate scheduler, if necessary.
if self.parameters.learning_rate_scheduler == "ReduceLROnPlateau":
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
Expand Down

0 comments on commit 1fb2c98

Please sign in to comment.