From ed7487ad8a3ab05aaa8f18394b23975dccaa2561 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 28 Nov 2023 20:25:40 +0100 Subject: [PATCH] PT distributed ctx cleanup --- returnn/torch/distributed.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 2d1dd2d470..e9f084d9a9 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -18,10 +18,7 @@ class DistributedContext: This class setups some helper functions for torch distributed training """ - def __init__(self, config): - """ - :param Config config: - """ + def __init__(self, config: Config): import torch.distributed as dist # when no backend is specified, both gloo and nccl backends will be created @@ -30,8 +27,8 @@ def __init__(self, config): dist.init_process_group(backend=None) self._config = config - self._local_rank = os.environ["LOCAL_RANK"] - self._local_size = os.environ["LOCAL_WORLD_SIZE"] + self._local_rank = int(os.environ["LOCAL_RANK"]) + self._local_size = int(os.environ["LOCAL_WORLD_SIZE"]) self._rank = dist.get_rank() self._size = dist.get_world_size() @@ -40,22 +37,20 @@ def __init__(self, config): % (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size) ) - def local_rank(self): - """ - :rtype: int - """ + def local_rank(self) -> int: + """local rank""" return self._local_rank - def rank(self): - """ - :rtype: int - """ + def local_size(self) -> int: + """local size""" + return self._local_size + + def rank(self) -> int: + """global rank""" return self._rank - def size(self): - """ - :rtype: int - """ + def size(self) -> int: + """global size""" return self._size