Skip to content

Commit

Permalink
add fast launch method based on openmpi (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Dec 25, 2021
1 parent 87c92a7 commit e7d165f
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 77 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ How to setup Tutel MoE for Pytorch:
* Run Tutel MoE in Distributed Mode:
(Single-Node Multi-GPU based on standard Pytorch distributed launcher:)
$ python3 -m torch.distributed.launch --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16
(Multi-Node Multi-GPU based on standard Pytorch distributed launcher:)
(Method A - Torch launcher for `Multi-Node x Multi-GPU`:)
$ ssh <node-ip-0> python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
$ ssh <node-ip-1> python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
(Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:)
$ mpiexec -host <node-ip-0>,<node-ip-1>,.. \
-x LOCAL_SIZE=8 -x MASTER_ADDR=<node-ip-0> \
python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16
```

How to import Tutel-optimized MoE in Pytorch:
Expand Down
3 changes: 3 additions & 0 deletions tutel/launcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

42 changes: 42 additions & 0 deletions tutel/launcher/execl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os, re, sys
import logging
import argparse

def main():
parser = argparse.ArgumentParser()
parser.add_argument('-m', default=False, action='store_true')
parser.add_argument('rest', nargs=argparse.REMAINDER)
args = parser.parse_args()

local_rank = int(os.environ['LOCAL_RANK'])
local_size = int(os.environ['LOCAL_SIZE'])

os.environ['TUTEL_CUDA_SANDBOX'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)

cmd_args = []
try:
if not os.path.exists('/usr/bin/numactl'):
raise
local_size = int(os.environ['LOCAL_SIZE'])
cpu_nodes = sorted([str(x[4:]) for x in os.listdir('/sys/devices/system/node') if re.match('node[0-9]+', x)])
if len(cpu_nodes) <= local_size:
sel_nodes = cpu_nodes[(local_rank // (local_size // len(cpu_nodes))) % len(cpu_nodes)]
else:
sel_nodes = cpu_nodes[local_rank::local_size]
sel_nodes = ','.join(sel_nodes)

cmd_args = ['/usr/bin/numactl', '--cpunodebind=%s' % sel_nodes]
except Exception as ex:
if local_rank == 0:
logging.warning('`numactl` is not enabled by tutel.launcher.execl')

cmd_args += [sys.executable, '-m'] if args.m else []
cmd_args += args.rest
os.execl(cmd_args[0], *cmd_args)

if __name__ == "__main__":
main()
25 changes: 25 additions & 0 deletions tutel/launcher/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os, sys

def main():
host_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
host_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_size = int(os.environ.get('LOCAL_SIZE', 1))

master_addr = os.environ['MASTER_ADDR'] if host_size > 1 else 'localhost'
master_port = int(os.environ.get('MASTER_PORT', 23232))

cmd_args = [sys.executable, '-m', 'torch.distributed.launch', '--use_env',
'--nproc_per_node=%d' % local_size,
'--nnodes=%d' % host_size,
'--node_rank=%d' % host_rank,
'--master_addr=%s' % master_addr,
'--master_port=%s' % master_port,
'-m', 'tutel.launcher.execl',
] + sys.argv[1:]
os.execl(cmd_args[0], *cmd_args)

if __name__ == "__main__":
main()
150 changes: 77 additions & 73 deletions tutel/system_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,77 +25,81 @@ def init_affinity_at_program_beginning():
logging.warning('Failed to set NUMA status: %s' % ex)

def init_data_model_parallel(group_count=1, backend='nccl'):
import torch
import torch.distributed as dist
try:
if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ):
dist.init_process_group(backend=backend,
init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')),
rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']))
dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
import torch
import torch.distributed as dist
try:
if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ):
dist.init_process_group(backend=backend,
init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')),
rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']))
dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
else:
dist.init_process_group(backend=backend)
dist_local_rank = int(os.environ.get('LOCAL_RANK', 0))
if TUTEL_CUDA_SANDBOX:
dist_local_rank = 0
glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank()
is_distributed = True

def dist_print(*args):
if glob_world_rank == 0:
print(*args)
except ValueError:
glob_world_size, glob_world_rank, dist_local_rank = 1, 0, 0
is_distributed = False
dist_print = print

assert glob_world_size % group_count == 0, f"Expected to evenly divide devices into {group_count} groups, while the world size of current sesion is {glob_world_size}."

dist_group_size = group_count
dist_world_size = glob_world_size // dist_group_size
dist_world_rank = glob_world_rank % dist_world_size
dist_group_rank = glob_world_rank // dist_world_size

if is_distributed:
global_group = model_group = data_group = dist.group.WORLD

if dist_group_size != glob_world_size:
groups, inner_ranks = [], []
for gr in range(dist_group_size):
group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)]
groups += [dist.new_group(ranks=group_ranks)]
inner_ranks += [group_ranks]
model_group = groups[dist_group_rank]

if dist_world_size != glob_world_size:
groups, outer_ranks = [], []
for gr in range(dist_world_size):
group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)]
groups += [dist.new_group(ranks=group_ranks)]
outer_ranks += [group_ranks]
data_group = groups[dist_world_rank]
else:
dist.init_process_group(backend=backend)
dist_local_rank = int(os.environ.get('LOCAL_RANK', 0))
if TUTEL_CUDA_SANDBOX:
dist_local_rank = 0
glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank()
is_distributed = True

def dist_print(*args):
if glob_world_rank == 0:
print(*args)
except ValueError:
glob_world_size, glob_world_rank, dist_local_rank = 1, 0, 0
is_distributed = False
dist_print = print

assert glob_world_size % group_count == 0, f"Expected to evenly divide devices into {group_count} groups, while the world size of current sesion is {glob_world_size}."

dist_group_size = group_count
dist_world_size = glob_world_size // dist_group_size
dist_world_rank = glob_world_rank % dist_world_size
dist_group_rank = glob_world_rank // dist_world_size

if is_distributed:
global_group = model_group = data_group = dist.group.WORLD

if dist_group_size != glob_world_size:
groups, inner_ranks = [], []
for gr in range(dist_group_size):
group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)]
groups += [dist.new_group(ranks=group_ranks)]
inner_ranks += [group_ranks]
model_group = groups[dist_group_rank]

if dist_world_size != glob_world_size:
groups, outer_ranks = [], []
for gr in range(dist_world_size):
group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)]
groups += [dist.new_group(ranks=group_ranks)]
outer_ranks += [group_ranks]
data_group = groups[dist_world_rank]
else:
model_group, data_group, global_group = None, None, None

result = init_data_model_parallel
result.global_size = glob_world_size
result.global_rank = glob_world_rank
result.group_count = dist_group_size
result.data_rank = dist_group_rank
result.model_rank = dist_world_rank

if backend == 'nccl':
result.local_device = torch.device('cuda', dist_local_rank)
torch.cuda.set_device(result.local_device)
else:
result.local_device = torch.device('cpu')

result.data_group = data_group
result.model_group = model_group
result.global_group = global_group

result.is_distributed = is_distributed
result.dist_print = dist_print

logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}')
return result
model_group, data_group, global_group = None, None, None

result = init_data_model_parallel
result.global_size = glob_world_size
result.global_rank = glob_world_rank
result.group_count = dist_group_size
result.data_rank = dist_group_rank
result.model_rank = dist_world_rank

if backend == 'nccl':
result.local_device = torch.device('cuda', dist_local_rank)
torch.cuda.set_device(result.local_device)
else:
result.local_device = torch.device('cpu')

result.data_group = data_group
result.model_group = model_group
result.global_group = global_group

result.is_distributed = is_distributed
result.dist_print = dist_print

# Temp work around for: https://github.com/pytorch/pytorch/issues/56390
import atexit
atexit.register(lambda *args: os._exit(0))

logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}')
return result

0 comments on commit e7d165f

Please sign in to comment.