Skip to content

Commit

Permalink
remove mypy checks
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Aug 22, 2023
1 parent cfa6ae2 commit 8d0ab83
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
42 changes: 21 additions & 21 deletions examples/src/main/python/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def python_udtf_simple_example(spark: SparkSession) -> None:

# Define the UDTF class and implement the required `eval` method.
class SquareNumbers:
def eval(self, start: int, end: int): # type: ignore[no-untyped-def]
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)

Expand All @@ -45,7 +45,7 @@ def eval(self, start: int, end: int): # type: ignore[no-untyped-def]
square_num = udtf(SquareNumbers, returnType="num: int, squared: int")

# Invoke the UDTF in PySpark.
square_num(lit(1), lit(3)).show() # type: ignore
square_num(lit(1), lit(3)).show()
# +---+------+
# |num|squred|
# +---+------+
Expand All @@ -60,14 +60,14 @@ def python_udtf_decorator_example(spark: SparkSession) -> None:
from pyspark.sql.functions import lit, udtf

# Define a UDTF using the `udtf` decorator directly on the class.
@udtf(returnType="num: int, squared: int") # type: ignore
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
def eval(self, start: int, end: int): # type: ignore[no-untyped-def]
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)

# Invoke the UDTF in PySpark using the SquareNumbers class directly.
SquareNumbers(lit(1), lit(3)).show() # type: ignore
SquareNumbers(lit(1), lit(3)).show()
# +---+------+
# |num|squred|
# +---+------+
Expand All @@ -81,14 +81,14 @@ def python_udtf_registration(spark: SparkSession) -> None:

from pyspark.sql.functions import udtf

@udtf(returnType="word: string") # type: ignore
@udtf(returnType="word: string")
class WordSplitter:
def eval(self, text: str): # type: ignore[no-untyped-def]
def eval(self, text: str):
for word in text.split(" "):
yield (word.strip(),)

# Register the UDTF for use in Spark SQL.
spark.udtf.register("split_words", WordSplitter) # type: ignore
spark.udtf.register("split_words", WordSplitter)

# Example: Using the UDTF in SQL.
spark.sql("SELECT * FROM split_words('hello world')").show()
Expand Down Expand Up @@ -120,9 +120,9 @@ def python_udtf_arrow_example(spark: SparkSession) -> None:

from pyspark.sql.functions import udtf

@udtf(returnType="c1: int, c2: int", useArrow=True) # type: ignore
@udtf(returnType="c1: int, c2: int", useArrow=True)
class PlusOne:
def eval(self, x: int): # type: ignore[no-untyped-def]
def eval(self, x: int):
yield x, x + 1


Expand All @@ -131,16 +131,16 @@ def python_udtf_date_expander_example(spark: SparkSession) -> None:
from datetime import datetime, timedelta
from pyspark.sql.functions import lit, udtf

@udtf(returnType="date: string") # type: ignore
@udtf(returnType="date: string")
class DateExpander:
def eval(self, start_date: str, end_date: str): # type: ignore[no-untyped-def]
def eval(self, start_date: str, end_date: str):
current = datetime.strptime(start_date, '%Y-%m-%d')
end = datetime.strptime(end_date, '%Y-%m-%d')
while current <= end:
yield (current.strftime('%Y-%m-%d'),)
current += timedelta(days=1)

DateExpander(lit("2023-02-25"), lit("2023-03-01")).show() # type: ignore
DateExpander(lit("2023-02-25"), lit("2023-03-01")).show()
# +----------+
# | date|
# +----------+
Expand All @@ -156,18 +156,18 @@ def python_udtf_terminate_example(spark: SparkSession) -> None:

from pyspark.sql.functions import udtf

@udtf(returnType="cnt: int") # type: ignore
@udtf(returnType="cnt: int")
class CountUDTF:
def __init__(self): # type: ignore[no-untyped-def]
def __init__(self):
self.count = 0

def eval(self, x: int): # type: ignore[no-untyped-def]
def eval(self, x: int):
self.count += 1

def terminate(self): # type: ignore[no-untyped-def]
def terminate(self):
yield self.count,

spark.udtf.register("count_udtf", CountUDTF) # type: ignore
spark.udtf.register("count_udtf", CountUDTF)
spark.sql("SELECT * FROM range(0, 10, 1, 1), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
Expand All @@ -181,13 +181,13 @@ def python_udtf_table_argument(spark: SparkSession) -> None:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int") # type: ignore
@udtf(returnType="id: int")
class FilterUDTF:
def eval(self, row: Row): # type: ignore[no-untyped-def]
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF) # type: ignore
spark.udtf.register("filter_udtf", FilterUDTF)

spark.sql("SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)))").show()
# +---+
Expand Down
6 changes: 6 additions & 0 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ disallow_untyped_defs = False
[mypy-pyspark.worker]
disallow_untyped_defs = False

; Allow untyped def and disable certain error codes in examples

[mypy-python.sql.udtf]
disallow_untyped_defs = False
disable_error_code = attr-defined,arg-type,call-arg,union-attr

; Ignore errors in tests

[mypy-pyspark.ml.tests.*]
Expand Down

0 comments on commit 8d0ab83

Please sign in to comment.