diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 711d2aaa9..6a431e04f 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1335,15 +1335,18 @@ def use_ddp(self, value): if value: print("initializing torch.distributed.") # JOSHR: - # We start up torch distributed here. As is fairly standard convention, we get the rank - # and world size arguments via environment variables (RANK, WORLD_SIZE). In addition to - # those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT should be set. + # We start up torch distributed here. As is fairly standard + # convention, we get the rank and world size arguments via + # environment variables (RANK, WORLD_SIZE). In addition to + # those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT + # should be set. rank = int(os.environ.get("RANK")) world_size = int(os.environ.get("WORLD_SIZE")) + dist.init_process_group("nccl", rank=rank, world_size=world_size) - # Invalidate, will be updated in setter. set_ddp_status(value) + # Invalidate, will be updated in setter. self.device = None self._use_ddp = value self.network._update_ddp(self.use_ddp)