Skip to content

Commit

Permalink
PT distributed ctx cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 28, 2023
1 parent d573b2b commit ed7487a
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions returnn/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ class DistributedContext:
This class setups some helper functions for torch distributed training
"""

def __init__(self, config):
"""
:param Config config:
"""
def __init__(self, config: Config):
import torch.distributed as dist

# when no backend is specified, both gloo and nccl backends will be created
Expand All @@ -30,8 +27,8 @@ def __init__(self, config):
dist.init_process_group(backend=None)

self._config = config
self._local_rank = os.environ["LOCAL_RANK"]
self._local_size = os.environ["LOCAL_WORLD_SIZE"]
self._local_rank = int(os.environ["LOCAL_RANK"])
self._local_size = int(os.environ["LOCAL_WORLD_SIZE"])
self._rank = dist.get_rank()
self._size = dist.get_world_size()

Expand All @@ -40,22 +37,20 @@ def __init__(self, config):
% (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size)
)

def local_rank(self):
"""
:rtype: int
"""
def local_rank(self) -> int:
"""local rank"""
return self._local_rank

def rank(self):
"""
:rtype: int
"""
def local_size(self) -> int:
"""local size"""
return self._local_size

def rank(self) -> int:
"""global rank"""
return self._rank

def size(self):
"""
:rtype: int
"""
def size(self) -> int:
"""global size"""
return self._size


Expand Down

0 comments on commit ed7487a

Please sign in to comment.