diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index e748ec16f14..1881286ccbb 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -19,6 +19,7 @@ from cudf_polars.dsl.expressions.base import ( AggInfo, Col, + ColRef, Expr, NamedExpr, ) @@ -40,6 +41,7 @@ "LiteralColumn", "Len", "Col", + "ColRef", "BooleanFunction", "StringFunction", "TemporalFunction", diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/base.py b/python/cudf_polars/cudf_polars/dsl/expressions/base.py index effe8cb2378..21ba7aea707 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/base.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/base.py @@ -20,7 +20,7 @@ from cudf_polars.containers import Column, DataFrame -__all__ = ["Expr", "NamedExpr", "Col", "AggInfo", "ExecutionContext"] +__all__ = ["Expr", "NamedExpr", "Col", "AggInfo", "ExecutionContext", "ColRef"] class AggInfo(NamedTuple): @@ -249,3 +249,36 @@ def do_evaluate( def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" return AggInfo([(self, plc.aggregation.collect_list(), self)]) + + +class ColRef(Expr): + __slots__ = ("index", "table_ref") + _non_child = ("dtype", "index", "table_ref") + index: int + table_ref: plc.expressions.TableReference + + def __init__( + self, + dtype: plc.DataType, + index: int, + table_ref: plc.expressions.TableReference, + column: Expr, + ) -> None: + if not isinstance(column, Col): + raise TypeError("Column reference should only apply to columns") + self.dtype = dtype + self.index = index + self.table_ref = table_ref + self.children = (column,) + + def do_evaluate( + self, + df: DataFrame, + *, + context: ExecutionContext = ExecutionContext.FRAME, + mapping: Mapping[Expr, Column] | None = None, + ) -> Column: + """Evaluate this expression given a dataframe for context.""" + raise NotImplementedError( + "Only expect this node as part of an expression translated to libcudf AST." + ) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index a242ff9300f..bc42b4a254f 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -29,8 +29,9 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import Column, DataFrame from cudf_polars.dsl.nodebase import Node -from cudf_polars.dsl.to_ast import to_parquet_filter +from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter from cudf_polars.utils import dtypes +from cudf_polars.utils.versions import POLARS_VERSION_GT_112 if TYPE_CHECKING: from collections.abc import Callable, Hashable, MutableMapping, Sequence @@ -48,6 +49,7 @@ "Select", "GroupBy", "Join", + "ConditionalJoin", "HStack", "Distinct", "Sort", @@ -522,6 +524,12 @@ def do_evaluate( ) # pragma: no cover; post init trips first if row_index is not None: name, offset = row_index + if POLARS_VERSION_GT_112: + # If we sliced away some data from the start, that + # shifts the row index. + # But prior to 1.13, polars had this wrong, so we match behaviour + # https://github.com/pola-rs/polars/issues/19607 + offset += skip_rows # pragma: no cover; polars 1.13 not yet released dtype = schema[name] step = plc.interop.from_arrow( pa.scalar(1, type=plc.interop.to_arrow(dtype)) @@ -890,6 +898,66 @@ def do_evaluate( return DataFrame(broadcasted).slice(options.slice) +class ConditionalJoin(IR): + """A conditional inner join of two dataframes on a predicate.""" + + __slots__ = ("predicate", "options", "ast_predicate") + _non_child = ("schema", "predicate", "options") + predicate: expr.Expr + options: tuple + + def __init__( + self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR + ) -> None: + self.schema = schema + self.predicate = predicate + self.options = options + self.children = (left, right) + self.ast_predicate = to_ast(predicate) + _, join_nulls, zlice, suffix, coalesce = self.options + # Preconditions from polars + assert not join_nulls + assert not coalesce + if self.ast_predicate is None: + raise NotImplementedError( + f"Conditional join with predicate {predicate}" + ) # pragma: no cover; polars never delivers expressions we can't handle + self._non_child_args = (self.ast_predicate, zlice, suffix) + + @classmethod + def do_evaluate( + cls, + predicate: plc.expressions.Expression, + zlice: tuple[int, int] | None, + suffix: str, + left: DataFrame, + right: DataFrame, + ) -> DataFrame: + """Evaluate and return a dataframe.""" + lg, rg = plc.join.conditional_inner_join(left.table, right.table, predicate) + left = DataFrame.from_table( + plc.copying.gather( + left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK + ), + left.column_names, + ) + right = DataFrame.from_table( + plc.copying.gather( + right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK + ), + right.column_names, + ) + right = right.rename_columns( + { + name: f"{name}{suffix}" + for name in right.column_names + if name in left.column_names_set + } + ) + result = left.with_columns(right.columns) + return result.slice(zlice) + + class Join(IR): """A join of two dataframes.""" diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index 9a0838631cc..acc4b3669af 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -14,12 +14,14 @@ from pylibcudf import expressions as plc_expr from cudf_polars.dsl import expr -from cudf_polars.dsl.traversal import CachingVisitor +from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged from cudf_polars.typing import GenericTransformer if TYPE_CHECKING: from collections.abc import Mapping + from cudf_polars.typing import ExprTransformer + # Can't merge these op-mapping dictionaries because scoped enum values # are exposed by cython with equality/hash based one their underlying # representation type. So in a dict they are just treated as integers. @@ -128,7 +130,14 @@ def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression: def _(node: expr.Col, self: Transformer) -> plc_expr.Expression: if self.state["for_parquet"]: return plc_expr.ColumnNameReference(node.name) - return plc_expr.ColumnReference(self.state["name_to_index"][node.name]) + raise TypeError("Should always be wrapped in a ColRef node before translation") + + +@_to_ast.register +def _(node: expr.ColRef, self: Transformer) -> plc_expr.Expression: + if self.state["for_parquet"]: + raise TypeError("Not expecting ColRef node in parquet filter") + return plc_expr.ColumnReference(node.index, node.table_ref) @_to_ast.register @@ -238,9 +247,7 @@ def to_parquet_filter(node: expr.Expr) -> plc_expr.Expression | None: return None -def to_ast( - node: expr.Expr, *, name_to_index: Mapping[str, int] -) -> plc_expr.Expression | None: +def to_ast(node: expr.Expr) -> plc_expr.Expression | None: """ Convert an expression to libcudf AST nodes suitable for compute_column. @@ -248,18 +255,66 @@ def to_ast( ---------- node Expression to convert. - name_to_index - Mapping from column names to their index in the table that - will be used for expression evaluation. + + Notes + ----- + `Col` nodes must always be wrapped in `TableRef` nodes when + converting to an ast expression so that their table reference and + index are provided. Returns ------- - pylibcudf Expressoin if conversion is possible, otherwise None. + pylibcudf Expression if conversion is possible, otherwise None. """ - mapper = CachingVisitor( - _to_ast, state={"for_parquet": False, "name_to_index": name_to_index} - ) + mapper = CachingVisitor(_to_ast, state={"for_parquet": False}) try: return mapper(node) except (KeyError, NotImplementedError): return None + + +def _insert_colrefs(node: expr.Expr, rec: ExprTransformer) -> expr.Expr: + if isinstance(node, expr.Col): + return expr.ColRef( + node.dtype, + rec.state["name_to_index"][node.name], + rec.state["table_ref"], + node, + ) + return reuse_if_unchanged(node, rec) + + +def insert_colrefs( + node: expr.Expr, + *, + table_ref: plc.expressions.TableReference, + name_to_index: Mapping[str, int], +) -> expr.Expr: + """ + Insert column references into an expression before conversion to libcudf AST. + + Parameters + ---------- + node + Expression to insert references into. + table_ref + pylibcudf `TableReference` indicating whether column + references are coming from the left or right table. + name_to_index: + Mapping from column names to column indices in the table + eventually used for evaluation. + + Notes + ----- + All column references are wrapped in the same, singular, table + reference, so this function relies on the expression only + containing column references from a single table. + + Returns + ------- + New expression with column references inserted. + """ + mapper = CachingVisitor( + _insert_colrefs, state={"table_ref": table_ref, "name_to_index": name_to_index} + ) + return mapper(node) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 5181214819e..2711676d31e 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -9,7 +9,7 @@ import json from contextlib import AbstractContextManager, nullcontext from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import Any import pyarrow as pa from typing_extensions import assert_never @@ -21,13 +21,10 @@ import pylibcudf as plc from cudf_polars.dsl import expr, ir -from cudf_polars.dsl.traversal import make_recursive, reuse_if_unchanged +from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes, sorting -if TYPE_CHECKING: - from cudf_polars.typing import ExprTransformer - __all__ = ["translate_ir", "translate_named_expr"] @@ -204,55 +201,40 @@ def _( raise NotImplementedError( f"Unsupported join type {how}" ) # pragma: no cover; asof joins not yet exposed - # No exposure of mixed/conditional joins in pylibcudf yet, so in - # the first instance, implement by doing a cross join followed by - # a filter. - _, join_nulls, zlice, suffix, coalesce = node.options - cross = ir.Join( - schema, - [], - [], - ("cross", join_nulls, None, suffix, coalesce), - inp_left, - inp_right, - ) - dtype = plc.DataType(plc.TypeId.BOOL8) if op2 is None: ops = [op1] else: ops = [op1, op2] - suffix = cross.options[3] - - # Column references in the right table refer to the post-join - # names, so with suffixes. - def _rename(e: expr.Expr, rec: ExprTransformer) -> expr.Expr: - if isinstance(e, expr.Col) and e.name in inp_left.schema: - return type(e)(e.dtype, f"{e.name}{suffix}") - return reuse_if_unchanged(e, rec) - - mapper = make_recursive(_rename) - right_on = [ - expr.NamedExpr( - f"{old.name}{suffix}" if old.name in inp_left.schema else old.name, new - ) - for new, old in zip( - (mapper(e.value) for e in right_on), right_on, strict=True - ) - ] - mask = functools.reduce( + + dtype = plc.DataType(plc.TypeId.BOOL8) + predicate = functools.reduce( functools.partial( expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND ), ( - expr.BinOp(dtype, expr.BinOp._MAPPING[op], left.value, right.value) + expr.BinOp( + dtype, + expr.BinOp._MAPPING[op], + insert_colrefs( + left.value, + table_ref=plc.expressions.TableReference.LEFT, + name_to_index={ + name: i for i, name in enumerate(inp_left.schema) + }, + ), + insert_colrefs( + right.value, + table_ref=plc.expressions.TableReference.RIGHT, + name_to_index={ + name: i for i, name in enumerate(inp_right.schema) + }, + ), + ) for op, left, right in zip(ops, left_on, right_on, strict=True) ), ) - filtered = ir.Filter(schema, expr.NamedExpr("mask", mask), cross) - if zlice is not None: - offset, length = zlice - return ir.Slice(schema, offset, length, filtered) - return filtered + + return ir.ConditionalJoin(schema, predicate, node.options, inp_left, inp_right) @_translate_ir.register diff --git a/python/cudf_polars/cudf_polars/utils/versions.py b/python/cudf_polars/cudf_polars/utils/versions.py index a119cab3b74..b08cede8f7f 100644 --- a/python/cudf_polars/cudf_polars/utils/versions.py +++ b/python/cudf_polars/cudf_polars/utils/versions.py @@ -14,6 +14,8 @@ POLARS_VERSION_LT_111 = POLARS_VERSION < parse("1.11") POLARS_VERSION_LT_112 = POLARS_VERSION < parse("1.12") +POLARS_VERSION_GT_112 = POLARS_VERSION > parse("1.12") +POLARS_VERSION_LT_113 = POLARS_VERSION < parse("1.13") def _ensure_polars_version(): diff --git a/python/cudf_polars/tests/dsl/test_to_ast.py b/python/cudf_polars/tests/dsl/test_to_ast.py index 57d794d4890..8f10f119199 100644 --- a/python/cudf_polars/tests/dsl/test_to_ast.py +++ b/python/cudf_polars/tests/dsl/test_to_ast.py @@ -3,6 +3,7 @@ from __future__ import annotations +import pyarrow as pa import pytest import polars as pl @@ -10,10 +11,11 @@ import pylibcudf as plc +import cudf_polars.dsl.expr as expr_nodes import cudf_polars.dsl.ir as ir_nodes from cudf_polars import translate_ir from cudf_polars.containers.dataframe import DataFrame, NamedColumn -from cudf_polars.dsl.to_ast import to_ast +from cudf_polars.dsl.to_ast import insert_colrefs, to_ast, to_parquet_filter @pytest.fixture(scope="module") @@ -65,7 +67,14 @@ def test_compute_column(expr, df): name_to_index = {c.name: i for i, c in enumerate(table.columns)} def compute_column(e): - ast = to_ast(e.value, name_to_index=name_to_index) + e_with_colrefs = insert_colrefs( + e.value, + table_ref=plc.expressions.TableReference.LEFT, + name_to_index=name_to_index, + ) + with pytest.raises(NotImplementedError): + e_with_colrefs.evaluate(table) + ast = to_ast(e_with_colrefs) if ast is not None: return NamedColumn( plc.transform.compute_column(table.table, ast), name=e.name @@ -77,3 +86,28 @@ def compute_column(e): expect = q.collect() assert_frame_equal(expect, got) + + +def test_invalid_colref_construction_raises(): + literal = expr_nodes.Literal( + plc.DataType(plc.TypeId.INT8), pa.scalar(1, type=pa.int8()) + ) + with pytest.raises(TypeError): + expr_nodes.ColRef( + literal.dtype, 0, plc.expressions.TableReference.LEFT, literal + ) + + +def test_to_ast_without_colref_raises(): + col = expr_nodes.Col(plc.DataType(plc.TypeId.INT8), "a") + + with pytest.raises(TypeError): + to_ast(col) + + +def test_to_parquet_filter_with_colref_raises(): + col = expr_nodes.Col(plc.DataType(plc.TypeId.INT8), "a") + colref = expr_nodes.ColRef(col.dtype, 0, plc.expressions.TableReference.LEFT, col) + + with pytest.raises(TypeError): + to_parquet_filter(colref) diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index 8ca7a7b9264..2fcbbf21f1c 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -13,7 +13,7 @@ assert_gpu_result_equal, assert_ir_translation_raises, ) -from cudf_polars.utils.versions import POLARS_VERSION_LT_112 +from cudf_polars.utils.versions import POLARS_VERSION_LT_112, POLARS_VERSION_LT_113 @pytest.fixture(params=[False, True], ids=["nulls_not_equal", "nulls_equal"]) @@ -110,7 +110,11 @@ def test_cross_join(left, right, zlice): @pytest.mark.parametrize( - "left_on,right_on", [(pl.col("a"), pl.lit(2)), (pl.lit(2), pl.col("a"))] + "left_on,right_on", + [ + (pl.col("a"), pl.lit(2, dtype=pl.Int64)), + (pl.lit(2, dtype=pl.Int64), pl.col("a")), + ], ) def test_join_literal_key_unsupported(left, right, left_on, right_on): q = left.join(right, left_on=left_on, right_on=right_on, how="inner") @@ -125,7 +129,13 @@ def test_join_literal_key_unsupported(left, right, left_on, right_on): [pl.col("a_right") <= pl.col("a") * 2], [pl.col("b") * 2 > pl.col("a_right"), pl.col("a") == pl.col("c_right")], [pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")], - [pl.col("b") <= pl.col("a_right") * 7, pl.col("a") < pl.col("d") * 2], + pytest.param( + [pl.col("b") <= pl.col("a_right") * 7, pl.col("a") < pl.col("d") * 2], + marks=pytest.mark.xfail( + POLARS_VERSION_LT_113, + reason="https://github.com/pola-rs/polars/issues/19597", + ), + ), ], ) def test_join_where(left, right, conditions, zlice):