diff --git a/dask_pytorch/data.py b/dask_pytorch/data.py index 959c091..c0b407a 100644 --- a/dask_pytorch/data.py +++ b/dask_pytorch/data.py @@ -59,7 +59,13 @@ class S3ImageFolder(Dataset): An image folder that lives in S3. Directories containing the image are classes. """ - def __init__(self, s3_bucket: str, s3_prefix: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): + def __init__( + self, + s3_bucket: str, + s3_prefix: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ): self.s3_bucket = s3_bucket self.s3_prefix = s3_prefix self.all_files = _list_all_files(s3_bucket, s3_prefix) diff --git a/dask_pytorch/dispatch.py b/dask_pytorch/dispatch.py index 66606a9..447cf03 100644 --- a/dask_pytorch/dispatch.py +++ b/dask_pytorch/dispatch.py @@ -31,23 +31,32 @@ def run(client: Client, pytorch_function: Callable, *args, **kwargs): futures = [ client.submit( dispatch_with_ddp, - pytorch_function = pytorch_function, - master_addr = host, - master_port = port, - rank = idx, - world_size = world_size, - backend = "nccl", - *args, + pytorch_function=pytorch_function, + master_addr=host, + master_port=port, + rank=idx, + world_size=world_size, + backend="nccl", + *args, **kwargs ) for idx, w in enumerate(worker_keys) ] - + return futures +# pylint: disable=keyword-arg-before-vararg +# pylint: disable=too-many-arguments def dispatch_with_ddp( - pytorch_function: Callable, master_addr: str, master_port: int, rank: int, world_size: int, backend: str = "nccl", *args, **kwargs + pytorch_function: Callable, + master_addr: Any, + master_port: Any, + rank: Any, + world_size: Any, + backend: str = "nccl", + *args, + **kwargs ) -> Any: """ runs a pytorch function, setting up torch.distributed before execution diff --git a/dask_pytorch/results.py b/dask_pytorch/results.py index 2fa661d..53802af 100644 --- a/dask_pytorch/results.py +++ b/dask_pytorch/results.py @@ -61,7 +61,9 @@ def _get_results(self, futures: List[Future], raise_errors: bool = True): raise futures = result.not_done - def process_results(self, prefix: str, futures: List[Future], raise_errors: bool = True) -> None: + def process_results( + self, prefix: str, futures: List[Future], raise_errors: bool = True + ) -> None: """ Process the intermediate results: result objects will be dictionaries of the form {'path': path, 'data': data} diff --git a/tests/test_data.py b/tests/test_data.py index 77fd000..d3d1e99 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,16 +1,19 @@ from unittest.mock import Mock, patch, ANY -from dask_pytorch.data import BOTOS3ImageFolder +from dask_pytorch.data import S3ImageFolder def test_image_folder_constructor(): fake_file_list = ["d/a.jpg", "c/b.jpg"] - with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list): + with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list): fake_transform = Mock() fake_target_transform = Mock() - folder = BOTOS3ImageFolder( - "fake-bucket", "fake-prefix/fake-prefix", fake_transform, fake_target_transform + folder = S3ImageFolder( + "fake-bucket", + "fake-prefix/fake-prefix", + fake_transform, + fake_target_transform, ) assert folder.all_files == fake_file_list assert folder.classes == ["c", "d"] @@ -21,17 +24,17 @@ def test_image_folder_constructor(): def test_image_folder_len(): fake_file_list = ["d/a.jpg", "c/b.jpg"] - with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list): - folder = BOTOS3ImageFolder("fake-bucket", "fake-prefix/fake-prefix") + with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list): + folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix") assert len(folder) == 2 def test_image_folder_getitem(): fake_file_list = ["d/a.jpg", "c/b.jpg"] - with patch("dask_pytorch.data.get_all_files", return_value=fake_file_list): - folder = BOTOS3ImageFolder("fake-bucket", "fake-prefix/fake-prefix") - with patch("dask_pytorch.data.read_s3_fileobj") as read_s3_fileobj, patch( - "dask_pytorch.data.load_image_obj" + with patch("dask_pytorch.data._list_all_files", return_value=fake_file_list): + folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix") + with patch("dask_pytorch.data._read_s3_fileobj") as read_s3_fileobj, patch( + "dask_pytorch.data._load_image_obj" ) as load_image_obj: read_s3_fileobj.return_value = Mock() diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index f0fd434..ff5c7ec 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -1,4 +1,5 @@ import os + from unittest.mock import Mock, patch from dask_pytorch.dispatch import run, dispatch_with_ddp @@ -31,16 +32,40 @@ def test_run(): output = run(client, fake_pytorch_func) client.submit.assert_any_call( - dispatch_with_ddp, fake_pytorch_func, host, 23456, 0, len(workers), workers=[worker_keys[0]] + dispatch_with_ddp, + pytorch_function=fake_pytorch_func, + master_addr=host, + master_port=23456, + rank=0, + world_size=len(workers), + backend="nccl", ) client.submit.assert_any_call( - dispatch_with_ddp, fake_pytorch_func, host, 23456, 1, len(workers), workers=[worker_keys[1]] + dispatch_with_ddp, + pytorch_function=fake_pytorch_func, + master_addr=host, + master_port=23456, + rank=1, + world_size=len(workers), + backend="nccl", ) client.submit.assert_any_call( - dispatch_with_ddp, fake_pytorch_func, host, 23456, 2, len(workers), workers=[worker_keys[2]] + dispatch_with_ddp, + pytorch_function=fake_pytorch_func, + master_addr=host, + master_port=23456, + rank=2, + world_size=len(workers), + backend="nccl", ) client.submit.assert_any_call( - dispatch_with_ddp, fake_pytorch_func, host, 23456, 3, len(workers), workers=[worker_keys[3]] + dispatch_with_ddp, + pytorch_function=fake_pytorch_func, + master_addr=host, + master_port=23456, + rank=3, + world_size=len(workers), + backend="nccl", ) assert output == fake_results @@ -51,7 +76,17 @@ def test_dispatch_with_ddp(): with patch.object(os, "environ", {}) as environ, patch( "dask_pytorch.dispatch.dist", return_value=Mock() ) as dist: - dispatch_with_ddp(pytorch_func, "master_addr", 2343, 1, 10, "a", "b", foo="bar") + dispatch_with_ddp( + pytorch_func, + "master_addr", + 2343, + 1, + 10, + "nccl", + "a", + "b", + foo="bar", + ) assert environ["MASTER_ADDR"] == "master_addr" assert environ["MASTER_PORT"] == "2343" assert environ["RANK"] == "1"