From e7071c0237da75967b2f1e222d9f3b8293a82f86 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 2 Dec 2024 09:59:07 +0800 Subject: [PATCH] [SPARK-50435][PYTHON][TESTS] Use assertDataFrameEqual in pyspark.sql.tests.test_functions ### What changes were proposed in this pull request? Use `assertDataFrameEqual` in pyspark.sql.tests.test_functions ### Why are the changes needed? `assertDataFrameEqual` is explicitly built to handle DataFrame-specific comparisons, including schema. So we propose to replace `assertEqual` with `assertDataFrameEqual` Part of https://issues.apache.org/jira/browse/SPARK-50435. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49011 from xinrong-meng/impr_test_functions. Lead-authored-by: Xinrong Meng Co-authored-by: Hyukjin Kwon Signed-off-by: Xinrong Meng --- python/pyspark/sql/tests/test_functions.py | 196 ++++++++++----------- 1 file changed, 92 insertions(+), 104 deletions(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index e192366676ad8..4607d5d3411fe 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -31,7 +31,7 @@ from pyspark.sql.column import Column from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils -from pyspark.testing.utils import have_numpy +from pyspark.testing.utils import have_numpy, assertDataFrameEqual class FunctionsTestsMixin: @@ -344,29 +344,29 @@ def test_try_parse_url(self): [("https://spark.apache.org/path?query=1", "QUERY", "query")], ["url", "part", "key"], ) - actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect() - self.assertEqual(actual, [Row("1")]) + actual = df.select(F.try_parse_url(df.url, df.part, df.key)) + assertDataFrameEqual(actual, [Row("1")]) df = self.spark.createDataFrame( [("inva lid://spark.apache.org/path?query=1", "QUERY", "query")], ["url", "part", "key"], ) - actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect() - self.assertEqual(actual, [Row(None)]) + actual = df.select(F.try_parse_url(df.url, df.part, df.key)) + assertDataFrameEqual(actual, [Row(None)]) def test_try_make_timestamp(self): data = [(2024, 5, 22, 10, 30, 0)] df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) actual = df.select( F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second) - ).collect() - self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) + ) + assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) data = [(2024, 13, 22, 10, 30, 0)] df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) actual = df.select( F.try_make_timestamp(df.year, df.month, df.day, df.hour, df.minute, df.second) - ).collect() - self.assertEqual(actual, [Row(None)]) + ) + assertDataFrameEqual(actual, [Row(None)]) def test_try_make_timestamp_ltz(self): # use local timezone here to avoid flakiness @@ -378,8 +378,8 @@ def test_try_make_timestamp_ltz(self): F.try_make_timestamp_ltz( df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone ) - ).collect() - self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 0))]) + ) + assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 0))]) # use local timezone here to avoid flakiness data = [(2024, 13, 22, 10, 30, 0, datetime.datetime.now().astimezone().tzinfo.__str__())] @@ -390,23 +390,23 @@ def test_try_make_timestamp_ltz(self): F.try_make_timestamp_ltz( df.year, df.month, df.day, df.hour, df.minute, df.second, df.timezone ) - ).collect() - self.assertEqual(actual, [Row(None)]) + ) + assertDataFrameEqual(actual, [Row(None)]) def test_try_make_timestamp_ntz(self): data = [(2024, 5, 22, 10, 30, 0)] df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) actual = df.select( F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) - ).collect() - self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) + ) + assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))]) data = [(2024, 13, 22, 10, 30, 0)] df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", "minute", "second"]) actual = df.select( F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, df.minute, df.second) - ).collect() - self.assertEqual(actual, [Row(None)]) + ) + assertDataFrameEqual(actual, [Row(None)]) def test_string_functions(self): string_functions = [ @@ -448,51 +448,51 @@ def test_string_functions(self): ) for name in string_functions: - self.assertEqual( - df.select(getattr(F, name)("name")).first()[0], - df.select(getattr(F, name)(F.col("name"))).first()[0], + assertDataFrameEqual( + df.select(getattr(F, name)("name")), + df.select(getattr(F, name)(F.col("name"))), ) def test_collation(self): df = self.spark.createDataFrame([("a",), ("b",)], ["name"]) - actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct().collect() - self.assertEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual) + actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct() + assertDataFrameEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual) def test_try_make_interval(self): df = self.spark.createDataFrame([(2147483647,)], ["num"]) - actual = df.select(F.isnull(F.try_make_interval("num"))).collect() - self.assertEqual([Row(True)], actual) + actual = df.select(F.isnull(F.try_make_interval("num"))) + assertDataFrameEqual([Row(True)], actual) def test_octet_length_function(self): # SPARK-36751: add octet length api for python df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"]) - actual = df.select(F.octet_length("cat")).collect() - self.assertEqual([Row(3), Row(4)], actual) + actual = df.select(F.octet_length("cat")) + assertDataFrameEqual([Row(3), Row(4)], actual) def test_bit_length_function(self): # SPARK-36751: add bit length api for python df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"]) - actual = df.select(F.bit_length("cat")).collect() - self.assertEqual([Row(24), Row(32)], actual) + actual = df.select(F.bit_length("cat")) + assertDataFrameEqual([Row(24), Row(32)], actual) def test_array_contains_function(self): df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ["data"]) - actual = df.select(F.array_contains(df.data, "1").alias("b")).collect() - self.assertEqual([Row(b=True), Row(b=False)], actual) + actual = df.select(F.array_contains(df.data, "1").alias("b")) + assertDataFrameEqual([Row(b=True), Row(b=False)], actual) def test_levenshtein_function(self): df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"]) - actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")).collect() - self.assertEqual([Row(b=3)], actual_without_threshold) - actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")).collect() - self.assertEqual([Row(b=-1)], actual_with_threshold) + actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")) + assertDataFrameEqual([Row(b=3)], actual_without_threshold) + actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")) + assertDataFrameEqual([Row(b=-1)], actual_with_threshold) def test_between_function(self): df = self.spark.createDataFrame( [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)] ) - self.assertEqual( - [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect() + assertDataFrameEqual( + [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)) ) def test_dayofweek(self): @@ -608,7 +608,7 @@ def test_first_last_ignorenulls(self): F.last(df2.id, False).alias("c"), F.last(df2.id, True).alias("d"), ) - self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + assertDataFrameEqual([Row(a=None, b=1, c=None, d=98)], df3) def test_approxQuantile(self): df = self.spark.createDataFrame([Row(a=i, b=i + 10) for i in range(10)]) @@ -666,20 +666,20 @@ def test_sort_with_nulls_order(self): df = self.spark.createDataFrame( [("Tom", 80), (None, 60), ("Alice", 50)], ["name", "height"] ) - self.assertEqual( - df.select(df.name).orderBy(F.asc_nulls_first("name")).collect(), + assertDataFrameEqual( + df.select(df.name).orderBy(F.asc_nulls_first("name")), [Row(name=None), Row(name="Alice"), Row(name="Tom")], ) - self.assertEqual( - df.select(df.name).orderBy(F.asc_nulls_last("name")).collect(), + assertDataFrameEqual( + df.select(df.name).orderBy(F.asc_nulls_last("name")), [Row(name="Alice"), Row(name="Tom"), Row(name=None)], ) - self.assertEqual( - df.select(df.name).orderBy(F.desc_nulls_first("name")).collect(), + assertDataFrameEqual( + df.select(df.name).orderBy(F.desc_nulls_first("name")), [Row(name=None), Row(name="Tom"), Row(name="Alice")], ) - self.assertEqual( - df.select(df.name).orderBy(F.desc_nulls_last("name")).collect(), + assertDataFrameEqual( + df.select(df.name).orderBy(F.desc_nulls_last("name")), [Row(name="Tom"), Row(name="Alice"), Row(name=None)], ) @@ -716,20 +716,16 @@ def test_slice(self): ) expected = [Row(sliced=[2, 3]), Row(sliced=[5])] - self.assertEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")).collect(), expected) - self.assertEqual( - df.select(F.slice(df.x, F.lit(2), F.lit(2)).alias("sliced")).collect(), expected - ) - self.assertEqual( - df.select(F.slice("x", "index", "len").alias("sliced")).collect(), expected - ) + assertDataFrameEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")), expected) + assertDataFrameEqual(df.select(F.slice(df.x, F.lit(2), F.lit(2)).alias("sliced")), expected) + assertDataFrameEqual(df.select(F.slice("x", "index", "len").alias("sliced")), expected) - self.assertEqual( - df.select(F.slice(df.x, F.size(df.x) - 1, F.lit(1)).alias("sliced")).collect(), + assertDataFrameEqual( + df.select(F.slice(df.x, F.size(df.x) - 1, F.lit(1)).alias("sliced")), [Row(sliced=[2]), Row(sliced=[4])], ) - self.assertEqual( - df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 1).alias("sliced")).collect(), + assertDataFrameEqual( + df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 1).alias("sliced")), [Row(sliced=[1, 2]), Row(sliced=[4])], ) @@ -738,11 +734,9 @@ def test_array_repeat(self): df = df.withColumn("repeat_n", F.lit(3)) expected = [Row(val=[0, 0, 0])] - self.assertEqual(df.select(F.array_repeat("id", 3).alias("val")).collect(), expected) - self.assertEqual(df.select(F.array_repeat("id", F.lit(3)).alias("val")).collect(), expected) - self.assertEqual( - df.select(F.array_repeat("id", "repeat_n").alias("val")).collect(), expected - ) + assertDataFrameEqual(df.select(F.array_repeat("id", 3).alias("val")), expected) + assertDataFrameEqual(df.select(F.array_repeat("id", F.lit(3)).alias("val")), expected) + assertDataFrameEqual(df.select(F.array_repeat("id", "repeat_n").alias("val")), expected) def test_input_file_name_udf(self): df = self.spark.read.text("python/test_support/hello/hello.txt") @@ -754,11 +748,11 @@ def test_least(self): df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) expected = [Row(least=1)] - self.assertEqual(df.select(F.least(df.a, df.b, df.c).alias("least")).collect(), expected) - self.assertEqual( - df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")).collect(), expected + assertDataFrameEqual(df.select(F.least(df.a, df.b, df.c).alias("least")), expected) + assertDataFrameEqual( + df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")), expected ) - self.assertEqual(df.select(F.least("a", "b", "c").alias("least")).collect(), expected) + assertDataFrameEqual(df.select(F.least("a", "b", "c").alias("least")), expected) with self.assertRaises(PySparkValueError) as pe: df.select(F.least(df.a).alias("least")).collect() @@ -800,11 +794,9 @@ def test_overlay(self): df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x", "y", "pos", "len")) exp = [Row(ol="SPARK_CORESQL")] - self.assertEqual(df.select(F.overlay(df.x, df.y, 7, 0).alias("ol")).collect(), exp) - self.assertEqual( - df.select(F.overlay(df.x, df.y, F.lit(7), F.lit(0)).alias("ol")).collect(), exp - ) - self.assertEqual(df.select(F.overlay("x", "y", "pos", "len").alias("ol")).collect(), exp) + assertDataFrameEqual(df.select(F.overlay(df.x, df.y, 7, 0).alias("ol")), exp) + assertDataFrameEqual(df.select(F.overlay(df.x, df.y, F.lit(7), F.lit(0)).alias("ol")), exp) + assertDataFrameEqual(df.select(F.overlay("x", "y", "pos", "len").alias("ol")), exp) with self.assertRaises(PySparkTypeError) as pe: df.select(F.overlay(df.x, df.y, 7.5, 0).alias("ol")).collect() @@ -1164,8 +1156,8 @@ def test_assert_true(self): def check_assert_true(self, tpe): df = self.spark.range(3) - self.assertEqual( - df.select(F.assert_true(df.id < 3)).toDF("val").collect(), + assertDataFrameEqual( + df.select(F.assert_true(df.id < 3)).toDF("val"), [Row(val=None), Row(val=None), Row(val=None)], ) @@ -1302,17 +1294,17 @@ def test_np_scalar_input(self): df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"]) for dtype in [np.int8, np.int16, np.int32, np.int64]: - res = df.select(F.array_contains(df.data, dtype(1)).alias("b")).collect() - self.assertEqual([Row(b=True), Row(b=False)], res) - res = df.select(F.array_position(df.data, dtype(1)).alias("c")).collect() - self.assertEqual([Row(c=1), Row(c=0)], res) + res = df.select(F.array_contains(df.data, dtype(1)).alias("b")) + assertDataFrameEqual([Row(b=True), Row(b=False)], res) + res = df.select(F.array_position(df.data, dtype(1)).alias("c")) + assertDataFrameEqual([Row(c=1), Row(c=0)], res) df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"]) for dtype in [np.float32, np.float64]: - res = df.select(F.array_contains(df.data, dtype(1)).alias("b")).collect() - self.assertEqual([Row(b=True), Row(b=False)], res) - res = df.select(F.array_position(df.data, dtype(1)).alias("c")).collect() - self.assertEqual([Row(c=1), Row(c=0)], res) + res = df.select(F.array_contains(df.data, dtype(1)).alias("b")) + assertDataFrameEqual([Row(b=True), Row(b=False)], res) + res = df.select(F.array_position(df.data, dtype(1)).alias("c")) + assertDataFrameEqual([Row(c=1), Row(c=0)], res) @unittest.skipIf(not have_numpy, "NumPy not installed") def test_ndarray_input(self): @@ -1729,46 +1721,42 @@ class IntEnum(Enum): def test_nullifzero_zeroifnull(self): df = self.spark.createDataFrame([(0,), (1,)], ["a"]) - result = df.select(nullifzero(df.a).alias("r")).collect() - self.assertEqual([Row(r=None), Row(r=1)], result) + result = df.select(nullifzero(df.a).alias("r")) + assertDataFrameEqual([Row(r=None), Row(r=1)], result) df = self.spark.createDataFrame([(None,), (1,)], ["a"]) - result = df.select(zeroifnull(df.a).alias("r")).collect() - self.assertEqual([Row(r=0), Row(r=1)], result) + result = df.select(zeroifnull(df.a).alias("r")) + assertDataFrameEqual([Row(r=0), Row(r=1)], result) def test_randstr_uniform(self): df = self.spark.createDataFrame([(0,)], ["a"]) - result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect() - self.assertEqual([Row(5)], result) + result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)") + assertDataFrameEqual([Row(5)], result) # The random seed is optional. - result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect() - self.assertEqual([Row(5)], result) + result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)") + assertDataFrameEqual([Row(5)], result) df = self.spark.createDataFrame([(0,)], ["a"]) - result = ( - df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")) - .selectExpr("x > 5") - .collect() - ) - self.assertEqual([Row(True)], result) + result = df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")).selectExpr("x > 5") + assertDataFrameEqual([Row(True)], result) # The random seed is optional. - result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect() - self.assertEqual([Row(True)], result) + result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5") + assertDataFrameEqual([Row(True)], result) def test_string_validation(self): df = self.spark.createDataFrame([("abc",)], ["a"]) # test is_valid_utf8 - result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r")).collect() - self.assertEqual([Row(r=True)], result_is_valid_utf8) + result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r")) + assertDataFrameEqual([Row(r=True)], result_is_valid_utf8) # test make_valid_utf8 - result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r")).collect() - self.assertEqual([Row(r="abc")], result_make_valid_utf8) + result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r")) + assertDataFrameEqual([Row(r="abc")], result_make_valid_utf8) # test validate_utf8 - result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r")).collect() - self.assertEqual([Row(r="abc")], result_validate_utf8) + result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r")) + assertDataFrameEqual([Row(r="abc")], result_validate_utf8) # test try_validate_utf8 - result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r")).collect() - self.assertEqual([Row(r="abc")], result_try_validate_utf8) + result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r")) + assertDataFrameEqual([Row(r="abc")], result_try_validate_utf8) class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):