Skip to content

Commit

Permalink
Merge pull request #3 from saturncloud/skirmer/linting
Browse files Browse the repository at this point in the history
fixing unit tests and linting for deployment
  • Loading branch information
skirmer authored Nov 18, 2020
2 parents b790537 + 0da5996 commit d9d1645
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 26 deletions.
8 changes: 7 additions & 1 deletion dask_pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ class S3ImageFolder(Dataset):
An image folder that lives in S3. Directories containing the image are classes.
"""

def __init__(self, s3_bucket: str, s3_prefix: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
def __init__(
self,
s3_bucket: str,
s3_prefix: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
self.s3_bucket = s3_bucket
self.s3_prefix = s3_prefix
self.all_files = _list_all_files(s3_bucket, s3_prefix)
Expand Down
27 changes: 18 additions & 9 deletions dask_pytorch/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,32 @@ def run(client: Client, pytorch_function: Callable, *args, **kwargs):
futures = [
client.submit(
dispatch_with_ddp,
pytorch_function = pytorch_function,
master_addr = host,
master_port = port,
rank = idx,
world_size = world_size,
backend = "nccl",
*args,
pytorch_function=pytorch_function,
master_addr=host,
master_port=port,
rank=idx,
world_size=world_size,
backend="nccl",
*args,
**kwargs
)
for idx, w in enumerate(worker_keys)
]

return futures


# pylint: disable=keyword-arg-before-vararg
# pylint: disable=too-many-arguments
def dispatch_with_ddp(
pytorch_function: Callable, master_addr: str, master_port: int, rank: int, world_size: int, backend: str = "nccl", *args, **kwargs
pytorch_function: Callable,
master_addr: Any,
master_port: Any,
rank: Any,
world_size: Any,
backend: str = "nccl",
*args,
**kwargs
) -> Any:
"""
runs a pytorch function, setting up torch.distributed before execution
Expand Down
4 changes: 3 additions & 1 deletion dask_pytorch/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _get_results(self, futures: List[Future], raise_errors: bool = True):
raise
futures = result.not_done

def process_results(self, prefix: str, futures: List[Future], raise_errors: bool = True) -> None:
def process_results(
self, prefix: str, futures: List[Future], raise_errors: bool = True
) -> None:
"""
Process the intermediate results:
result objects will be dictionaries of the form {'path': path, 'data': data}
Expand Down
23 changes: 13 additions & 10 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from unittest.mock import Mock, patch, ANY


from dask_pytorch.data import BOTOS3ImageFolder
from dask_pytorch.data import S3ImageFolder


def test_image_folder_constructor():
fake_file_list = ["d/a.jpg", "c/b.jpg"]
with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list):
with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list):
fake_transform = Mock()
fake_target_transform = Mock()
folder = BOTOS3ImageFolder(
"fake-bucket", "fake-prefix/fake-prefix", fake_transform, fake_target_transform
folder = S3ImageFolder(
"fake-bucket",
"fake-prefix/fake-prefix",
fake_transform,
fake_target_transform,
)
assert folder.all_files == fake_file_list
assert folder.classes == ["c", "d"]
Expand All @@ -21,17 +24,17 @@ def test_image_folder_constructor():

def test_image_folder_len():
fake_file_list = ["d/a.jpg", "c/b.jpg"]
with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list):
folder = BOTOS3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list):
folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
assert len(folder) == 2


def test_image_folder_getitem():
fake_file_list = ["d/a.jpg", "c/b.jpg"]
with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list):
folder = BOTOS3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
with patch("dask_pytorch.data.read_s3_fileobj") as read_s3_fileobj, patch(
"dask_pytorch.data.load_image_obj"
with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list):
folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix")
with patch("dask_pytorch.data._read_s3_fileobj") as read_s3_fileobj, patch(
"dask_pytorch.data._load_image_obj"
) as load_image_obj:

read_s3_fileobj.return_value = Mock()
Expand Down
45 changes: 40 additions & 5 deletions tests/test_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

from unittest.mock import Mock, patch

from dask_pytorch.dispatch import run, dispatch_with_ddp
Expand Down Expand Up @@ -31,16 +32,40 @@ def test_run():
output = run(client, fake_pytorch_func)

client.submit.assert_any_call(
dispatch_with_ddp, fake_pytorch_func, host, 23456, 0, len(workers), workers=[worker_keys[0]]
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=0,
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp, fake_pytorch_func, host, 23456, 1, len(workers), workers=[worker_keys[1]]
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=1,
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp, fake_pytorch_func, host, 23456, 2, len(workers), workers=[worker_keys[2]]
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=2,
world_size=len(workers),
backend="nccl",
)
client.submit.assert_any_call(
dispatch_with_ddp, fake_pytorch_func, host, 23456, 3, len(workers), workers=[worker_keys[3]]
dispatch_with_ddp,
pytorch_function=fake_pytorch_func,
master_addr=host,
master_port=23456,
rank=3,
world_size=len(workers),
backend="nccl",
)
assert output == fake_results

Expand All @@ -51,7 +76,17 @@ def test_dispatch_with_ddp():
with patch.object(os, "environ", {}) as environ, patch(
"dask_pytorch.dispatch.dist", return_value=Mock()
) as dist:
dispatch_with_ddp(pytorch_func, "master_addr", 2343, 1, 10, "a", "b", foo="bar")
dispatch_with_ddp(
pytorch_func,
"master_addr",
2343,
1,
10,
"nccl",
"a",
"b",
foo="bar",
)
assert environ["MASTER_ADDR"] == "master_addr"
assert environ["MASTER_PORT"] == "2343"
assert environ["RANK"] == "1"
Expand Down

0 comments on commit d9d1645

Please sign in to comment.