From fdfe737b8fba70002c985e22e0c018521fc707ec Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 7 Oct 2024 13:20:36 +0000 Subject: [PATCH] join_where tests --- python/cudf_polars/tests/test_join.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index 7d9ec98db97..ff080b716d1 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -39,6 +39,7 @@ def right(): { "a": [1, 4, 3, 7, None, None], "c": [2, 3, 4, 5, 6, 7], + "d": [6, None, 7, 8, -1, 2], } ) @@ -86,3 +87,24 @@ def test_join_literal_key_unsupported(left, right, left_on, right_on): q = left.join(right, left_on=left_on, right_on=right_on, how="inner") assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize( + "conditions", + [ + [pl.col("a") < pl.col("a_right")], + [pl.col("a_right") <= pl.col("a") * 2], + [pl.col("b") * 2 > pl.col("a_right"), pl.col("a") == pl.col("c_right")], + [pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")], + pytest.param( + [pl.col("b") <= pl.col("a_right") * 7, pl.col("a") < pl.col("d") * 2], + marks=pytest.mark.xfail( + reason="https://github.com/pola-rs/polars/issues/19119" + ), + ), + ], +) +def test_join_where(left, right, conditions): + q = left.join_where(right, *conditions) + + assert_gpu_result_equal(q, check_row_order=False)