Skip to content

Commit

Permalink
Finally worked out the unit test for dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
skirmer committed Nov 18, 2020
1 parent 8882106 commit 0da5996
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions tests/test_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

from unittest.mock import Mock, patch

from dask_pytorch.dispatch import run, dispatch_with_ddp
Expand Down Expand Up @@ -32,39 +33,39 @@ def test_run():

client.submit.assert_any_call(
dispatch_with_ddp,
fake_pytorch_func,
host,
23456,
0,
len(workers),
"nccl",
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=0,
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
fake_pytorch_func,
host,
23456,
1,
len(workers),
"nccl",
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=1,
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
fake_pytorch_func,
host,
23456,
2,
len(workers),
"nccl",
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=2,
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp,
fake_pytorch_func,
host,
23456,
3,
len(workers),
"nccl",
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=3,
world_size=len(workers),
backend="nccl",
)
assert output == fake_results

Expand Down

0 comments on commit 0da5996

Please sign in to comment.