Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
Adapt to IR changes in polars 1.11
Browse files Browse the repository at this point in the history
In addition, implement inequality joins by translation to cross-join +
filter. This is the minimum necessary for things to work, and we will
use ast-based conditional and mixed joins in a followup.
  • Loading branch information
wence- committed Oct 23, 2024
1 parent 757b371 commit 1dc6e77
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 11 deletions.
13 changes: 6 additions & 7 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,10 @@ class Join(IR):
right_on: tuple[expr.NamedExpr, ...]
"""List of expressions used as keys in the right frame."""
options: tuple[
Literal["inner", "left", "right", "full", "leftsemi", "leftanti", "cross"],
Literal["inner", "left", "right", "full", "semi", "anti", "cross"],
bool,
tuple[int, int] | None,
str | None,
str,
bool,
]
"""
Expand Down Expand Up @@ -840,7 +840,7 @@ def __init__(
@staticmethod
@cache
def _joiners(
how: Literal["inner", "left", "right", "full", "leftsemi", "leftanti"],
how: Literal["inner", "left", "right", "full", "semi", "anti"],
) -> tuple[
Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
]:
Expand All @@ -862,13 +862,13 @@ def _joiners(
plc.copying.OutOfBoundsPolicy.NULLIFY,
plc.copying.OutOfBoundsPolicy.NULLIFY,
)
elif how == "leftsemi":
elif how == "semi":
return (
plc.join.left_semi_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
None,
)
elif how == "leftanti":
elif how == "anti":
return (
plc.join.left_anti_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
Expand Down Expand Up @@ -933,7 +933,6 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
left, right = (c.evaluate(cache=cache) for c in self.children)
how, join_nulls, zlice, suffix, coalesce = self.options
suffix = "_right" if suffix is None else suffix
if how == "cross":
# Separate implementation, since cross_join returns the
# result, not the gather maps
Expand All @@ -955,7 +954,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
columns[left.num_columns :], right.column_names, strict=True
)
]
return DataFrame([*left_cols, *right_cols])
return DataFrame([*left_cols, *right_cols]).slice(zlice)
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
left_on = DataFrame(broadcast(*(e.evaluate(left) for e in self.left_on)))
right_on = DataFrame(broadcast(*(e.evaluate(right) for e in self.right_on)))
Expand Down
76 changes: 72 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from __future__ import annotations

import functools
import json
from contextlib import AbstractContextManager, nullcontext
from functools import singledispatch
from typing import Any
from typing import TYPE_CHECKING, Any

import pyarrow as pa
import pylibcudf as plc
Expand All @@ -19,9 +20,13 @@
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir

from cudf_polars.dsl import expr, ir
from cudf_polars.dsl.traversal import make_recursive, reuse_if_unchanged
from cudf_polars.typing import NodeTraverser
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from cudf_polars.typing import ExprTransformer

__all__ = ["translate_ir", "translate_named_expr"]


Expand Down Expand Up @@ -182,7 +187,71 @@ def _(
with set_node(visitor, node.input_right):
inp_right = translate_ir(visitor, n=None)
right_on = [translate_named_expr(visitor, n=e) for e in node.right_on]
return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right)
if (how := node.options[0]) in {
"inner",
"left",
"right",
"full",
"cross",
"semi",
"anti",
}:
return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right)
else:
how, op1, op2 = how
if how != "ie_join":
raise NotImplementedError(
f"Unsupported join type {how}"
) # pragma: no cover; asof joins not yet exposed
# No exposure of mixed/conditional joins in pylibcudf yet, so in
# the first instance, implement by doing a cross join followed by
# a filter.
_, join_nulls, zlice, suffix, coalesce = node.options
cross = ir.Join(
schema,
[],
[],
("cross", join_nulls, None, suffix, coalesce),
inp_left,
inp_right,
)
dtype = plc.DataType(plc.TypeId.BOOL8)
if op2 is None:
ops = [op1]
else:
ops = [op1, op2]
suffix = cross.options[3]

# Column references in the right table refer to the post-join
# names, so with suffixes.
def _rename(e: expr.Expr, rec: ExprTransformer) -> expr.Expr:
if isinstance(e, expr.Col) and e.name in inp_left.schema:
return type(e)(e.dtype, f"{e.name}{suffix}")
return reuse_if_unchanged(e, rec)

mapper = make_recursive(_rename)
right_on = [
expr.NamedExpr(
f"{old.name}{suffix}" if old.name in inp_left.schema else old.name, new
)
for new, old in zip(
(mapper(e.value) for e in right_on), right_on, strict=True
)
]
mask = functools.reduce(
functools.partial(
expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND
),
(
expr.BinOp(dtype, expr.BinOp._MAPPING[op], left.value, right.value)
for op, left, right in zip(ops, left_on, right_on, strict=True)
),
)
filtered = ir.Filter(schema, expr.NamedExpr("mask", mask), cross)
if zlice is not None:
offset, length = zlice
return ir.Slice(schema, offset, length, filtered)
return filtered


@_translate_ir.register
Expand Down Expand Up @@ -319,8 +388,7 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
# IR is versioned with major.minor, minor is bumped for backwards
# compatible changes (e.g. adding new nodes), major is bumped for
# incompatible changes (e.g. renaming nodes).
# Polars 1.7 changes definition of the CSV reader options schema name.
if (version := visitor.version()) >= (3, 0):
if (version := visitor.version()) >= (4, 0):
raise NotImplementedError(
f"No support for polars IR {version=}"
) # pragma: no cover; no such version for now.
Expand Down

0 comments on commit 1dc6e77

Please sign in to comment.