diff --git a/python/mypy.ini b/python/mypy.ini index 5f662a4a2375b..a845cd88bd84f 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -85,6 +85,9 @@ disallow_untyped_defs = False [mypy-pyspark.ml.tests.*] ignore_errors = True +[mypy-pyspark.ml.torch.tests.*] +ignore_errors = True + [mypy-pyspark.mllib.tests.*] ignore_errors = True diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py index 0f4a4a23dc086..baf68757f67c3 100644 --- a/python/pyspark/ml/torch/tests/test_distributor.py +++ b/python/pyspark/ml/torch/tests/test_distributor.py @@ -18,7 +18,7 @@ import contextlib import os import shutil -from six import StringIO # type: ignore +from six import StringIO import stat import subprocess import sys @@ -57,7 +57,7 @@ def patch_stdout() -> StringIO: def create_training_function(mnist_dir_path: str) -> Callable: import torch.nn as nn import torch.nn.functional as F - from torchvision import transforms, datasets # type: ignore + from torchvision import transforms, datasets batch_size = 100 num_epochs = 1 @@ -99,7 +99,7 @@ def train_fn(learning_rate: float) -> Any: dist.init_process_group("gloo") - train_sampler = DistributedSampler(dataset=train_dataset) # type: ignore + train_sampler = DistributedSampler(dataset=train_dataset) data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler ) @@ -220,11 +220,11 @@ def test_execute_command(self) -> None: ) # include command in the exception message - with self.assertRaisesRegexp(RuntimeError, "exit 1"): # pylint: disable=deprecated-method + with self.assertRaisesRegex(RuntimeError, "exit 1"): error_command = ["bash", "-c", "exit 1"] TorchDistributor._execute_command(error_command) - with self.assertRaisesRegexp(RuntimeError, "abcdef"): # pylint: disable=deprecated-method + with self.assertRaisesRegex(RuntimeError, "abcdef"): error_command = ["bash", "-c", "'abc''def'"] TorchDistributor._execute_command(error_command) @@ -359,7 +359,7 @@ def test_local_training_succeeds(self) -> None: self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var}) dist = TorchDistributor(num_processes, True, use_gpu) - dist._run_training_on_pytorch_file = lambda *args: os.environ.get( # type: ignore + dist._run_training_on_pytorch_file = lambda *args: os.environ.get( CUDA_VISIBLE_DEVICES, "NONE" ) self.assertEqual( @@ -429,7 +429,7 @@ def test_dist_training_succeeds(self) -> None: for i, (_, num_processes, use_gpu, expected) in enumerate(inputs): with self.subTest(f"subtest: {i + 1}"): dist = TorchDistributor(num_processes, False, use_gpu) - dist._run_training_on_pytorch_file = lambda *args: os.environ.get( # type: ignore + dist._run_training_on_pytorch_file = lambda *args: os.environ.get( CUDA_VISIBLE_DEVICES, "NONE" ) self.assertEqual( @@ -486,14 +486,14 @@ def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> None: t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True) t.start() time.sleep(2) - self.assertEqual(mock_clean_and_terminate.call_count, 0) # type: ignore[attr-defined] + self.assertEqual(mock_clean_and_terminate.call_count, 0) if __name__ == "__main__": - from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403 type: ignore + from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403 try: - import xmlrunner # type: ignore + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/torch/tests/test_log_communication.py b/python/pyspark/ml/torch/tests/test_log_communication.py index 0c9379264804f..164c7556d129d 100644 --- a/python/pyspark/ml/torch/tests/test_log_communication.py +++ b/python/pyspark/ml/torch/tests/test_log_communication.py @@ -18,14 +18,14 @@ from __future__ import absolute_import, division, print_function import contextlib -from six import StringIO # type: ignore +from six import StringIO import sys import time from typing import Any, Callable import unittest import pyspark.ml.torch.log_communication -from pyspark.ml.torch.log_communication import ( # type: ignore +from pyspark.ml.torch.log_communication import ( LogStreamingServer, LogStreamingClient, LogStreamingClientBase, @@ -47,15 +47,11 @@ def patch_stderr() -> StringIO: class LogStreamingServiceTestCase(unittest.TestCase): def setUp(self) -> None: - self.default_truncate_msg_len = ( - pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN # type: ignore - ) - pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10 # type: ignore + self.default_truncate_msg_len = pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN + pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10 def tearDown(self) -> None: - pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = ( # type: ignore - self.default_truncate_msg_len - ) + pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = self.default_truncate_msg_len def basic_test(self) -> None: server = LogStreamingServer() @@ -165,10 +161,10 @@ def client_ops_send_a_msg_and_close(client: Any) -> None: if __name__ == "__main__": - from pyspark.ml.torch.tests.test_log_communication import * # noqa: F401,F403 type: ignore + from pyspark.ml.torch.tests.test_log_communication import * # noqa: F401,F403 try: - import xmlrunner # type: ignore + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: