This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 52
/
distributed.py
91 lines (77 loc) · 2.76 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import subprocess
from dataclasses import dataclass
import torch.distributed as dist
@dataclass(frozen=True, repr=True, eq=True, unsafe_hash=True)
class DistributedContext:
is_distributed: bool
rank: int
local_rank: int
world_size: int
mode: str
@property
def is_leader(self) -> bool:
return self.rank == 0
def init_distributed_context(port: int) -> DistributedContext:
# Sometimes the nccl backend hangs on the barrier op (https://github.com/pytorch/pytorch/issues/53658).
# Since it is the only op we care about here, we'd use the gloo backend.
BACKEND = "gloo"
# default, non-distributed context
context = DistributedContext(
is_distributed=False, rank=0, local_rank=0, world_size=1, mode="none"
)
launch_keys = ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK", "LOCAL_RANK"]
slurm_keys = [
"SLURM_LOCALID",
"SLURM_PROCID",
"SLURM_NTASKS",
"SLURM_NODEID",
"SLURM_JOB_NODELIST",
]
# is it torch.distributed.launch?
if all(key in os.environ for key in launch_keys):
init_method = "env://"
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
context = DistributedContext(
is_distributed=True,
rank=rank,
world_size=world_size,
local_rank=local_rank,
mode="launch",
)
dist.init_process_group(
backend=BACKEND, init_method=init_method, world_size=world_size, rank=rank
)
# is it slurm?
elif all(key in os.environ for key in slurm_keys):
init_method = "env://"
local_rank = int(os.environ["SLURM_LOCALID"])
rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NTASKS"])
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
)
leader_addr = hostnames.split()[0].decode("utf-8")
os.environ["MASTER_ADDR"] = leader_addr
os.environ["MASTER_PORT"] = str(port)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
context = DistributedContext(
is_distributed=True,
rank=rank,
local_rank=local_rank,
world_size=world_size,
mode="slurm",
)
dist.init_process_group(
backend=BACKEND,
init_method=init_method,
world_size=world_size,
rank=rank,
)
return context