From 30ee2e001fb238b9c00fae631ff581db81b8e5ec Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 25 Oct 2024 12:05:02 +0000 Subject: [PATCH] Slightly simplify evaluate_node implementations --- python/cudf_polars/cudf_polars/dsl/ir.py | 253 ++++++++++++++--------- 1 file changed, 153 insertions(+), 100 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 28a7203f05b..39e992a140d 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -125,9 +125,12 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column class IR(Node["IR"]): """Abstract plan node, representing an unevaluated dataframe.""" - __slots__ = ("schema",) + __slots__ = ("schema", "_non_child_args") # This annotation is needed because of https://github.com/python/mypy/issues/17981 _non_child: ClassVar[tuple[str, ...]] = ("schema",) + # Concrete classes should set this up with the arguments that will + # be passed to do_evaluate. + _non_child_args: tuple[Any, ...] schema: Schema """Mapping from column names to their data types.""" @@ -144,42 +147,34 @@ def get_hashable(self) -> Hashable: schema_hash = tuple(self.schema.items()) return (type(self), schema_hash, args) - def _eval_arguments(self, children: Sequence[DataFrame]) -> Sequence: - # Construct arguments for evaluate_node. - # By default, this is _non_child attributes followed by - # the evaluated children. - return ( - *(getattr(self, attr) for attr in self._non_child), - *children, - ) - - @classmethod - def evaluate_node(cls, *args: Any, **kwargs: Any) -> DataFrame: - """ - Evaluate the node (given its evaluated children), and return a dataframe. + # Hacky to avoid type-checking issues, just advertise the + # signature. Both mypy and pyright complain if we have an abstract + # method that takes arbitrary *args, but the subclasses have + # tighter signatures. This complaint is correct because the + # subclass is not Liskov-substitutable for the superclass. + # However, we know do_evaluate will only be called with the + # correct arguments by "construction". + do_evaluate: Callable[..., DataFrame] + """ + Evaluate the node (given its evaluated children), and return a dataframe. - Parameters - ---------- - *args - Positional arguments specified in IR._eval_arguments. - **kwargs - Key-word arguments. This should be empty! + Parameters + ---------- + args + Non child arguments followed by any evaluated dataframe inputs. - Returns - ------- - DataFrame (on device) representing the evaluation of this plan - node. + Returns + ------- + DataFrame (on device) representing the evaluation of this plan + node. - Raises - ------ - NotImplementedError - If we couldn't evaluate things. Ideally this should not occur, - since the translation phase should pick up things that we - cannot handle. - """ - raise NotImplementedError( - f"Evaluation of plan {cls.__name__}" - ) # pragma: no cover + Raises + ------ + NotImplementedError + If we couldn't evaluate things. Ideally this should not occur, + since the translation phase should pick up things that we + cannot handle. + """ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """ @@ -191,6 +186,12 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: Mapping from cached node ids to constructed DataFrames. Used to implement evaluation of the `Cache` node. + Notes + ----- + Prefer not to override this method. Instead implement + :meth:`do_evaluate` which doesn't encode a recursion scheme + and just assumes already evaluated inputs. + Returns ------- DataFrame (on device) representing the evaluation of this plan @@ -203,8 +204,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: since the translation phase should pick up things that we cannot handle. """ - evald_children = [child.evaluate(cache=cache) for child in self.children] - return self.evaluate_node(*self._eval_arguments(evald_children)) + return self.do_evaluate( + *self._non_child_args, + *(child.evaluate(cache=cache) for child in self.children), + ) class PythonScan(IR): @@ -221,6 +224,7 @@ def __init__(self, schema: Schema, options: Any, predicate: expr.NamedExpr | Non self.schema = schema self.options = options self.predicate = predicate + self._non_child_args = (schema, options, predicate) self.children = () raise NotImplementedError("PythonScan not implemented") @@ -293,6 +297,18 @@ def __init__( self.n_rows = n_rows self.row_index = row_index self.predicate = predicate + self._non_child_args = ( + schema, + typ, + reader_options, + cloud_options, + paths, + with_columns, + skip_rows, + n_rows, + row_index, + predicate, + ) self.children = () if self.typ not in ("csv", "parquet", "ndjson"): # pragma: no cover # This line is unhittable ATM since IPC/Anonymous scan raise @@ -376,7 +392,7 @@ def get_hashable(self) -> Hashable: ) @classmethod - def evaluate_node( + def do_evaluate( cls, schema: Schema, typ: str, @@ -542,9 +558,21 @@ def __init__(self, schema: Schema, key: int, value: IR): self.schema = schema self.key = key self.children = (value,) + self._non_child_args = (key,) + + @classmethod + def do_evaluate( + cls, key: int, df: DataFrame + ) -> DataFrame: # pragma: no cover; basic evaluation never calls this + """Evaluate and return a dataframe.""" + # Our value has already been computed for us, so let's just + # return it. + return df def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" + # We must override the recursion scheme because we don't want + # to recurse if we're in the cache. try: return cache[self.key] except KeyError: @@ -579,6 +607,7 @@ def __init__( self.df = df self.projection = tuple(projection) if projection is not None else None self.predicate = predicate + self._non_child_args = (schema, df, projection, predicate) self.children = () def get_hashable(self) -> Hashable: @@ -592,7 +621,13 @@ def get_hashable(self) -> Hashable: return (type(self), schema_hash, id(self.df), self.projection, self.predicate) @classmethod - def evaluate_node(cls, schema, df, projection, predicate) -> DataFrame: + def do_evaluate( + cls, + schema: Schema, + df: Any, + projection: tuple[str, ...] | None, + predicate: expr.NamedExpr | None, + ) -> DataFrame: """Evaluate and return a dataframe.""" pdf = pl.DataFrame._from_pydf(df) if projection is not None: @@ -630,12 +665,12 @@ def __init__( self.exprs = tuple(exprs) self.should_broadcast = should_broadcast self.children = (df,) + self._non_child_args = (exprs, should_broadcast) @classmethod - def evaluate_node( + def do_evaluate( cls, - schema: Schema, - exprs: Sequence[expr.NamedExpr], + exprs: tuple[expr.NamedExpr, ...], should_broadcast: bool, # noqa: FBT001 df: DataFrame, ) -> DataFrame: @@ -665,11 +700,12 @@ def __init__( self.schema = schema self.exprs = tuple(exprs) self.children = (df,) + self._non_child_args = (exprs,) @classmethod - def evaluate_node( - cls, schema: Schema, exprs: Sequence[expr.NamedExpr], df: DataFrame - ): + def do_evaluate( + cls, exprs: tuple[expr.NamedExpr, ...], df: DataFrame + ) -> DataFrame: # pragma: no cover; not exposed by polars yet """Evaluate and return a dataframe.""" columns = broadcast(*(e.evaluate(df) for e in exprs)) assert all(column.obj.size() == 1 for column in columns) @@ -720,6 +756,13 @@ def __init__( if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests): raise NotImplementedError("Nested aggregations in groupby") self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] + self._non_child_args = ( + keys, + agg_requests, + maintain_order, + options, + self.agg_infos, + ) @staticmethod def check_agg(agg: expr.Expr) -> int: @@ -749,24 +792,14 @@ def check_agg(agg: expr.Expr) -> int: else: raise NotImplementedError(f"No handler for {agg=}") - def _eval_arguments(self, children: Sequence[DataFrame]) -> Sequence: - return ( - self.keys, - self.agg_requests, - self.maintain_order, - self.options, - self.agg_infos, # Need agg_infos - *children, - ) - @classmethod - def evaluate_node( + def do_evaluate( cls, keys_in: Sequence[expr.NamedExpr], agg_requests: Sequence[expr.NamedExpr], maintain_order: bool, # noqa: FBT001 options: Any, - agg_infos: Sequence, + agg_infos: Sequence[expr.AggInfo], df: DataFrame, ): """Evaluate and return a dataframe.""" @@ -893,6 +926,7 @@ def __init__( self.right_on = tuple(right_on) self.options = options self.children = (left, right) + self._non_child_args = (self.left_on, self.right_on, self.options) if any( isinstance(e.value, expr.Literal) for e in itertools.chain(self.left_on, self.right_on) @@ -992,12 +1026,17 @@ def _reorder_maps( ).columns() @classmethod - def evaluate_node( + def do_evaluate( cls, - schema: Schema, left_on_exprs: Sequence[expr.NamedExpr], right_on_exprs: Sequence[expr.NamedExpr], - options: Any, + options: tuple[ + Literal["inner", "left", "right", "full", "semi", "anti", "cross"], + bool, + tuple[int, int] | None, + str, + bool, + ], left: DataFrame, right: DataFrame, ) -> DataFrame: @@ -1106,18 +1145,18 @@ def __init__( self.schema = schema self.columns = tuple(columns) self.should_broadcast = should_broadcast + self._non_child_args = (self.columns, self.should_broadcast) self.children = (df,) @classmethod - def evaluate_node( + def do_evaluate( cls, - schema: Schema, - columns_in: Sequence[expr.NamedExpr], + exprs: Sequence[expr.NamedExpr], should_broadcast: bool, # noqa: FBT001 df: DataFrame, ) -> DataFrame: """Evaluate and return a dataframe.""" - columns = [c.evaluate(df) for c in columns_in] + columns = [c.evaluate(df) for c in exprs] if should_broadcast: columns = broadcast(*columns, target_length=df.num_rows) else: @@ -1128,9 +1167,7 @@ def evaluate_node( # table that might have mismatching column lengths will # never be turned into a pylibcudf Table with all columns # by the Select, which is why this is safe. - assert all( - e.name and e.name.startswith("__POLARS_CSER_0x") for e in columns - ) + assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs) return df.with_columns(columns) @@ -1163,6 +1200,7 @@ def __init__( self.subset = subset self.zlice = zlice self.stable = stable + self._non_child_args = (keep, subset, zlice, stable) self.children = (df,) _KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = { @@ -1173,9 +1211,8 @@ def __init__( } @classmethod - def evaluate_node( + def do_evaluate( cls, - schema: Schema, keep: plc.stream_compaction.DuplicateKeepOption, subset: frozenset[str] | None, zlice: tuple[int, int] | None, @@ -1253,12 +1290,18 @@ def __init__( self.null_order = tuple(null_order) self.stable = stable self.zlice = zlice + self._non_child_args = ( + self.by, + self.order, + self.null_order, + self.stable, + self.zlice, + ) self.children = (df,) @classmethod - def evaluate_node( + def do_evaluate( cls, - schema: Schema, by: Sequence[expr.NamedExpr], order: Sequence[plc.types.Order], null_order: Sequence[plc.types.NullOrder], @@ -1310,12 +1353,11 @@ def __init__(self, schema: Schema, offset: int, length: int, df: IR): self.schema = schema self.offset = offset self.length = length + self._non_child_args = (offset, length) self.children = (df,) @classmethod - def evaluate_node( - cls, schema: Schema, offset: int, length: int, df: DataFrame - ) -> DataFrame: + def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" return df.slice((offset, length)) @@ -1331,12 +1373,11 @@ class Filter(IR): def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR): self.schema = schema self.mask = mask + self._non_child_args = (mask,) self.children = (df,) @classmethod - def evaluate_node( - cls, schema: Schema, mask_expr: expr.NamedExpr, df: DataFrame - ) -> DataFrame: + def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" (mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows) return df.filter(mask) @@ -1350,10 +1391,11 @@ class Projection(IR): def __init__(self, schema: Schema, df: IR): self.schema = schema + self._non_child_args = (schema,) self.children = (df,) @classmethod - def evaluate_node(cls, schema: Schema, df: DataFrame) -> DataFrame: + def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" # This can reorder things. columns = broadcast( @@ -1420,12 +1462,16 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR): "Unpivot cannot cast all input columns to " f"{self.schema[value_name].id()}" ) - self.options = (tuple(indices), tuple(pivotees), variable_name, value_name) + self.options = ( + tuple(indices), + tuple(pivotees), + (variable_name, schema[variable_name]), + (value_name, schema[value_name]), + ) + self._non_child_args = (name, self.options) @classmethod - def evaluate_node( - cls, schema: Schema, name: str, options: Any, df: DataFrame - ) -> DataFrame: + def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" if name == "rechunk": # No-op in our data model @@ -1444,7 +1490,12 @@ def evaluate_node( plc.lists.explode_outer(df.table, index), df.column_names ).sorted_like(df, subset=subset) elif name == "unpivot": - indices, pivotees, variable_name, value_name = options + ( + indices, + pivotees, + (variable_name, variable_dtype), + (value_name, value_dtype), + ) = options npiv = len(pivotees) index_columns = [ Column(col, name=name) @@ -1460,7 +1511,7 @@ def evaluate_node( plc.interop.from_arrow( pa.array( pivotees, - type=plc.interop.to_arrow(schema[variable_name]), + type=plc.interop.to_arrow(variable_dtype), ), ) ] @@ -1468,10 +1519,7 @@ def evaluate_node( df.num_rows, ).columns() value_column = plc.concatenate.concatenate( - [ - df.column_map[pivotee].astype(schema[value_name]).obj - for pivotee in pivotees - ] + [df.column_map[pivotee].astype(value_dtype).obj for pivotee in pivotees] ) return DataFrame( [ @@ -1495,19 +1543,19 @@ class Union(IR): def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR): self.schema = schema self.zlice = zlice + self._non_child_args = (zlice,) self.children = children schema = self.children[0].schema if not all(s.schema == schema for s in self.children[1:]): raise NotImplementedError("Schema mismatch") @classmethod - def evaluate_node( - cls, schema: Schema, zlice: tuple[int, int] | None, *dfs: DataFrame - ) -> DataFrame: + def do_evaluate(cls, zlice: tuple[int, int] | None, *dfs: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" # TODO: only evaluate what we need if we have a slice? return DataFrame.from_table( - plc.concatenate.concatenate([df.table for df in dfs]), dfs[0].column_names + plc.concatenate.concatenate([df.table for df in dfs]), + dfs[0].column_names, ).slice(zlice) @@ -1519,6 +1567,7 @@ class HConcat(IR): def __init__(self, schema: Schema, *children: IR): self.schema = schema + self._non_child_args = () self.children = children @staticmethod @@ -1550,17 +1599,21 @@ def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table: ) @classmethod - def evaluate_node(cls, schema: Schema, *dfs: DataFrame) -> DataFrame: + def do_evaluate(cls, *dfs: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" max_rows = max(df.num_rows for df in dfs) # Horizontal concatenation extends shorter tables with nulls - dfs = tuple( - df - if df.num_rows == max_rows - else DataFrame.from_table( - cls._extend_with_nulls(df.table, nrows=max_rows - df.num_rows), - df.column_names, + return DataFrame( + itertools.chain.from_iterable( + df.columns + for df in ( + df + if df.num_rows == max_rows + else DataFrame.from_table( + cls._extend_with_nulls(df.table, nrows=max_rows - df.num_rows), + df.column_names, + ) + for df in dfs + ) ) - for df in dfs ) - return DataFrame(itertools.chain.from_iterable(df.columns for df in dfs))