Skip to content

Commit

Permalink
[SPARK-42183][PYTHON][ML][TESTS] Exclude pyspark.ml.torch.tests in My…
Browse files Browse the repository at this point in the history
…Py tests

### What changes were proposed in this pull request?

This PR proposes to exclude `pyspark.ml.torch.tests` in MyPy tests

### Why are the changes needed?

Initial intention was to annotate types for public APIs only, see also apache#38991

### Does this PR introduce _any_ user-facing change?

No, test-only.

### How was this patch tested?

CI in this PR should test it out.

Closes apache#39740 from HyukjinKwon/SPARK-42183.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Jan 25, 2023
1 parent 001408c commit e81c2c8
Showing 3 changed files with 20 additions and 21 deletions.
3 changes: 3 additions & 0 deletions python/mypy.ini
Original file line number Diff line number Diff line change
@@ -85,6 +85,9 @@ disallow_untyped_defs = False
[mypy-pyspark.ml.tests.*]
ignore_errors = True

[mypy-pyspark.ml.torch.tests.*]
ignore_errors = True

[mypy-pyspark.mllib.tests.*]
ignore_errors = True

20 changes: 10 additions & 10 deletions python/pyspark/ml/torch/tests/test_distributor.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
import contextlib
import os
import shutil
from six import StringIO # type: ignore
from six import StringIO
import stat
import subprocess
import sys
@@ -57,7 +57,7 @@ def patch_stdout() -> StringIO:
def create_training_function(mnist_dir_path: str) -> Callable:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets # type: ignore
from torchvision import transforms, datasets

batch_size = 100
num_epochs = 1
@@ -99,7 +99,7 @@ def train_fn(learning_rate: float) -> Any:

dist.init_process_group("gloo")

train_sampler = DistributedSampler(dataset=train_dataset) # type: ignore
train_sampler = DistributedSampler(dataset=train_dataset)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler
)
@@ -220,11 +220,11 @@ def test_execute_command(self) -> None:
)

# include command in the exception message
with self.assertRaisesRegexp(RuntimeError, "exit 1"): # pylint: disable=deprecated-method
with self.assertRaisesRegex(RuntimeError, "exit 1"):
error_command = ["bash", "-c", "exit 1"]
TorchDistributor._execute_command(error_command)

with self.assertRaisesRegexp(RuntimeError, "abcdef"): # pylint: disable=deprecated-method
with self.assertRaisesRegex(RuntimeError, "abcdef"):
error_command = ["bash", "-c", "'abc''def'"]
TorchDistributor._execute_command(error_command)

@@ -359,7 +359,7 @@ def test_local_training_succeeds(self) -> None:
self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})

dist = TorchDistributor(num_processes, True, use_gpu)
dist._run_training_on_pytorch_file = lambda *args: os.environ.get( # type: ignore
dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
CUDA_VISIBLE_DEVICES, "NONE"
)
self.assertEqual(
@@ -429,7 +429,7 @@ def test_dist_training_succeeds(self) -> None:
for i, (_, num_processes, use_gpu, expected) in enumerate(inputs):
with self.subTest(f"subtest: {i + 1}"):
dist = TorchDistributor(num_processes, False, use_gpu)
dist._run_training_on_pytorch_file = lambda *args: os.environ.get( # type: ignore
dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
CUDA_VISIBLE_DEVICES, "NONE"
)
self.assertEqual(
@@ -486,14 +486,14 @@ def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> None:
t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True)
t.start()
time.sleep(2)
self.assertEqual(mock_clean_and_terminate.call_count, 0) # type: ignore[attr-defined]
self.assertEqual(mock_clean_and_terminate.call_count, 0)


if __name__ == "__main__":
from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403 type: ignore
from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403

try:
import xmlrunner # type: ignore
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
18 changes: 7 additions & 11 deletions python/pyspark/ml/torch/tests/test_log_communication.py
Original file line number Diff line number Diff line change
@@ -18,14 +18,14 @@
from __future__ import absolute_import, division, print_function

import contextlib
from six import StringIO # type: ignore
from six import StringIO
import sys
import time
from typing import Any, Callable
import unittest

import pyspark.ml.torch.log_communication
from pyspark.ml.torch.log_communication import ( # type: ignore
from pyspark.ml.torch.log_communication import (
LogStreamingServer,
LogStreamingClient,
LogStreamingClientBase,
@@ -47,15 +47,11 @@ def patch_stderr() -> StringIO:

class LogStreamingServiceTestCase(unittest.TestCase):
def setUp(self) -> None:
self.default_truncate_msg_len = (
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN # type: ignore
)
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10 # type: ignore
self.default_truncate_msg_len = pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10

def tearDown(self) -> None:
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = ( # type: ignore
self.default_truncate_msg_len
)
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = self.default_truncate_msg_len

def basic_test(self) -> None:
server = LogStreamingServer()
@@ -165,10 +161,10 @@ def client_ops_send_a_msg_and_close(client: Any) -> None:


if __name__ == "__main__":
from pyspark.ml.torch.tests.test_log_communication import * # noqa: F401,F403 type: ignore
from pyspark.ml.torch.tests.test_log_communication import * # noqa: F401,F403

try:
import xmlrunner # type: ignore
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:

0 comments on commit e81c2c8

Please sign in to comment.