Skip to content

Commit

Permalink
WIP: conditional join
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Nov 1, 2024
1 parent aaeeee8 commit 8629c74
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 44 deletions.
56 changes: 55 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame
from cudf_polars.dsl.nodebase import Node
from cudf_polars.dsl.to_ast import to_parquet_filter
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
Expand All @@ -48,6 +48,7 @@
"Select",
"GroupBy",
"Join",
"ConditionalJoin",
"HStack",
"Distinct",
"Sort",
Expand Down Expand Up @@ -802,6 +803,59 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return DataFrame(broadcasted).slice(self.options.slice)


class ConditionalJoin(IR):
"""A conditional inner join of two dataframes on a predicate."""

__slots__ = ("predicate", "options", "ast_predicate")
_non_child = ("schema", "predicate", "options")
predicate: expr.Expr
options: tuple

def __init__(
self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
) -> None:
self.schema = schema
self.predicate = predicate
self.options = options
self.children = (left, right)
self.ast_predicate = to_ast(predicate)
_, join_nulls, _, _, coalesce = self.options
# Preconditions from polars
assert not join_nulls
assert not coalesce
if self.ast_predicate is None:
raise NotImplementedError(f"Conditional join with predicate {predicate}")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
left, right = (c.evaluate(cache=cache) for c in self.children)
_, _, zlice, suffix, _ = self.options
lg, rg = plc.join.conditional_inner_join(
left.table, right.table, self.ast_predicate
)
left = DataFrame.from_table(
plc.copying.gather(
left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
),
left.column_names,
)
right = DataFrame.from_table(
plc.copying.gather(
right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
),
right.column_names,
)
right = right.rename_columns(
{
name: f"{name}{suffix}"
for name in right.column_names
if name in left.column_names_set
}
)
result = left.with_columns(right.columns)
return result.slice(zlice)


class Join(IR):
"""A join of two dataframes."""

Expand Down
68 changes: 25 additions & 43 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
from contextlib import AbstractContextManager, nullcontext
from functools import singledispatch
from typing import TYPE_CHECKING, Any
from typing import Any

import pyarrow as pa
from typing_extensions import assert_never
Expand All @@ -21,13 +21,10 @@
import pylibcudf as plc

from cudf_polars.dsl import expr, ir
from cudf_polars.dsl.traversal import make_recursive, reuse_if_unchanged
from cudf_polars.dsl.to_ast import insert_tablerefs
from cudf_polars.typing import NodeTraverser
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from cudf_polars.typing import ExprTransformer

__all__ = ["translate_ir", "translate_named_expr"]


Expand Down Expand Up @@ -204,55 +201,40 @@ def _(
raise NotImplementedError(
f"Unsupported join type {how}"
) # pragma: no cover; asof joins not yet exposed
# No exposure of mixed/conditional joins in pylibcudf yet, so in
# the first instance, implement by doing a cross join followed by
# a filter.
_, join_nulls, zlice, suffix, coalesce = node.options
cross = ir.Join(
schema,
[],
[],
("cross", join_nulls, None, suffix, coalesce),
inp_left,
inp_right,
)
dtype = plc.DataType(plc.TypeId.BOOL8)
if op2 is None:
ops = [op1]
else:
ops = [op1, op2]
suffix = cross.options[3]

# Column references in the right table refer to the post-join
# names, so with suffixes.
def _rename(e: expr.Expr, rec: ExprTransformer) -> expr.Expr:
if isinstance(e, expr.Col) and e.name in inp_left.schema:
return type(e)(e.dtype, f"{e.name}{suffix}")
return reuse_if_unchanged(e, rec)

mapper = make_recursive(_rename)
right_on = [
expr.NamedExpr(
f"{old.name}{suffix}" if old.name in inp_left.schema else old.name, new
)
for new, old in zip(
(mapper(e.value) for e in right_on), right_on, strict=True
)
]
mask = functools.reduce(

dtype = plc.DataType(plc.TypeId.BOOL8)
predicate = functools.reduce(
functools.partial(
expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND
),
(
expr.BinOp(dtype, expr.BinOp._MAPPING[op], left.value, right.value)
expr.BinOp(
dtype,
expr.BinOp._MAPPING[op],
insert_tablerefs(
left.value,
table_ref=plc.expressions.TableReference.LEFT,
name_to_index={
name: i for i, name in enumerate(inp_left.schema)
},
),
insert_tablerefs(
right.value,
table_ref=plc.expressions.TableReference.RIGHT,
name_to_index={
name: i for i, name in enumerate(inp_right.schema)
},
),
)
for op, left, right in zip(ops, left_on, right_on, strict=True)
),
)
filtered = ir.Filter(schema, expr.NamedExpr("mask", mask), cross)
if zlice is not None:
offset, length = zlice
return ir.Slice(schema, offset, length, filtered)
return filtered

return ir.ConditionalJoin(schema, predicate, node.options, inp_left, inp_right)


@_translate_ir.register
Expand Down

0 comments on commit 8629c74

Please sign in to comment.