Skip to content

Commit

Permalink
Merge pull request #14 from saturncloud/feat/local-rank-support
Browse files Browse the repository at this point in the history
feat/local rank support
  • Loading branch information
hhuuggoo authored Feb 10, 2021
2 parents d9a6e4f + a5b65f8 commit 76cf3a5
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 23 deletions.
87 changes: 64 additions & 23 deletions dask_pytorch_ddp/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,88 @@
"""

import os
from typing import List, Callable, Tuple, Any
from typing import List, Callable, Any, Dict
from dask.distributed import Client
import torch.distributed as dist


def _get_worker_info(client: Client) -> Tuple[List[str], str]:
def _get_worker_info(client: Client) -> List[Dict]:
"""
returns a list of workers (sorted), and the DNS name for the master host
The master is the 0th worker's host
"""
workers = client.scheduler_info()["workers"]
worker_keys = sorted(workers.keys())
workers_by_host: Dict[str, List[str]] = {}
for key in worker_keys:
worker = workers[key]
host = worker["host"]
workers_by_host.setdefault(host, []).append(key)
host = workers[worker_keys[0]]["host"]
return worker_keys, host
all_workers = []
global_rank = 0
for host in sorted(workers_by_host.keys()):
local_rank = 0
for worker in workers_by_host[host]:
all_workers.append(
dict(
worker=worker,
local_rank=local_rank,
global_rank=global_rank,
host=host,
)
)
local_rank += 1
global_rank += 1
return all_workers


def run(client: Client, pytorch_function: Callable, *args, backend: str = "nccl", **kwargs):
def run(
client: Client,
pytorch_function: Callable,
*args,
backend: str = "nccl",
pass_local_rank: bool = False,
**kwargs
):
"""
Dispatch a pytorch function over a dask cluster, and returns a list of futures
for the resulting tasks
"""
worker_keys, host = _get_worker_info(client)
world_size = len(worker_keys)
all_workers = _get_worker_info(client)
world_size = len(all_workers)
port = 23456 # pick a free port?

futures = [
client.submit(
dispatch_with_ddp,
pytorch_function=pytorch_function,
master_addr=host,
master_port=port,
rank=idx,
world_size=world_size,
*args,
backend=backend,
workers=[w],
**kwargs
)
for idx, w in enumerate(worker_keys)
]

host = all_workers[0]["host"]
futures = []
for worker in all_workers:
if pass_local_rank:
fut = client.submit(
dispatch_with_ddp,
pytorch_function=pytorch_function,
master_addr=host,
master_port=port,
rank=worker["global_rank"],
world_size=world_size,
*args,
local_rank=worker["local_rank"],
backend=backend,
workers=[worker["worker"]],
**kwargs
)
else:
fut = client.submit(
dispatch_with_ddp,
pytorch_function=pytorch_function,
master_addr=host,
master_port=port,
rank=worker["global_rank"],
world_size=world_size,
*args,
backend=backend,
workers=[worker["worker"]],
**kwargs
)
futures.append(fut)
return futures


Expand Down
134 changes: 134 additions & 0 deletions tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,140 @@ def test_run():
assert output == fake_results


def test_run_with_local_rank_simple():
client = Mock()
client.scheduler_info = Mock(return_value={"workers": workers})

fake_pytorch_func = Mock()

fake_results = []
worker_keys = sorted(workers.keys())
for idx, worker in enumerate(worker_keys):
r = Mock()
r.result = Mock(return_value=idx)
fake_results.append(r)

client.submit = Mock(side_effect=fake_results)
output = run(client, fake_pytorch_func, pass_local_rank=True)

client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=0,
local_rank=0,
world_size=len(workers),
workers=[worker_keys[0]],
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=1,
local_rank=0,
workers=[worker_keys[1]],
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=2,
local_rank=0,
workers=[worker_keys[2]],
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=3,
local_rank=0,
workers=[worker_keys[3]],
world_size=len(workers),
backend="nccl",
)
assert output == fake_results


def test_run_with_local_rank_complex():
workers = {
"tcp://1.2.3.4:8786": {"host": "1.2.3.4"},
"tcp://1.2.3.4:8787": {"host": "1.2.3.4"},
"tcp://3.2.3.4:8786": {"host": "3.2.3.4"},
"tcp://3.2.3.4:8787": {"host": "3.2.3.4"},
}
host_name = sorted(workers.keys())[0]
host = workers[host_name]["host"]
client = Mock()
client.scheduler_info = Mock(return_value={"workers": workers})

fake_pytorch_func = Mock()

fake_results = []
worker_keys = sorted(workers.keys())
for idx, worker in enumerate(worker_keys):
r = Mock()
r.result = Mock(return_value=idx)
fake_results.append(r)

client.submit = Mock(side_effect=fake_results)
output = run(client, fake_pytorch_func, pass_local_rank=True)

client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=0,
local_rank=0,
world_size=len(workers),
workers=[worker_keys[0]],
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=1,
local_rank=1,
workers=[worker_keys[1]],
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=2,
local_rank=0,
workers=[worker_keys[2]],
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=3,
local_rank=1,
workers=[worker_keys[3]],
world_size=len(workers),
backend="nccl",
)
assert output == fake_results


def test_dispatch_with_ddp():
pytorch_func = Mock()

Expand Down

0 comments on commit 76cf3a5

Please sign in to comment.