From 3e22c8653d728a6b8523051faddcca437accfc22 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Sat, 2 Sep 2023 16:07:09 -0700 Subject: [PATCH] [SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error messages to include method name ### What changes were proposed in this pull request? This PR is a follow-up for SPARK-44640 to make the error message of a few UDTF errors more informative by including the method name in the error message (`eval` or `terminate`). ### Why are the changes needed? To improve error messages. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42726 from allisonwang-db/SPARK-44640-follow-up. Authored-by: allisonwang-db Signed-off-by: Takuya UESHIN --- python/pyspark/errors/error_classes.py | 8 +++--- python/pyspark/sql/tests/test_udtf.py | 21 +++++++++++++++ python/pyspark/worker.py | 37 +++++++++++++++++++------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index ca448a169e83b..74f52c416e95b 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -244,7 +244,7 @@ }, "INVALID_ARROW_UDTF_RETURN_TYPE" : { "message" : [ - "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the function returned a value of type with value: ." + "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '' method returned a value of type with value: ." ] }, "INVALID_BROADCAST_OPERATION": { @@ -745,17 +745,17 @@ }, "UDTF_INVALID_OUTPUT_ROW_TYPE" : { "message" : [ - "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." + "The type of an individual output row in the '' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." ] }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ - "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." + "The return value of the '' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." ] }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ - "The number of columns in the result does not match the specified schema. Expected column count: , Actual column count: . Please make sure the values returned by the function have the same number of columns as specified in the output schema." + "The number of columns in the result does not match the specified schema. Expected column count: , Actual column count: . Please make sure the values returned by the '' method have the same number of columns as specified in the output schema." ] }, "UDTF_RETURN_TYPE_MISMATCH" : { diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index c5f8b7693c26d..97d5190a5060c 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -190,6 +190,27 @@ def eval(self, a): with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): TestUDTF(lit(1)).collect() + def test_udtf_with_zero_arg_and_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF().collect() + + def test_udtf_with_invalid_return_value_in_terminate(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + ... + + def terminate(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d95a5c4672f86..fff99f1de3d06 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -773,6 +773,7 @@ def verify_result(result): message_parameters={ "type_name": type(result).__name__, "value": str(result), + "func": f.__name__, }, ) @@ -787,6 +788,7 @@ def verify_result(result): message_parameters={ "expected": str(return_type_size), "actual": str(len(result.columns)), + "func": f.__name__, }, ) @@ -806,9 +808,23 @@ def func(*a: Any, **kw: Any) -> Any: message_parameters={"method_name": f.__name__, "error": str(e)}, ) + def check_return_value(res): + # Check whether the result of an arrow UDTF is iterable before + # using it to construct a pandas DataFrame. + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + def evaluate(*args: pd.Series, **kwargs: pd.Series): if len(args) == 0 and len(kwargs) == 0: - yield verify_result(pd.DataFrame(func())), arrow_return_type + res = func() + check_return_value(res) + yield verify_result(pd.DataFrame(res)), arrow_return_type else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. @@ -820,13 +836,7 @@ def evaluate(*args: pd.Series, **kwargs: pd.Series): *row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)}, ) - if res is not None and not isinstance(res, Iterable): - raise PySparkRuntimeError( - error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={ - "type": type(res).__name__, - }, - ) + check_return_value(res) yield verify_result(pd.DataFrame(res)), arrow_return_type return evaluate @@ -868,13 +878,17 @@ def verify_and_convert_result(result): message_parameters={ "expected": str(return_type_size), "actual": str(len(result)), + "func": f.__name__, }, ) if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): raise PySparkRuntimeError( error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", - message_parameters={"type": type(result).__name__}, + message_parameters={ + "type": type(result).__name__, + "func": f.__name__, + }, ) return toInternal(result) @@ -898,7 +912,10 @@ def evaluate(*a, **kw) -> tuple: if not isinstance(res, Iterable): raise PySparkRuntimeError( error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={"type": type(res).__name__}, + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, ) # If the function returns a result, we map it to the internal representation and