Skip to content

Commit

Permalink
Consistent API to set parameters of aggregate and window functions (`…
Browse files Browse the repository at this point in the history
…AggregateExt` --> `ExprFunctionExt`) (#11550)

* Moving over AggregateExt to ExprFunctionExt and adding in function settings for window functions

* Switch WindowFrame to only need the window function definition and arguments. Other parameters will be set via the ExprFuncBuilder

* Changing null_treatment to take an option, but this is mostly for code cleanliness and not strictly required

* Moving functions in ExprFuncBuilder over to be explicitly implementing ExprFunctionExt trait so we can guarantee a consistent user experience no matter which they call on the Expr and which on the builder

* Apply cargo fmt

* Add deprecated trait AggregateExt so that users get a warning but still builds

* Window helper functions should return Expr

* Update documentation to show window function example

* Add license info

* Update comments that are no longer applicable

* Remove first_value and last_value since these are already implemented in the aggregate functions

* Update  to use WindowFunction::new to set additional parameters for order_by using ExprFunctionExt

* Apply cargo fmt

* Fix up clippy

* fix doc example

* fmt

* doc tweaks

* more doc tweaks

* fix up links

* fix integration test

* fix anothr doc example

---------

Co-authored-by: Tim Saucer <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people authored Jul 24, 2024
1 parent 76039fa commit 886e8ac
Show file tree
Hide file tree
Showing 26 changed files with 657 additions and 444 deletions.
12 changes: 6 additions & 6 deletions datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ async fn main() -> Result<()> {
df.show().await?;

// Now, run the function using the DataFrame API:
let window_expr = smooth_it.call(
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(None),
);
let window_expr = smooth_it
.call(vec![col("speed")]) // smooth_it(speed)
.partition_by(vec![col("car")]) // PARTITION BY car
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
.window_frame(WindowFrame::new(None))
.build()?;
let df = ctx.table("cars").await?.window(vec![window_expr])?;

// print the results
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};
use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
///
Expand Down Expand Up @@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> {
let agg = first_value.call(vec![col("price")]);
assert_eq!(agg.to_string(), "first_value(price)");

// You can use the AggregateExt trait to create more complex aggregates
// You can use the ExprFunctionExt trait to create more complex aggregates
// such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
let agg = first_value
.call(vec![col("price")])
Expand Down
12 changes: 6 additions & 6 deletions datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ async fn main() -> Result<()> {
df.show().await?;

// Now, run the function using the DataFrame API:
let window_expr = smooth_it.call(
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(None),
);
let window_expr = smooth_it
.call(vec![col("speed")]) // smooth_it(speed)
.partition_by(vec![col("car")]) // PARTITION BY car
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
.window_frame(WindowFrame::new(None))
.build()?;
let df = ctx.table("cars").await?.window(vec![window_expr])?;

// print the results
Expand Down
13 changes: 6 additions & 7 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,8 +1696,8 @@ mod tests {
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation,
Volatility, WindowFrame, WindowFunctionDefinition,
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -1867,11 +1867,10 @@ mod tests {
BuiltInWindowFunction::FirstValue,
),
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(None),
None,
));
))
.partition_by(vec![col("aggregate_test_100.c2")])
.build()
.unwrap();
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();

Expand Down
22 changes: 11 additions & 11 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum};

Expand Down Expand Up @@ -183,15 +183,15 @@ async fn test_count_wildcard_on_window() -> Result<()> {
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
None,
))])?
))
.order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
))
.build()
.unwrap()])?
.explain(false, false)?
.collect()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
use arrow_schema::{DataType, Field};
use datafusion::prelude::*;
use datafusion_common::{assert_contains, DFSchema, ScalarValue};
use datafusion_expr::AggregateExt;
use datafusion_expr::ExprFunctionExt;
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::first_last::first_value_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
Expand Down
85 changes: 65 additions & 20 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::expr_to_columns;
use crate::{
aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator,
Signature,
aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction,
ExprSchemable, Operator, Signature, WindowFrame, WindowUDF,
};
use crate::{window_frame, Volatility};

Expand Down Expand Up @@ -60,6 +60,10 @@ use sqlparser::ast::NullTreatment;
/// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or
/// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]).
///
/// See also [`ExprFunctionExt`] for creating aggregate and window functions.
///
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
///
/// # Schema Access
///
/// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability
Expand Down Expand Up @@ -283,15 +287,17 @@ pub enum Expr {
/// This expression is guaranteed to have a fixed type.
TryCast(TryCast),
/// A sort expression, that can be used to sort values.
///
/// See [Expr::sort] for more details
Sort(Sort),
/// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Calls an aggregate function with arguments, and optional
/// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
///
/// See also [`AggregateExt`] to set these fields.
/// See also [`ExprFunctionExt`] to set these fields.
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
Expand Down Expand Up @@ -641,9 +647,9 @@ impl AggregateFunctionDefinition {

/// Aggregate function
///
/// See also [`AggregateExt`] to set these fields on `Expr`
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
///
/// [`AggregateExt`]: crate::udaf::AggregateExt
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
Expand Down Expand Up @@ -769,7 +775,52 @@ impl fmt::Display for WindowFunctionDefinition {
}
}

impl From<aggregate_function::AggregateFunction> for WindowFunctionDefinition {
fn from(value: aggregate_function::AggregateFunction) -> Self {
Self::AggregateFunction(value)
}
}

impl From<BuiltInWindowFunction> for WindowFunctionDefinition {
fn from(value: BuiltInWindowFunction) -> Self {
Self::BuiltInWindowFunction(value)
}
}

impl From<Arc<crate::AggregateUDF>> for WindowFunctionDefinition {
fn from(value: Arc<crate::AggregateUDF>) -> Self {
Self::AggregateUDF(value)
}
}

impl From<Arc<WindowUDF>> for WindowFunctionDefinition {
fn from(value: Arc<WindowUDF>) -> Self {
Self::WindowUDF(value)
}
}

/// Window function
///
/// Holds the actual actual function to call [`WindowFunction`] as well as its
/// arguments (`args`) and the contents of the `OVER` clause:
///
/// 1. `PARTITION BY`
/// 2. `ORDER BY`
/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`)
///
/// # Example
/// ```
/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt};
/// # use datafusion_expr::expr::WindowFunction;
/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c)
/// let expr = Expr::WindowFunction(
/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")])
/// )
/// .partition_by(vec![col("b")])
/// .order_by(vec![col("b").sort(true, true)])
/// .build()
/// .unwrap();
/// ```
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct WindowFunction {
/// Name of the function
Expand All @@ -787,22 +838,16 @@ pub struct WindowFunction {
}

impl WindowFunction {
/// Create a new Window expression
pub fn new(
fun: WindowFunctionDefinition,
args: Vec<Expr>,
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
null_treatment: Option<NullTreatment>,
) -> Self {
/// Create a new Window expression with the specified argument an
/// empty `OVER` clause
pub fn new(fun: impl Into<WindowFunctionDefinition>, args: Vec<Expr>) -> Self {
Self {
fun,
fun: fun.into(),
args,
partition_by,
order_by,
window_frame,
null_treatment,
partition_by: Vec::default(),
order_by: Vec::default(),
window_frame: WindowFrame::new(None),
null_treatment: None,
}
}
}
Expand Down
Loading

0 comments on commit 886e8ac

Please sign in to comment.