Skip to content

Commit

Permalink
[SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error messages to includ…
Browse files Browse the repository at this point in the history
…e method name

### What changes were proposed in this pull request?

This PR is a follow-up for SPARK-44640 to make the error message of a few UDTF errors more informative by including the method name in the error message (`eval` or `terminate`).

### Why are the changes needed?

To improve error messages.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #42726 from allisonwang-db/SPARK-44640-follow-up.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
allisonwang-db authored and ueshin committed Sep 2, 2023
1 parent 967aac1 commit 3e22c86
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
},
"INVALID_ARROW_UDTF_RETURN_TYPE" : {
"message" : [
"The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the function returned a value of type <type_name> with value: <value>."
"The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '<func>' method returned a value of type <type_name> with value: <value>."
]
},
"INVALID_BROADCAST_OPERATION": {
Expand Down Expand Up @@ -745,17 +745,17 @@
},
"UDTF_INVALID_OUTPUT_ROW_TYPE" : {
"message" : [
"The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type."
"The type of an individual output row in the '<func>' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type."
]
},
"UDTF_RETURN_NOT_ITERABLE" : {
"message" : [
"The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types."
"The return value of the '<func>' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types."
]
},
"UDTF_RETURN_SCHEMA_MISMATCH" : {
"message" : [
"The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema."
"The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the '<func>' method have the same number of columns as specified in the output schema."
]
},
"UDTF_RETURN_TYPE_MISMATCH" : {
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ def eval(self, a):
with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"):
TestUDTF(lit(1)).collect()

def test_udtf_with_zero_arg_and_invalid_return_value(self):
@udtf(returnType="x: int")
class TestUDTF:
def eval(self):
return 1

with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"):
TestUDTF().collect()

def test_udtf_with_invalid_return_value_in_terminate(self):
@udtf(returnType="x: int")
class TestUDTF:
def eval(self, a):
...

def terminate(self):
return 1

with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"):
TestUDTF(lit(1)).collect()

def test_udtf_eval_with_no_return(self):
@udtf(returnType="a: int")
class TestUDTF:
Expand Down
37 changes: 27 additions & 10 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ def verify_result(result):
message_parameters={
"type_name": type(result).__name__,
"value": str(result),
"func": f.__name__,
},
)

Expand All @@ -787,6 +788,7 @@ def verify_result(result):
message_parameters={
"expected": str(return_type_size),
"actual": str(len(result.columns)),
"func": f.__name__,
},
)

Expand All @@ -806,9 +808,23 @@ def func(*a: Any, **kw: Any) -> Any:
message_parameters={"method_name": f.__name__, "error": str(e)},
)

def check_return_value(res):
# Check whether the result of an arrow UDTF is iterable before
# using it to construct a pandas DataFrame.
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
"func": f.__name__,
},
)

def evaluate(*args: pd.Series, **kwargs: pd.Series):
if len(args) == 0 and len(kwargs) == 0:
yield verify_result(pd.DataFrame(func())), arrow_return_type
res = func()
check_return_value(res)
yield verify_result(pd.DataFrame(res)), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
Expand All @@ -820,13 +836,7 @@ def evaluate(*args: pd.Series, **kwargs: pd.Series):
*row[:len_args],
**{key: row[len_args + i] for i, key in enumerate(keys)},
)
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
},
)
check_return_value(res)
yield verify_result(pd.DataFrame(res)), arrow_return_type

return evaluate
Expand Down Expand Up @@ -868,13 +878,17 @@ def verify_and_convert_result(result):
message_parameters={
"expected": str(return_type_size),
"actual": str(len(result)),
"func": f.__name__,
},
)

if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")):
raise PySparkRuntimeError(
error_class="UDTF_INVALID_OUTPUT_ROW_TYPE",
message_parameters={"type": type(result).__name__},
message_parameters={
"type": type(result).__name__,
"func": f.__name__,
},
)

return toInternal(result)
Expand All @@ -898,7 +912,10 @@ def evaluate(*a, **kw) -> tuple:
if not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={"type": type(res).__name__},
message_parameters={
"type": type(res).__name__,
"func": f.__name__,
},
)

# If the function returns a result, we map it to the internal representation and
Expand Down

0 comments on commit 3e22c86

Please sign in to comment.