From efe5708978a480d11d5406a7d7df76d73e15c5d7 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Fri, 18 Oct 2024 16:56:41 +0530 Subject: [PATCH 01/17] Convert `BuiltInWindowFunction::{Lead, Lag}` to a user defined window function (#12857) * Move `lead-lag` to `functions-window` package * Builds with warnings * Adds `PartitionEvaluatorArgs` * Extracts `shift_offset` from input expressions * Computes shift offset * Get default value from input expression * Implements `partition_evaluator` * Fixes compiler warnings * Comments out failing tests * Fixes `cargo test` errors and warnings * Minor: taplo formatting * Delete code * Define `lead`, `lag` user-defined window functions * Fixes `cargo build` errors * Export udwf and expression public APIs * Mark result field as nullable * Delete `return_type` tests for `lead` and `lag` * Disables test: window function case insensitive * Fixes: lowercase name in logical plan * Reverts to old methods for computing `shift_offset`, `default_value` * Implements expression reversal * Fixes: lowercase name in logical plans * Fixes: doc test compilation errors Fixes: doc test build errors * Temporarily quite clippy errors * Fixes proto defintion * Minor: fixes formatting * Fixes: doc tests * Uses macro for defining `lag_udwf()` and `leag_udwf()` * Fixes: window fuzz test cases * Copies doc comments verbatim from `BuiltInWindowFunction` enum * Deletes from window function case insensitive test * Deletes `BuiltInWindowFunction` expression APIs * Delete from `create_built_in_window_expr` * Deletes proto serialization * Delete from `BuiltInWindowFunction` enum * Deletes test for finding built-in window function * Fixes build errors + deletes redundant code * Deletes more code * Delete unnecessary structs * Refactors shift offset computation * Passes range unit test * Fixes: clippy::get-first error * Rewrite unit tests for WindowUDF * Fixes: unit test for lag with default value * Consistent input expressions and data types in unit tests * Minor: fixes formatting * Restore original helper method for unit tests * Revert "Refactors shift offset computation" This reverts commit 000ceb76409e66230f9c5017a30fa3c9bb1e6575. * Moves helper functions into `functions-window-common` package * Uses common helper functions in `{lead, lag}` * Minor: formatting * Revert "Moves helper functions into `functions-window-common` package" This reverts commit ab8a83c9c11ca3a245278f6f300438feaacb0978. * Moves common functions to utils * Minor: formatting fixes * Update lowercase names in explain output * Adds doc for `lead()` and `lag()` expression functions * Add doc for `WindowShiftKind::shift_offset` * Remove `arrow` dev dependency * Minor: formatting * Update inner doc comment * Serialize 1 or more window function arguments * Adds logical plan roundtrip test cases * Refactor: readability of unit tests * Minor: rename variable bindings * Minor: copy edit * Revert "Remove `arrow` dev dependency" This reverts commit 3eb09856c8ec4ddce20472deee2df590c2fd3f35. * Move null argument handling helper to utils * Disable failing sqllogic tests for handling NULL input * Revert "Disable failing sqllogic tests for handling NULL input" This reverts commit 270a2030637012d549c001e973a0a1bb6b3d4dd0. * Fixes: incorrect NULL handling in `lead`/`lag` window function * Adds more tests cases --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 1 + .../core/tests/fuzz_cases/window_fuzz.rs | 13 +- .../expr/src/built_in_window_function.rs | 32 +- datafusion/expr/src/expr.rs | 38 -- datafusion/expr/src/udwf.rs | 23 + datafusion/expr/src/window_function.rs | 34 -- .../functions-window-common/src/expr.rs | 64 +++ datafusion/functions-window-common/src/lib.rs | 1 + datafusion/functions-window/Cargo.toml | 1 + .../src}/lead_lag.rs | 392 ++++++++++++------ datafusion/functions-window/src/lib.rs | 8 + datafusion/functions-window/src/utils.rs | 53 +++ .../physical-expr/src/expressions/mod.rs | 1 - datafusion/physical-expr/src/window/mod.rs | 1 - datafusion/physical-plan/src/windows/mod.rs | 88 +--- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 30 +- datafusion/proto/src/generated/prost.rs | 14 +- .../proto/src/logical_plan/from_proto.rs | 17 +- datafusion/proto/src/logical_plan/to_proto.rs | 14 +- .../proto/src/physical_plan/to_proto.rs | 20 - .../tests/cases/roundtrip_logical_plan.rs | 12 +- datafusion/sqllogictest/test_files/union.slt | 8 +- datafusion/sqllogictest/test_files/window.slt | 56 ++- 24 files changed, 520 insertions(+), 407 deletions(-) create mode 100644 datafusion/functions-window-common/src/expr.rs rename datafusion/{physical-expr/src/window => functions-window/src}/lead_lag.rs (59%) create mode 100644 datafusion/functions-window/src/utils.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index aa64e14fca8e..dfd07a7658ff 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1445,6 +1445,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-functions-window-common", + "datafusion-physical-expr", "datafusion-physical-expr-common", "log", "paste", diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 4a33334770a0..d649919f1b6a 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf}; use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf}; use hashbrown::HashMap; use rand::distributions::Alphanumeric; @@ -197,7 +198,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::WindowUDF(lag_udwf()), // its name "LAG", // no argument @@ -211,7 +212,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::WindowUDF(lead_udwf()), // its name "LEAD", // no argument @@ -393,9 +394,7 @@ fn get_random_function( window_fn_map.insert( "lead", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lead, - ), + WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -406,9 +405,7 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lag, - ), + WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 6a30080fb38b..2c70a07a4e15 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::type_coercion::functions::data_types; use crate::utils; -use crate::{Signature, TypeSignature, Volatility}; +use crate::{Signature, Volatility}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use arrow::datatypes::DataType; @@ -44,17 +44,7 @@ pub enum BuiltInWindowFunction { CumeDist, /// Integer ranging from 1 to the argument value, dividing the partition as equally as possible Ntile, - /// Returns value evaluated at the row that is offset rows before the current row within the partition; - /// If there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// Returns value evaluated at the row that is offset rows after the current row within the partition; - /// If there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// Returns value evaluated at the row that is the first row of the window frame + /// returns value evaluated at the row that is the first row of the window frame FirstValue, /// Returns value evaluated at the row that is the last row of the window frame LastValue, @@ -68,8 +58,6 @@ impl BuiltInWindowFunction { match self { CumeDist => "CUME_DIST", Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", FirstValue => "first_value", LastValue => "last_value", NthValue => "NTH_VALUE", @@ -83,8 +71,6 @@ impl FromStr for BuiltInWindowFunction { Ok(match name.to_uppercase().as_str() { "CUME_DIST" => BuiltInWindowFunction::CumeDist, "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, @@ -117,9 +103,7 @@ impl BuiltInWindowFunction { match self { BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), BuiltInWindowFunction::CumeDist => Ok(DataType::Float64), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), } @@ -130,16 +114,6 @@ impl BuiltInWindowFunction { // Note: The physical expression must accept the type returned by this function or the execution panics. match self { BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e692189e488..f3f71a87278b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2560,30 +2560,6 @@ mod test { Ok(()) } - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); @@ -2621,8 +2597,6 @@ mod test { let names = vec![ "cume_dist", "ntile", - "lag", - "lead", "first_value", "last_value", "nth_value", @@ -2660,18 +2634,6 @@ mod test { built_in_window_function::BuiltInWindowFunction::LastValue )) ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lead - )) - ); assert_eq!(find_df_window_func("not_exist"), None) } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 6d8f2be97e02..6ab94c1e841a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -34,8 +34,10 @@ use crate::{ Signature, }; use datafusion_common::{not_impl_err, Result}; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. @@ -149,6 +151,12 @@ impl WindowUDF { self.inner.simplify() } + /// Expressions that are passed to the [`PartitionEvaluator`]. + /// + /// See [`WindowUDFImpl::expressions`] for more details. + pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + self.inner.expressions(expr_args) + } /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory( &self, @@ -302,6 +310,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; + /// Returns the expressions that are passed to the [`PartitionEvaluator`]. + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + /// Invoke the function, returning the [`PartitionEvaluator`] instance fn partition_evaluator( &self, @@ -480,6 +496,13 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + fn partition_evaluator( &self, partition_evaluator_args: PartitionEvaluatorArgs, diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 7ac6fb7d167c..3e1870c59c15 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::ScalarValue; - use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; /// Create an expression to represent the `cume_dist` window function @@ -29,38 +27,6 @@ pub fn ntile(arg: Expr) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) } -/// Create an expression to represent the `lag` window function -pub fn lag( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lag, - vec![arg, shift_offset_lit, default_lit], - )) -} - -/// Create an expression to represent the `lead` window function -pub fn lead( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lead, - vec![arg, shift_offset_lit, default_lit], - )) -} - /// Create an expression to represent the `nth_value` window function pub fn nth_value(arg: Expr, n: i64) -> Expr { Expr::WindowFunction(WindowFunction::new( diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs new file mode 100644 index 000000000000..1d99fe7acf15 --- /dev/null +++ b/datafusion/functions-window-common/src/expr.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to user-defined window function +#[derive(Debug, Default)] +pub struct ExpressionArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], +} + +impl<'a> ExpressionArgs<'a> { + /// Create an instance of [`ExpressionArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + ) -> Self { + Self { + input_exprs, + input_types, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } +} diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 53f9eb1c9ac6..da8d096da562 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -18,5 +18,6 @@ //! Common user-defined window functionality for [DataFusion] //! //! [DataFusion]: +pub mod expr; pub mod field; pub mod partition; diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 952e5720c77c..262c21fcec65 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -41,6 +41,7 @@ path = "src/lib.rs" datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs similarity index 59% rename from datafusion/physical-expr/src/window/lead_lag.rs rename to datafusion/functions-window/src/lead_lag.rs index 1656b7c3033a..f81521099751 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -15,125 +15,275 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `lead` and `lag` that can evaluated -//! at runtime during query execution -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +//! `lead` and `lag` window function implementations + +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::{ + Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, Volatility, + WindowUDFImpl, +}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; use std::ops::{Neg, Range}; use std::sync::Arc; -/// window shift expression +get_or_init_udwf!( + Lag, + lag, + "Returns the row value that precedes the current row by a specified \ + offset within partition. If no such row exists, then returns the \ + default value.", + WindowShift::lag +); +get_or_init_udwf!( + Lead, + lead, + "Returns the value from a row that follows the current row by a \ + specified offset within the partition. If no such row exists, then \ + returns the default value.", + WindowShift::lead +); + +/// Create an expression to represent the `lag` window function +/// +/// returns value evaluated at the row that is offset rows before the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lag( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lag_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +/// Create an expression to represent the `lead` window function +/// +/// returns value evaluated at the row that is offset rows after the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lead( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + #[derive(Debug)] -pub struct WindowShift { - name: String, - /// Output data type - data_type: DataType, - shift_offset: i64, - expr: Arc, - default_value: ScalarValue, - ignore_nulls: bool, +enum WindowShiftKind { + Lag, + Lead, } -impl WindowShift { - /// Get shift_offset of window shift expression - pub fn get_shift_offset(&self) -> i64 { - self.shift_offset +impl WindowShiftKind { + fn name(&self) -> &'static str { + match self { + WindowShiftKind::Lag => "lag", + WindowShiftKind::Lead => "lead", + } } - /// Get the default_value for window shift expression. - pub fn get_default_value(&self) -> ScalarValue { - self.default_value.clone() + /// In [`WindowShiftEvaluator`] a positive offset is used to signal + /// computation of `lag()`. So here we negate the input offset + /// value when computing `lead()`. + fn shift_offset(&self, value: Option) -> i64 { + match self { + WindowShiftKind::Lag => value.unwrap_or(1), + WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1), + } } } -/// lead() window function -pub fn lead( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), - expr, - default_value, - ignore_nulls, - } +/// window shift expression +#[derive(Debug)] +pub struct WindowShift { + signature: Signature, + kind: WindowShiftKind, } -/// lag() window function -pub fn lag( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.unwrap_or(1), - expr, - default_value, - ignore_nulls, +impl WindowShift { + fn new(kind: WindowShiftKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + kind, + } + } + + pub fn lag() -> Self { + Self::new(WindowShiftKind::Lag) + } + + pub fn lead() -> Self { + Self::new(WindowShiftKind::Lead) } } -impl BuiltInWindowFunctionExpr for WindowShift { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for WindowShift { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + fn name(&self) -> &str { + self.kind.name() } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name + /// Handles the case where `NULL` expression is passed as an + /// argument to `lead`/`lag`. The type is refined depending + /// on the default value argument. + /// + /// For more details see: + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + parse_expr(expr_args.input_exprs(), expr_args.input_types()) + .into_iter() + .collect::>() } - fn create_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let shift_offset = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some)) + .map(|n| self.kind.shift_offset(n)) + .map(|offset| { + if partition_evaluator_args.is_reversed() { + -offset + } else { + offset + } + })?; + let default_value = parse_default_value( + partition_evaluator_args.input_exprs(), + partition_evaluator_args.input_types(), + )?; + Ok(Box::new(WindowShiftEvaluator { - shift_offset: self.shift_offset, - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, + shift_offset, + default_value, + ignore_nulls: partition_evaluator_args.ignore_nulls(), non_null_offsets: VecDeque::new(), })) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.clone(), - data_type: self.data_type.clone(), - shift_offset: -self.shift_offset, - expr: Arc::clone(&self.expr), - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, - })) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = parse_expr_type(field_args.input_types())?; + + Ok(Field::new(field_args.name(), return_type, true)) } + + fn reverse_expr(&self) -> ReversedUDWF { + match self.kind { + WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()), + WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()), + } + } +} + +/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to +/// refine it by matching it with the type of the default value. +/// +/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null` +/// is refined into `ScalarValue::Boolean(None)`. Only the type is +/// refined, the expression value remains `NULL`. +/// +/// When the window function is evaluated with `NULL` expression +/// this guarantees that the type matches with that of the default +/// value. +/// +/// For more details see: +fn parse_expr( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result> { + assert!(!input_exprs.is_empty()); + assert!(!input_types.is_empty()); + + let expr = Arc::clone(input_exprs.first().unwrap()); + let expr_type = input_types.first().unwrap(); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr); + } + + let default_value = get_scalar_value_from_args(input_exprs, 2)?; + default_value.map_or(Ok(expr), |value| { + ScalarValue::try_from(&value.data_type()).map(|v| { + Arc::new(datafusion_physical_expr::expressions::Literal::new(v)) + as Arc + }) + }) +} + +/// Returns the data type of the default value(if provided) when the +/// expression is `NULL`. +/// +/// Otherwise, returns the expression type unchanged. +fn parse_expr_type(input_types: &[DataType]) -> Result { + assert!(!input_types.is_empty()); + let expr_type = input_types.first().unwrap_or(&DataType::Null); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr_type.clone()); + } + + let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); + Ok(default_value_type.clone()) +} + +/// Handles type coercion and null value refinement for default value +/// argument depending on the data type of the input expression. +fn parse_default_value( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result { + let expr_type = parse_expr_type(input_types)?; + let unparsed = get_scalar_value_from_args(input_exprs, 2)?; + + unparsed + .filter(|v| !v.data_type().is_null()) + .map(|v| v.cast_to(&expr_type)) + .unwrap_or(ScalarValue::try_from(expr_type)) } #[derive(Debug)] -pub(crate) struct WindowShiftEvaluator { +struct WindowShiftEvaluator { shift_offset: i64, default_value: ScalarValue, ignore_nulls: bool, @@ -205,7 +355,7 @@ fn shift_with_default_value( offset: i64, default_value: &ScalarValue, ) -> Result { - use arrow::compute::concat; + use datafusion_common::arrow::compute::concat; let value_len = array.len() as i64; if offset == 0 { @@ -402,19 +552,22 @@ impl PartitionEvaluator for WindowShiftEvaluator { #[cfg(test)] mod tests { use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; + use arrow::array::*; use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + fn test_i32_result( + expr: WindowShift, + partition_evaluator_args: PartitionEvaluatorArgs, + expected: Int32Array, + ) -> Result<()> { let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let values = expr.evaluate_args(&batch)?; + let num_rows = values.len(); let result = expr - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; + .partition_evaluator(partition_evaluator_args)? + .evaluate_all(&values, num_rows)?; let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) @@ -466,16 +619,12 @@ mod tests { } #[test] - fn lead_lag_window_shift() -> Result<()> { + fn test_lead_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_result( - lead( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lead(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ Some(-2), Some(3), @@ -488,17 +637,16 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ None, Some(1), @@ -511,17 +659,24 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_with_default() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let shift_offset = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100)))) + as Arc; + + let input_exprs = &[expr, shift_offset, default_value]; + let input_types: &[DataType] = + &[DataType::Int32, DataType::Int32, DataType::Int32]; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Int32(Some(100)), - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), [ Some(100), Some(1), @@ -534,7 +689,6 @@ mod tests { ] .iter() .collect::(), - )?; - Ok(()) + ) } } diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index ef624e13e61c..5a2aafa2892e 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -31,11 +31,17 @@ use datafusion_expr::WindowUDF; #[macro_use] pub mod macros; + +pub mod lead_lag; + pub mod rank; pub mod row_number; +mod utils; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::lead_lag::lag; + pub use super::lead_lag::lead; pub use super::rank::{dense_rank, percent_rank, rank}; pub use super::row_number::row_number; } @@ -44,6 +50,8 @@ pub mod expr_fn { pub fn all_default_window_functions() -> Vec> { vec![ row_number::row_number_udwf(), + lead_lag::lead_udwf(), + lead_lag::lag_udwf(), rank::rank_udwf(), rank::dense_rank_udwf(), rank::percent_rank_udwf(), diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs new file mode 100644 index 000000000000..69f68aa78f2c --- /dev/null +++ b/datafusion/functions-window/src/utils.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +pub(crate) fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::Int64)?.try_into() +} + +pub(crate) fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Result> { + Ok(if let Some(field) = args.get(index) { + let tmp = field + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::NotImplemented( + format!("There is only support Literal types for field at idx: {index} in Window Function"), + ))? + .value() + .clone(); + Some(tmp) + } else { + None + }) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e07e11e43199..54b8aafdb4da 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -36,7 +36,6 @@ mod unknown_column; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; pub use crate::window::cume_dist::{cume_dist, CumeDist}; -pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; pub use crate::window::ntile::Ntile; pub use crate::PhysicalSortExpr; diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 938bdac50f97..c0fe3c2933a7 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -19,7 +19,6 @@ mod aggregate; mod built_in; mod built_in_window_function_expr; pub(crate) mod cume_dist; -pub(crate) mod lead_lag; pub(crate) mod nth_value; pub(crate) mod ntile; mod sliding_aggregate; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index e6a773f6b1ea..adf61f27bc6f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,7 +21,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::{cume_dist, lag, lead, Literal, NthValue, Ntile, PhysicalSortExpr}, + expressions::{cume_dist, Literal, NthValue, Ntile, PhysicalSortExpr}, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; @@ -48,6 +48,7 @@ mod utils; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::expressions::Column; @@ -206,52 +207,6 @@ fn get_unsigned_integer(value: ScalarValue) -> Result { value.cast_to(&DataType::UInt64)?.try_into() } -fn get_casted_value( - default_value: Option, - dtype: &DataType, -) -> Result { - match default_value { - Some(v) if !v.data_type().is_null() => v.cast_to(dtype), - // If None or Null datatype - _ => ScalarValue::try_from(dtype), - } -} - -/// Rewrites the NULL expression (1st argument) with an expression -/// which is the same data type as the default value (3rd argument). -/// Also rewrites the return type with the same data type as the -/// default value. -/// -/// If a default value is not provided, or it is NULL the original -/// expression (1st argument) and return type is returned without -/// any modifications. -fn rewrite_null_expr_and_data_type( - args: &[Arc], - expr_type: &DataType, -) -> Result<(Arc, DataType)> { - assert!(!args.is_empty()); - let expr = Arc::clone(&args[0]); - - // The input expression and the return is type is unchanged - // when the input expression is not NULL. - if !expr_type.is_null() { - return Ok((expr, expr_type.clone())); - } - - get_scalar_value_from_args(args, 2)? - .and_then(|value| { - ScalarValue::try_from(value.data_type().clone()) - .map(|sv| { - Ok(( - Arc::new(Literal::new(sv)) as Arc, - value.data_type().clone(), - )) - }) - .ok() - }) - .unwrap_or(Ok((expr, expr_type.clone()))) -} - fn create_built_in_window_expr( fun: &BuiltInWindowFunction, args: &[Arc], @@ -286,42 +241,6 @@ fn create_built_in_window_expr( Arc::new(Ntile::new(name, n as u64, out_data_type)) } } - BuiltInWindowFunction::Lag => { - // rewrite NULL expression and the return datatype - let (arg, out_data_type) = - rewrite_null_expr_and_data_type(args, out_data_type)?; - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(get_signed_integer) - .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; - Arc::new(lag( - name, - default_value.data_type().clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } - BuiltInWindowFunction::Lead => { - // rewrite NULL expression and the return datatype - let (arg, out_data_type) = - rewrite_null_expr_and_data_type(args, out_data_type)?; - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(get_signed_integer) - .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; - Arc::new(lead( - name, - default_value.data_type().clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } BuiltInWindowFunction::NthValue => { let arg = Arc::clone(&args[0]); let n = get_signed_integer( @@ -415,7 +334,8 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn expressions(&self) -> Vec> { - self.args.clone() + self.fun + .expressions(ExpressionArgs::new(&self.args, &self.input_types)) } fn create_evaluator(&self) -> Result> { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5256f7473c95..9964ab498fb1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -515,8 +515,8 @@ enum BuiltInWindowFunction { // PERCENT_RANK = 3; CUME_DIST = 4; NTILE = 5; - LAG = 6; - LEAD = 7; + // LAG = 6; + // LEAD = 7; FIRST_VALUE = 8; LAST_VALUE = 9; NTH_VALUE = 10; @@ -528,7 +528,7 @@ message WindowExprNode { string udaf = 3; string udwf = 9; } - LogicalExprNode expr = 4; + repeated LogicalExprNode exprs = 4; repeated LogicalExprNode partition_by = 5; repeated SortExprNode order_by = 6; // repeated LogicalExprNode filter = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e876008e853f..4417d1149681 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1664,8 +1664,6 @@ impl serde::Serialize for BuiltInWindowFunction { Self::Unspecified => "UNSPECIFIED", Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1683,8 +1681,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { "UNSPECIFIED", "CUME_DIST", "NTILE", - "LAG", - "LEAD", "FIRST_VALUE", "LAST_VALUE", "NTH_VALUE", @@ -1731,8 +1727,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), "CUME_DIST" => Ok(BuiltInWindowFunction::CumeDist), "NTILE" => Ok(BuiltInWindowFunction::Ntile), - "LAG" => Ok(BuiltInWindowFunction::Lag), - "LEAD" => Ok(BuiltInWindowFunction::Lead), "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), "NTH_VALUE" => Ok(BuiltInWindowFunction::NthValue), @@ -21157,7 +21151,7 @@ impl serde::Serialize for WindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.exprs.is_empty() { len += 1; } if !self.partition_by.is_empty() { @@ -21176,8 +21170,8 @@ impl serde::Serialize for WindowExprNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.exprs.is_empty() { + struct_ser.serialize_field("exprs", &self.exprs)?; } if !self.partition_by.is_empty() { struct_ser.serialize_field("partitionBy", &self.partition_by)?; @@ -21218,7 +21212,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "exprs", "partition_by", "partitionBy", "order_by", @@ -21235,7 +21229,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Exprs, PartitionBy, OrderBy, WindowFrame, @@ -21264,7 +21258,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "exprs" => Ok(GeneratedField::Exprs), "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), @@ -21291,7 +21285,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut exprs__ = None; let mut partition_by__ = None; let mut order_by__ = None; let mut window_frame__ = None; @@ -21299,11 +21293,11 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Exprs => { + if exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("exprs")); } - expr__ = map_.next_value()?; + exprs__ = Some(map_.next_value()?); } GeneratedField::PartitionBy => { if partition_by__.is_some() { @@ -21352,7 +21346,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } } Ok(WindowExprNode { - expr: expr__, + exprs: exprs__.unwrap_or_default(), partition_by: partition_by__.unwrap_or_default(), order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2aa14f7e80b0..d3fe031a48c9 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -538,7 +538,7 @@ pub mod logical_expr_node { TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "18")] - WindowExpr(::prost::alloc::boxed::Box), + WindowExpr(super::WindowExprNode), /// AggregateUDF expressions #[prost(message, tag = "19")] AggregateUdfExpr(::prost::alloc::boxed::Box), @@ -735,8 +735,8 @@ pub struct ScalarUdfExprNode { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowExprNode { - #[prost(message, optional, boxed, tag = "4")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub exprs: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] @@ -1828,8 +1828,8 @@ pub enum BuiltInWindowFunction { /// PERCENT_RANK = 3; CumeDist = 4, Ntile = 5, - Lag = 6, - Lead = 7, + /// LAG = 6; + /// LEAD = 7; FirstValue = 8, LastValue = 9, NthValue = 10, @@ -1844,8 +1844,6 @@ impl BuiltInWindowFunction { Self::Unspecified => "UNSPECIFIED", Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1857,8 +1855,6 @@ impl BuiltInWindowFunction { "UNSPECIFIED" => Some(Self::Unspecified), "CUME_DIST" => Some(Self::CumeDist), "NTILE" => Some(Self::Ntile), - "LAG" => Some(Self::Lag), - "LEAD" => Some(Self::Lead), "FIRST_VALUE" => Some(Self::FirstValue), "LAST_VALUE" => Some(Self::LastValue), "NTH_VALUE" => Some(Self::NthValue), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 32e1b68203ce..20d007048a00 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -142,8 +142,6 @@ impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { protobuf::BuiltInWindowFunction::Unspecified => todo!(), - protobuf::BuiltInWindowFunction::Lag => Self::Lag, - protobuf::BuiltInWindowFunction::Lead => Self::Lead, protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, protobuf::BuiltInWindowFunction::CumeDist => Self::CumeDist, protobuf::BuiltInWindowFunction::Ntile => Self::Ntile, @@ -286,10 +284,7 @@ pub fn parse_expr( .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::BuiltInWindowFunction( @@ -309,10 +304,7 @@ pub fn parse_expr( None => registry.udaf(udaf_name)?, }; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, @@ -329,10 +321,7 @@ pub fn parse_expr( None => registry.udwf(udwf_name)?, }; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 07823b422f71..15fec3a8b2a8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -119,8 +119,6 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { BuiltInWindowFunction::NthValue => Self::NthValue, BuiltInWindowFunction::Ntile => Self::Ntile, BuiltInWindowFunction::CumeDist => Self::CumeDist, - BuiltInWindowFunction::Lag => Self::Lag, - BuiltInWindowFunction::Lead => Self::Lead, } } } @@ -333,25 +331,19 @@ pub fn serialize_expr( ) } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(serialize_expr(arg, codec)?)) - } else { - None - }; let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, + let window_expr = protobuf::WindowExprNode { + exprs: serialize_exprs(args, codec)?, window_function: Some(window_function), partition_by, order_by, window_frame, fun_definition, - }); + }; protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 85d4fe8a16d0..6072baca688c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,7 +25,6 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, TryCastExpr, - WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -119,25 +118,6 @@ pub fn serialize_physical_window_expr( )))), ); protobuf::BuiltInWindowFunction::Ntile - } else if let Some(window_shift_expr) = - built_in_fn_expr.downcast_ref::() - { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - window_shift_expr.get_shift_offset(), - )))), - ); - args.insert( - 2, - Arc::new(Literal::new(window_shift_expr.get_default_value())), - ); - - if window_shift_expr.get_shift_offset() >= 0 { - protobuf::BuiltInWindowFunction::Lag - } else { - protobuf::BuiltInWindowFunction::Lead - } } else if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { match nth_value_expr.get_kind() { NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ffa8fc1eefe9..c017395d979f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -47,8 +47,10 @@ use datafusion::functions_aggregate::expr_fn::{ }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; -use datafusion::functions_window::rank::{dense_rank, percent_rank, rank, rank_udwf}; -use datafusion::functions_window::row_number::row_number; +use datafusion::functions_window::expr_fn::{ + dense_rank, lag, lead, percent_rank, rank, row_number, +}; +use datafusion::functions_window::rank::rank_udwf; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -942,6 +944,12 @@ async fn roundtrip_expr_api() -> Result<()> { rank(), dense_rank(), percent_rank(), + lead(col("b"), None, None), + lead(col("b"), Some(2), None), + lead(col("b"), Some(2), Some(ScalarValue::from(100))), + lag(col("b"), None, None), + lag(col("b"), Some(2), None), + lag(col("b"), Some(2), Some(ScalarValue::from(100))), nth_value(col("b"), 1, vec![]), nth_value( col("b"), diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index a3d0ff4383ae..fb7afdda2ea8 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -503,9 +503,9 @@ logical_plan 12)----Projection: Int64(1) AS cnt 13)------Limit: skip=0, fetch=3 14)--------EmptyRelation -15)----Projection: LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt +15)----Projection: lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt 16)------Limit: skip=0, fetch=3 -17)--------WindowAggr: windowExpr=[[LEAD(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +17)--------WindowAggr: windowExpr=[[lead(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 18)----------SubqueryAlias: b 19)------------Projection: Int64(1) AS c1 20)--------------EmptyRelation @@ -528,8 +528,8 @@ physical_plan 16)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true 17)------ProjectionExec: expr=[1 as cnt] 18)--------PlaceholderRowExec -19)------ProjectionExec: expr=[LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -20)--------BoundedWindowAggExec: wdw=[LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +19)------ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 21)----------ProjectionExec: expr=[1 as c1] 22)------------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 1b612f921262..b3f2786d3dba 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1376,16 +1376,16 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 +01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] +01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2636,15 +2636,15 @@ EXPLAIN SELECT ---- logical_plan 01)Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 -03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] +02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIIII @@ -4971,6 +4971,26 @@ SELECT LAG(NULL, 1, false) OVER () FROM t1; ---- false +query B +SELECT LEAD(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true + statement ok insert into t1 values (2); @@ -4986,6 +5006,18 @@ SELECT LAG(NULL, 1, false) OVER () FROM t1; false NULL +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +NULL +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true +NULL + statement ok DROP TABLE t1; From 24148bd65fdf61fba340b69dc87a7920850cb19f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 18 Oct 2024 13:28:03 +0200 Subject: [PATCH 02/17] Add links to new_constraint_from_table_constraints doc (#12995) --- datafusion/sql/src/statement.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 4109f1371187..60e3413b836f 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1263,7 +1263,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - /// Convert each `TableConstraint` to corresponding `Constraint` + /// Convert each [TableConstraint] to corresponding [Constraint] fn new_constraint_from_table_constraints( constraints: &[TableConstraint], df_schema: &DFSchemaRef, From 87e931c976a7aa24cecaa9bf3658b42bba12a51e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alihan=20=C3=87elikcan?= Date: Fri, 18 Oct 2024 14:34:42 +0300 Subject: [PATCH 03/17] Split output batches of joins that do not respect batch size (#12969) * Add BatchSplitter to joins that do not respect batch size * Group relevant imports * Update configs.md * Update SQL logic tests for config * Review * Use PrimitiveBuilder for PrimitiveArray concatenation * Fix into_builder() bug * Apply suggestions from code review Co-authored-by: Andrew Lamb * Update config docs * Format * Update config SQL Logic Test --------- Co-authored-by: Mehmet Ozan Kabak Co-authored-by: Andrew Lamb --- datafusion/common/src/config.rs | 26 +- datafusion/execution/src/config.rs | 14 + .../physical-plan/src/joins/cross_join.rs | 84 +++-- .../physical-plan/src/joins/hash_join.rs | 2 +- .../src/joins/nested_loop_join.rs | 356 ++++++++++++------ .../src/joins/stream_join_utils.rs | 83 ++-- .../src/joins/symmetric_hash_join.rs | 252 +++++++------ datafusion/physical-plan/src/joins/utils.rs | 220 +++++++++-- .../test_files/information_schema.slt | 2 + docs/source/user-guide/configs.md | 1 + 10 files changed, 709 insertions(+), 331 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1e1c5d5424b0..47ffe0b1c66b 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -338,6 +338,12 @@ config_namespace! { /// if the source of statistics is accurate. /// We plan to make this the default in the future. pub use_row_number_estimates_to_optimize_partitioning: bool, default = false + + /// Should DataFusion enforce batch size in joins or not. By default, + /// DataFusion will not enforce batch size in joins. Enforcing batch size + /// in joins can reduce memory usage when joining large + /// tables with a highly-selective join filter, but is also slightly slower. + pub enforce_batch_size_in_joins: bool, default = false } } @@ -1222,16 +1228,18 @@ impl ConfigField for TableOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); - let Some(format) = &self.current_format else { - return _config_err!("Specify a format for TableOptions"); - }; match key { - "format" => match format { - #[cfg(feature = "parquet")] - ConfigFileType::PARQUET => self.parquet.set(rem, value), - ConfigFileType::CSV => self.csv.set(rem, value), - ConfigFileType::JSON => self.json.set(rem, value), - }, + "format" => { + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; + match format { + #[cfg(feature = "parquet")] + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), + } + } _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cede75d21ca4..53646dc5b468 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -432,6 +432,20 @@ impl SessionConfig { self } + /// Enables or disables the enforcement of batch size in joins + pub fn with_enforce_batch_size_in_joins( + mut self, + enforce_batch_size_in_joins: bool, + ) -> Self { + self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self + } + + /// Returns true if the joins will be enforced to output batches of the configured size + pub fn enforce_batch_size_in_joins(&self) -> bool { + self.options.execution.enforce_batch_size_in_joins + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index a70645f3d6c0..8f2bef56da76 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -19,7 +19,8 @@ //! and producing batches in parallel for the right partitions use super::utils::{ - adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + adjust_right_output_partitioning, BatchSplitter, BatchTransformer, + BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -86,6 +87,7 @@ impl CrossJoinExec { let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + CrossJoinExec { left, right, @@ -246,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec { let reservation = MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let left_fut = self.left_fut.once(|| { load_left_input( Arc::clone(&self.left), @@ -255,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec { ) }); - Ok(Box::pin(CrossJoinStream { - schema: Arc::clone(&self.schema), - left_fut, - right: stream, - left_index: 0, - join_metrics, - state: CrossJoinStreamState::WaitBuildSide, - left_data: RecordBatch::new_empty(self.left().schema()), - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: NoopBatchTransformer::new(), + })) + } } fn statistics(&self) -> Result { @@ -319,7 +339,7 @@ fn stats_cartesian_product( } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct CrossJoinStream { +struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side @@ -334,9 +354,11 @@ struct CrossJoinStream { state: CrossJoinStreamState, /// Left data left_data: RecordBatch, + /// Batch transformer + batch_transformer: T, } -impl RecordBatchStream for CrossJoinStream { +impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -390,7 +412,7 @@ fn build_batch( } #[async_trait] -impl Stream for CrossJoinStream { +impl Stream for CrossJoinStream { type Item = Result; fn poll_next( @@ -401,7 +423,7 @@ impl Stream for CrossJoinStream { } } -impl CrossJoinStream { +impl CrossJoinStream { /// Separate implementation function that unpins the [`CrossJoinStream`] so /// that partial borrows work correctly fn poll_next_impl( @@ -470,21 +492,33 @@ impl CrossJoinStream { fn build_batches(&mut self) -> Result>> { let right_batch = self.state.try_as_record_batch()?; if self.left_index < self.left_data.num_rows() { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, right_batch, &self.left_data, &self.schema); - join_timer.done(); - - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + match self.batch_transformer.next() { + None => { + let join_timer = self.join_metrics.join_time.timer(); + let result = build_batch( + self.left_index, + right_batch, + &self.left_data, + &self.schema, + ); + join_timer.done(); + + self.batch_transformer.set_batch(result?); + } + Some((batch, last)) => { + if last { + self.left_index += 1; + } + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } - self.left_index += 1; - result.map(|r| StatefulStreamResult::Ready(Some(r))) } else { self.state = CrossJoinStreamState::FetchProbeBatch; - Ok(StatefulStreamResult::Continue) } + Ok(StatefulStreamResult::Continue) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 74a45a7e4761..3b730c01291c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1438,7 +1438,7 @@ impl HashJoinStream { index_alignment_range_start..index_alignment_range_end, self.join_type, self.right_side_ordered, - ); + )?; let result = build_batch_from_indices( &self.schema, diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 6068e7526316..358ff02473a6 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,7 +25,10 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; +use super::utils::{ + asymmetric_join_output_partitioning, need_produce_result_in_final, BatchSplitter, + BatchTransformer, NoopBatchTransformer, StatefulStreamResult, +}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -35,8 +38,8 @@ use crate::joins::utils::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType, + Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -45,7 +48,9 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics}; +use datafusion_common::{ + exec_datafusion_err, internal_err, JoinSide, Result, Statistics, +}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; @@ -230,10 +235,11 @@ impl NestedLoopJoinExec { asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - mode = ExecutionMode::PipelineBreaking; - } + let mode = if left.execution_mode().is_unbounded() { + ExecutionMode::PipelineBreaking + } else { + execution_mode_from_children([left, right]) + }; PlanProperties::new(eq_properties, output_partitioning, mode) } @@ -345,6 +351,10 @@ impl ExecutionPlan for NestedLoopJoinExec { ) }); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -352,18 +362,38 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); - Ok(Box::pin(NestedLoopJoinStream { - schema: Arc::clone(&self.schema), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - is_exhausted: false, - column_indices: self.column_indices.clone(), - join_metrics, - indices_cache, - right_side_ordered, - })) + + if enforce_batch_size_in_joins { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: BatchSplitter::new(batch_size), + left_data: None, + })) + } else { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: NoopBatchTransformer::new(), + left_data: None, + })) + } } fn metrics(&self) -> Option { @@ -442,8 +472,37 @@ async fn collect_left_input( )) } +/// This enumeration represents various states of the nested loop join algorithm. +#[derive(Debug, Clone)] +enum NestedLoopJoinStreamState { + /// The initial state, indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for + /// fetching probe-side + FetchProbeBatch, + /// Indicates that a non-empty batch has been fetched from probe-side, and + /// is ready to be processed + ProcessProbeBatch(RecordBatch), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that NestedLoopJoinStream execution is completed + Completed, +} + +impl NestedLoopJoinStreamState { + /// Tries to extract a `ProcessProbeBatchState` from the + /// `NestedLoopJoinStreamState` enum. Returns an error if state is not + /// `ProcessProbeBatchState`. + fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { + match self { + NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected join stream in ProcessProbeBatch state"), + } + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { +struct NestedLoopJoinStream { /// Input schema schema: Arc, /// join filter @@ -454,8 +513,6 @@ struct NestedLoopJoinStream { outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join inner_table: OnceFut, - /// There is nothing to process anymore and left side is processed in case of full join - is_exhausted: bool, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal @@ -466,6 +523,12 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// Current state of the stream + state: NestedLoopJoinStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, + /// Result of the left data future + left_data: Option>, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -544,107 +607,164 @@ fn build_join_indices( } } -impl NestedLoopJoinStream { +impl NestedLoopJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - // all left row + loop { + return match self.state { + NestedLoopJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + NestedLoopJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + NestedLoopJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + NestedLoopJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + NestedLoopJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.inner_table.get_shared(cx)) { - Ok(data) => data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); build_timer.done(); - // Get or initialize visited_left_side bitmap if required by join type + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If a non-empty batch has been fetched, updates state to + /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.outer_table.poll_next_unpin(cx)) { + None => { + self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(right_batch)) => { + self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with + /// matched output, updates state to `FetchProbeBatch`. + fn process_probe_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ProcessProbeBatch state" + ); + }; let visited_left_side = left_data.bitmap(); + let batch = self.state.try_as_process_probe_batch()?; + + match self.batch_transformer.next() { + None => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + timer.done(); + + self.batch_transformer.set_batch(result?); + Ok(StatefulStreamResult::Continue) + } + Some((batch, last)) => { + if last { + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + } - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Ok(StatefulStreamResult::Ready(Some(batch))) + } } + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return None; - }; - - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } else { - // end of the join loop - None - } - } - }) + /// Processes unmatched build-side rows for certain join types and produces + /// output batch, updates state to `Completed`. + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ExhaustedProbeSide state" + ); + }; + let visited_left_side = left_data.bitmap(); + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` / returning None will prevent from + // multiple calls of `report_probe_completed()` + if !left_data.report_probe_completed() { + self.state = NestedLoopJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.state = NestedLoopJoinStreamState::Completed; + + // Recording time + if result.is_ok() { + timer.done(); + } + + Ok(StatefulStreamResult::Ready(Some(result?))) + } else { + // end of the join loop + self.state = NestedLoopJoinStreamState::Completed; + Ok(StatefulStreamResult::Ready(None)) + } } } @@ -684,7 +804,7 @@ fn join_left_and_right_batch( 0..right_batch.num_rows(), join_type, right_side_ordered, - ); + )?; build_batch_from_indices( schema, @@ -705,7 +825,7 @@ fn get_final_indices_from_shared_bitmap( get_final_indices_from_bit_map(&bitmap, join_type) } -impl Stream for NestedLoopJoinStream { +impl Stream for NestedLoopJoinStream { type Item = Result; fn poll_next( @@ -716,14 +836,14 @@ impl Stream for NestedLoopJoinStream { } } -impl RecordBatchStream for NestedLoopJoinStream { +impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -850,7 +970,7 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( + pub(crate) async fn multi_partitioned_join_collect( left: Arc, right: Arc, join_type: &JoinType, diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index ba9384aef1a6..bddd152341da 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -31,8 +31,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, - ScalarValue, + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -369,34 +368,40 @@ impl SortedFilterExpr { filter_expr: Arc, filter_schema: &Schema, ) -> Result { - let dt = &filter_expr.data_type(filter_schema)?; + let dt = filter_expr.data_type(filter_schema)?; Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::make_unbounded(dt)?, + interval: Interval::make_unbounded(&dt)?, node_index: 0, }) } + /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { &self.origin_sorted_expr } + /// Get filter expr information pub fn filter_expr(&self) -> &Arc { &self.filter_expr } + /// Get interval information pub fn interval(&self) -> &Interval { &self.interval } + /// Sets interval pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph pub fn node_index(&self) -> usize { self.node_index } + /// Node index setter in ExprIntervalGraph pub fn set_node_index(&mut self, node_index: usize) { self.node_index = node_index; @@ -409,41 +414,45 @@ impl SortedFilterExpr { /// on the first or the last value of the expression in `build_input_buffer` /// and `probe_batch`. /// -/// # Arguments +/// # Parameters /// /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. /// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// -/// ### Note -/// ```text +/// ## Note /// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. +/// Utilizing interval arithmetic, this function computes feasible join intervals +/// on the pruning side by evaluating the prospective value ranges that might +/// emerge in subsequent data batches from the enforcer side. This is done by +/// first creating an interval for join filter values in the pruning side of the +/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/ +/// ascending) of the filter expression. Here, `FV` denotes the first value on the +/// pruning side. This range is then compared with the enforcer side interval, +/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/ +/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer +/// side. /// /// As a concrete example, consider the following query: /// +/// ```text /// SELECT * FROM left_table, right_table /// WHERE /// left_key = right_key AND /// a > b - 3 AND /// a < b + 10 +/// ``` /// -/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// where columns `a` and `b` come from tables `left_table` and `right_table`, /// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left +/// condition `a > b - 3` will possibly indicate a prunable range for the left /// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// condition `a < b + 10` will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new `RecordBatch` arrives at the right /// side (i.e. when the left side is the build side): /// +/// ```text /// Build Probe /// +-------+ +-------+ /// | a | z | | b | y | @@ -456,13 +465,13 @@ impl SortedFilterExpr { /// |+--|--+| |+--|--+| /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ +/// ``` /// /// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// column `a` is `[1, ∞]`, and the interval representing possible future values +/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate /// intervals for the whole filter expression and propagate join constraint by /// traversing the expression graph. -/// ``` pub fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices( } } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// Prepares and sorts expressions based on a given filter, left and right schemas, +/// and sort expressions. /// -/// # Arguments +/// This function prepares sorted filter expressions for both the left and right +/// sides of a join operation. It first builds the filter order for each side +/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter +/// expressions, the function then constructs an expression interval graph and +/// updates the sorted expressions with node indices. The final sorted filter +/// expressions for both sides are then returned. +/// +/// # Parameters /// /// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. +/// * `left` - The `ExecutionPlan` for the left side of the join. +/// * `right` - The `ExecutionPlan` for the right side of the join. /// * `left_sort_exprs` - The expressions to sort on the left side. /// * `right_sort_exprs` - The expressions to sort on the right side. /// @@ -730,9 +747,11 @@ pub fn prepare_sorted_exprs( left_sort_exprs: &[PhysicalSortExpr], right_sort_exprs: &[PhysicalSortExpr], ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); + let err = || { + datafusion_common::plan_datafusion_err!("Filter does not include the child order") + }; + // Build the filter order for the left side: let left_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Left, filter, @@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs( )? .ok_or_else(err)?; - // Build the filter order for the right side + // Build the filter order for the right side: let right_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Right, filter, @@ -952,15 +971,15 @@ pub mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index ac718a95e9f4..70ada3892aea 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -32,7 +32,6 @@ use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; -use crate::handle_state; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, @@ -42,8 +41,9 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, + BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, + NoopBatchTransformer, StatefulStreamResult, }; use crate::{ execution_mode_from_children, @@ -465,23 +465,27 @@ impl ExecutionPlan for SymmetricHashJoinExec { consider using RepartitionExec" ); } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + // If `filter_state` and `filter` are both present, then calculate sorted + // filter expressions for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left_sort_exprs(), + self.right_sort_exprs(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None + // for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); @@ -494,6 +498,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) .register(context.memory_pool()), @@ -502,29 +510,52 @@ impl ExecutionPlan for SymmetricHashJoinExec { reservation.lock().try_grow(g.size())?; } - Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: StreamJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - state: SHJStreamState::PullRight, - reservation, - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: NoopBatchTransformer::new(), + })) + } } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { +struct SymmetricHashJoinStream { /// Input streams left_stream: SendableRecordBatchStream, right_stream: SendableRecordBatchStream, @@ -556,15 +587,19 @@ struct SymmetricHashJoinStream { reservation: SharedMemoryReservation, /// State machine for input execution state: SHJStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, } -impl RecordBatchStream for SymmetricHashJoinStream { +impl RecordBatchStream + for SymmetricHashJoinStream +{ fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl Stream for SymmetricHashJoinStream { +impl Stream for SymmetricHashJoinStream { type Item = Result; fn poll_next( @@ -1140,7 +1175,7 @@ impl OneSideHashJoiner { /// - Transition to `BothExhausted { final_result: true }`: /// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are /// exhausted, indicating completion of processing and availability of final results. -impl SymmetricHashJoinStream { +impl SymmetricHashJoinStream { /// Implements the main polling logic for the join stream. /// /// This method continuously checks the state of the join stream and @@ -1159,26 +1194,45 @@ impl SymmetricHashJoinStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - return match self.state() { - SHJStreamState::PullRight => { - handle_state!(ready!(self.fetch_next_from_right_stream(cx))) - } - SHJStreamState::PullLeft => { - handle_state!(ready!(self.fetch_next_from_left_stream(cx))) + match self.batch_transformer.next() { + None => { + let result = match self.state() { + SHJStreamState::PullRight => { + ready!(self.fetch_next_from_right_stream(cx)) + } + SHJStreamState::PullLeft => { + ready!(self.fetch_next_from_left_stream(cx)) + } + SHJStreamState::RightExhausted => { + ready!(self.handle_right_stream_end(cx)) + } + SHJStreamState::LeftExhausted => { + ready!(self.handle_left_stream_end(cx)) + } + SHJStreamState::BothExhausted { + final_result: false, + } => self.prepare_for_final_results_after_exhaustion(), + SHJStreamState::BothExhausted { final_result: true } => { + return Poll::Ready(None); + } + }; + + match result? { + StatefulStreamResult::Ready(None) => { + return Poll::Ready(None); + } + StatefulStreamResult::Ready(Some(batch)) => { + self.batch_transformer.set_batch(batch); + } + _ => {} + } } - SHJStreamState::RightExhausted => { - handle_state!(ready!(self.handle_right_stream_end(cx))) - } - SHJStreamState::LeftExhausted => { - handle_state!(ready!(self.handle_left_stream_end(cx))) - } - SHJStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) + Some((batch, _)) => { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); } - SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None), - }; + } } } /// Asynchronously pulls the next batch from the right stream. @@ -1384,11 +1438,8 @@ impl SymmetricHashJoinStream { // Combine the left and right results: let result = combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); + // Return the result: + if result.is_some() { return Ok(StatefulStreamResult::Ready(result)); } Ok(StatefulStreamResult::Continue) @@ -1523,11 +1574,6 @@ impl SymmetricHashJoinStream { let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } Ok(result) } } @@ -1716,15 +1762,15 @@ mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; @@ -1771,10 +1817,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1825,10 +1868,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1877,10 +1917,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) } @@ -1926,10 +1963,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1987,10 +2021,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2048,10 +2079,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2111,10 +2139,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2170,10 +2195,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2237,10 +2259,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2296,10 +2315,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { @@ -2380,10 +2396,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { @@ -2473,10 +2486,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Float64, true), diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 89f3feaf07be..c520e4271416 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -546,15 +546,16 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. #[derive(Debug, Clone)] pub struct JoinFilter { /// Filter expression - expression: Arc, + pub(crate) expression: Arc, /// Column indices required to construct intermediate batch for filtering - column_indices: Vec, + pub(crate) column_indices: Vec, /// Physical schema of intermediate batch - schema: Schema, + pub(crate) schema: Schema, } impl JoinFilter { @@ -1280,15 +1281,15 @@ pub(crate) fn adjust_indices_by_join_type( adjust_range: Range, join_type: JoinType, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { match join_type { JoinType::Inner => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::Left => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } JoinType::Right => { @@ -1307,22 +1308,22 @@ pub(crate) fn adjust_indices_by_join_type( // need to remove the duplicated record in the right side let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::LeftSemi | JoinType::LeftAnti => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - ( + Ok(( UInt64Array::from_iter_values(vec![]), UInt32Array::from_iter_values(vec![]), - ) + )) } } } @@ -1347,27 +1348,64 @@ pub(crate) fn append_right_indices( right_indices: UInt32Array, adjust_range: Range, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { if preserve_order_for_right { - append_probe_indices_in_order(left_indices, right_indices, adjust_range) + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) } else { let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + Ok((left_indices, right_indices)) } else { - let unmatched_size = right_unmatched_indices.len(); + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect(); - (new_left_indices, new_right_indices) + let mut new_right_indices_builder = right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) } } } @@ -1635,6 +1673,91 @@ pub(crate) fn asymmetric_join_output_partitioning( } } +/// Trait for incrementally generating Join output. +/// +/// This trait is used to limit some join outputs +/// so it does not produce single large batches +pub(crate) trait BatchTransformer: Debug + Clone { + /// Sets the next `RecordBatch` to be processed. + fn set_batch(&mut self, batch: RecordBatch); + + /// Retrieves the next `RecordBatch` from the transformer. + /// Returns `None` if all batches have been produced. + /// The boolean flag indicates whether the batch is the last one. + fn next(&mut self) -> Option<(RecordBatch, bool)>; +} + +#[derive(Debug, Clone)] +/// A batch transformer that does nothing. +pub(crate) struct NoopBatchTransformer { + /// RecordBatch to be processed + batch: Option, +} + +impl NoopBatchTransformer { + pub fn new() -> Self { + Self { batch: None } + } +} + +impl BatchTransformer for NoopBatchTransformer { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + self.batch.take().map(|batch| (batch, true)) + } +} + +#[derive(Debug, Clone)] +/// Splits large batches into smaller batches with a maximum number of rows. +pub(crate) struct BatchSplitter { + /// RecordBatch to be split + batch: Option, + /// Maximum number of rows in a split batch + batch_size: usize, + /// Current row index + row_index: usize, +} + +impl BatchSplitter { + /// Creates a new `BatchSplitter` with the specified batch size. + pub(crate) fn new(batch_size: usize) -> Self { + Self { + batch: None, + batch_size, + row_index: 0, + } + } +} + +impl BatchTransformer for BatchSplitter { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + self.row_index = 0; + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + let Some(batch) = &self.batch else { + return None; + }; + + let remaining_rows = batch.num_rows() - self.row_index; + let rows_to_slice = remaining_rows.min(self.batch_size); + let sliced_batch = batch.slice(self.row_index, rows_to_slice); + self.row_index += rows_to_slice; + + let mut last = false; + if self.row_index >= batch.num_rows() { + self.batch = None; + last = true; + } + + Some((sliced_batch, last)) + } +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -1643,11 +1766,13 @@ mod tests { use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_array::Int32Array; use arrow_schema::SortOptions; - use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use rstest::rstest; + fn check( left: &[Column], right: &[Column], @@ -2554,4 +2679,49 @@ mod tests { Ok(()) } + + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32)); + RecordBatch::try_new(schema, vec![data]).unwrap() + } + + fn assert_split_batches( + batches: Vec<(RecordBatch, bool)>, + batch_size: usize, + num_rows: usize, + ) { + let mut row_count = 0; + for (batch, last) in batches.into_iter() { + assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size)); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + assert_eq!(column.value(i), i as i32 + row_count as i32); + } + row_count += batch.num_rows(); + assert_eq!(last, row_count == num_rows); + } + } + + #[rstest] + #[test] + fn test_batch_splitter( + #[values(1, 3, 11)] batch_size: usize, + #[values(1, 6, 50)] num_rows: usize, + ) { + let mut splitter = BatchSplitter::new(batch_size); + splitter.set_batch(create_test_batch(num_rows)); + + let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size)); + while let Some(batch) = splitter.next() { + batches.push(batch); + } + + assert!(splitter.next().is_none()); + assert_split_batches(batches, batch_size, num_rows); + } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 7acdf25b6596..57bf029a63c1 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -173,6 +173,7 @@ datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false datafusion.execution.enable_recursive_ctes true +datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 @@ -263,6 +264,7 @@ datafusion.execution.batch_size 8192 Default batch size while creating new batch datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs +datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index f34d148f092f..c61a7b673334 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -91,6 +91,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | | datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | | datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From e9435a920ed84a1956b23e7ab6d13fe833cce3eb Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Sat, 19 Oct 2024 00:52:23 +0800 Subject: [PATCH 04/17] =?UTF-8?q?Fix=EF=BC=9Afix=20HashJoin=20projection?= =?UTF-8?q?=20swap=20(#12967)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * swap_hash_join works with joins with projections * use non swapped hash join's projection * clean up * fix hashjoin projection swap. * assert hashjoinexec. * Update datafusion/core/src/physical_optimizer/join_selection.rs Co-authored-by: Eduard Karacharov * fix clippy. --------- Co-authored-by: Onur Satici Co-authored-by: Eduard Karacharov --- .../src/physical_optimizer/join_selection.rs | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 499fb9cbbcf0..dfaa7dbb8910 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -183,13 +183,15 @@ pub fn swap_hash_join( partition_mode, hash_join.null_equals_null(), )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( hash_join.join_type(), JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - ) { + ) || hash_join.projection.is_some() + { Ok(Arc::new(new_join)) } else { // TODO avoid adding ProjectionExec again and again, only adding Final Projection @@ -1287,6 +1289,33 @@ mod tests_statistical { ); } + #[tokio::test] + async fn test_hash_join_swap_on_joins_with_projections() -> Result<()> { + let (big, small) = create_big_and_small(); + let join = Arc::new(HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Arc::new(Column::new_with_schema("big_col", &big.schema())?), + Arc::new(Column::new_with_schema("small_col", &small.schema())?), + )], + None, + &JoinType::Inner, + Some(vec![1]), + PartitionMode::Partitioned, + false, + )?); + let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) + .expect("swap_hash_join must support joins with projections"); + let swapped_join = swapped.as_any().downcast_ref::().expect( + "ProjectionExec won't be added above if HashJoinExec contains embedded projection", + ); + assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped.schema().fields.len(), 1); + assert_eq!(swapped.schema().fields[0].name(), "small_col"); + Ok(()) + } + #[tokio::test] async fn test_swap_reverting_projection() { let left_schema = Schema::new(vec![ From 97f7491ed62ed7643b8b466237fd1ceb19a54431 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 18 Oct 2024 23:06:45 +0400 Subject: [PATCH 05/17] refactor(substrait): refactor ReadRel consumer (#12983) --- .../substrait/src/logical_plan/consumer.rs | 181 +++++++++--------- 1 file changed, 87 insertions(+), 94 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 4af02858e65a..08e54166d39a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -794,60 +794,61 @@ pub async fn from_substrait_rel( let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } - Some(RelType::Read(read)) => match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Named Table") - })?; + Some(RelType::Read(read)) => { + fn read_with_schema( + df: DataFrame, + schema: DFSchema, + projection: &Option, + ) -> Result { + ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; + let schema = apply_masking(schema, projection)?; - let t = ctx.table(table_reference.clone()).await?; + apply_projection(df, schema) + } - let substrait_schema = - from_substrait_named_struct(named_struct, extensions)? - .replace_qualifier(table_reference); + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; - ensure_schema_compatability( - t.schema().to_owned(), - substrait_schema.clone(), - )?; + let substrait_schema = from_substrait_named_struct(named_struct, extensions)?; - let substrait_schema = apply_masking(substrait_schema, &read.projection)?; + match &read.as_ref().read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; - apply_projection(t, substrait_schema) - } - Some(ReadType::VirtualTable(vt)) => { - let base_schema = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Virtual Table") - })?; + let t = ctx.table(table_reference.clone()).await?; - let schema = from_substrait_named_struct(base_schema, extensions)?; + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(schema), - })); + read_with_schema(t, substrait_schema, &read.projection) } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } - let values = vt + let values = vt .values .iter() .map(|row| { @@ -860,79 +861,71 @@ pub async fn from_substrait_rel( Ok(Expr::Literal(from_substrait_literal( lit, extensions, - &base_schema.names, + &named_struct.names, &mut name_idx, )?)) }) .collect::>()?; - if name_idx != base_schema.names.len() { + if name_idx != named_struct.names.len() { return substrait_err!( "Names list must match exactly to nested schema, but found {} uses for {} names", name_idx, - base_schema.names.len() + named_struct.names.len() ); } Ok(lits) }) .collect::>()?; - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for LocalFiles") - })?; - - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = if name.starts_with("file://") + && !name.starts_with("file:///") + { name.replacen("file://", "file:///", 1) } else { name.to_string() }; - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } - - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); - - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference.clone()).await?; + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } - let substrait_schema = - from_substrait_named_struct(named_struct, extensions)? - .replace_qualifier(table_reference); + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); - ensure_schema_compatability( - t.schema().to_owned(), - substrait_schema.clone(), - )?; + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + let t = ctx.table(table_reference.clone()).await?; - let substrait_schema = apply_masking(substrait_schema, &read.projection)?; + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); - apply_projection(t, substrait_schema) + read_with_schema(t, substrait_schema, &read.projection) + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) + } } - _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), - }, + } Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { From 42f906072a3000d005b8ced97654aaec2828a878 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 18 Oct 2024 23:06:58 +0400 Subject: [PATCH 06/17] feat(substrait): add wildcard handling to producer (#12987) * feat(substrait): add wildcard expand rule in producer * add comment describing need for ExpandWildcardRule --- .../substrait/src/logical_plan/producer.rs | 10 +++++- .../tests/cases/roundtrip_logical_plan.rs | 34 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0e1375a8e0ea..7504a287c055 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::config::ConfigOptions; +use datafusion::optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion::optimizer::AnalyzerRule; use std::sync::Arc; use substrait::proto::expression_reference::ExprType; @@ -103,9 +106,14 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result<()> { #[tokio::test] async fn wildcard_select() -> Result<()> { - roundtrip("SELECT * FROM data").await + assert_expected_plan_unoptimized( + "SELECT * FROM data", + "Projection: data.a, data.b, data.c, data.d, data.e, data.f\ + \n TableScan: data", + true, + ) + .await } #[tokio::test] @@ -1174,6 +1180,32 @@ async fn verify_post_join_filter_value(proto: Box) -> Result<()> { Ok(()) } +async fn assert_expected_plan_unoptimized( + sql: &str, + expected_plan_str: &str, + assert_schema: bool, +) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_unoptimized_plan(); + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + + println!("{plan}"); + println!("{plan2}"); + + println!("{proto:?}"); + + if assert_schema { + assert_eq!(plan.schema(), plan2.schema()); + } + + let plan2str = format!("{plan2}"); + assert_eq!(expected_plan_str, &plan2str); + + Ok(()) +} + async fn assert_expected_plan( sql: &str, expected_plan_str: &str, From 3405234836be98860ce1516ed2263c163ada5535 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 18 Oct 2024 12:26:48 -0700 Subject: [PATCH 07/17] Move SMJ join filtered part out of join_output stage. LeftOuter, LeftSemi (#12764) * WIP: move filtered join out of join_output stage * WIP: move filtered join out of join_output stage * WIP: move filtered join out of join_output stage * cleanup * cleanup * Move Left/LeftAnti filtered SMJ join out of join partial stage * Move Left/LeftAnti filtered SMJ join out of join partial stage * Address comments --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 12 +- .../src/joins/sort_merge_join.rs | 1095 ++++++++++++----- .../test_files/sort_merge_join.slt | 478 +++---- 3 files changed, 1061 insertions(+), 524 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 96aa1be181f5..2eab45256dbb 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -125,8 +125,6 @@ async fn test_left_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -134,7 +132,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -229,6 +227,7 @@ async fn test_anti_join_1k() { #[tokio::test] // flaky for HjSmj case, giving 1 rows difference sometimes // https://github.com/apache/datafusion/issues/11555 +#[ignore] async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -515,14 +514,11 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) - && join_tests.contains(&JoinTestType::NljHj) - && nlj_rows != hj_rows - { + if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== NestedLoopJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); Self::save_partitioned_batches_as_parquet( &nlj_collected, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2118c1a5266f..5e77becd1c5e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -29,18 +29,17 @@ use std::io::BufReader; use std::mem; use std::ops::Range; use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::task::{Context, Poll}; use arrow::array::*; -use arrow::compute::{self, concat_batches, take, SortOptions}; +use arrow::compute::{self, concat_batches, filter_record_batch, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; - use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -52,6 +51,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use datafusion_physical_expr_common::sort_expr::LexRequirement; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ @@ -687,7 +688,7 @@ struct SMJStream { /// optional join filter pub filter: Option, /// Staging output array builders - pub output_record_batches: Vec, + pub output_record_batches: JoinedRecordBatches, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready @@ -702,6 +703,22 @@ struct SMJStream { pub reservation: MemoryReservation, /// Runtime env pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, +} + +/// Joined batches with attached join filter information +struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) + pub filter_mask: BooleanBuilder, + /// Row indices to glue together rows in `batches` and `filter_mask` + pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches + pub batch_ids: Vec, } impl RecordBatchStream for SMJStream { @@ -710,6 +727,82 @@ impl RecordBatchStream for SMJStream { } } +#[inline(always)] +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + ids: &[usize], + indices_len: usize, +) -> bool { + row_index == indices_len - 1 + || ids[row_index] != ids[row_index + 1] + || indices.value(row_index) != indices.value(row_index + 1) +} + +// Returns a corrected boolean bitmask for the given join type +// Values in the corrected bitmask can be: true, false, null +// `true` - the row found its match and sent to the output +// `null` - the row ignored, no output +// `false` - the row sent as NULL joined row +fn get_corrected_filter_mask( + join_type: JoinType, + indices: &UInt64Array, + ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let streamed_indices_length = indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(streamed_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left => { + for i in 0..streamed_indices_length { + let last_index = + last_index_for_row(i, indices, ids, streamed_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(false); null_matched]); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi => { + for i in 0..streamed_indices_length { + let last_index = + last_index_for_row(i, indices, ids, streamed_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + // Only outer joins needs to keep track of processed rows and apply corrected filter mask + _ => None, + } +} + impl Stream for SMJStream { type Item = Result; @@ -719,7 +812,6 @@ impl Stream for SMJStream { ) -> Poll> { let join_time = self.join_metrics.join_time.clone(); let _timer = join_time.timer(); - loop { match &self.state { SMJState::Init => { @@ -733,6 +825,22 @@ impl Stream for SMJStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi + ) + { + self.freeze_all()?; + + if !self.output_record_batches.batches.is_empty() + && self.buffered_data.scanning_finished() + { + let out_batch = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out_batch))); + } + } + self.streamed_joined = false; self.streamed_state = StreamedState::Init; } @@ -786,8 +894,23 @@ impl Stream for SMJStream { } } else { self.freeze_all()?; - if !self.output_record_batches.is_empty() { + if !self.output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues + let record_batch = if !(self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi + )) { + record_batch + } else { + continue; + }; + return Poll::Ready(Some(Ok(record_batch))); } return Poll::Pending; @@ -795,11 +918,23 @@ impl Stream for SMJStream { } SMJState::Exhausted => { self.freeze_all()?; - if !self.output_record_batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); + + if !self.output_record_batches.batches.is_empty() { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi + ) + { + let out = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out))); + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } else { + return Poll::Ready(None); } - return Poll::Ready(None); } } } @@ -844,13 +979,19 @@ impl SMJStream { on_streamed, on_buffered, filter, - output_record_batches: vec![], + output_record_batches: JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }, output_size: 0, batch_size, join_type, join_metrics, reservation, runtime_env, + streamed_batch_counter: AtomicUsize::new(0), }) } @@ -882,6 +1023,10 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.streamed_state = StreamedState::Ready; } } @@ -1062,14 +1207,14 @@ impl SMJStream { return Ok(Ordering::Less); } - return compare_join_arrays( + compare_join_arrays( &self.streamed_batch.join_arrays, self.streamed_batch.idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, self.null_equals_null, - ); + ) } /// Produce join and fill output buffer until reaching target batch size @@ -1228,7 +1373,7 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); @@ -1251,7 +1396,7 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + self.output_record_batches.batches.push(record_batch); } buffered_batch.join_filter_failed_map.clear(); } @@ -1329,15 +1474,14 @@ impl SMJStream { }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); + buffered_columns.extend(streamed_columns); buffered_columns } else { streamed_columns.extend(buffered_columns); streamed_columns }; - let output_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; // Apply join filter if any if !filter_columns.is_empty() { @@ -1367,59 +1511,46 @@ impl SMJStream { pre_mask.clone() }; - // For certain join types, we need to adjust the initial mask to handle the join filter. - let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask( - self.join_type, - &streamed_indices, - &mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ); - - let mask = - if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - self.streamed_batch - .join_filter_matched_idxs - .extend(&filtered_join_mask.1); - &filtered_join_mask.0 - } else { - &mask - }; - // Push the filtered batch which contains rows passing join filter to the output - let filtered_batch = - compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); + if matches!(self.join_type, JoinType::Left | JoinType::LeftSemi) { + self.output_record_batches + .batches + .push(output_batch.clone()); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.output_record_batches.batches.push(filtered_batch); + } + + self.output_record_batches.filter_mask.extend(&mask); + self.output_record_batches + .row_indices + .extend(&streamed_indices); + self.output_record_batches.batch_ids.extend(vec![ + self.streamed_batch_counter.load(Relaxed); + streamed_indices.len() + ]); // For outer joins, we need to push the null joined rows to the output if // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!( - self.join_type, - JoinType::Left | JoinType::Right | JoinType::Full - ) { + if matches!(self.join_type, JoinType::Right | JoinType::Full) { // We need to get the mask for row indices that the joined rows are failed // on the join filter. I.e., for a row in streamed side, if all joined rows // between it and all buffered rows are failed on the join filter, we need to // output it with null columns from buffered side. For the mask here, it // behaves like LeftAnti join. - let null_mask: BooleanArray = get_filtered_join_mask( - // Set a mask slot as true only if all joined rows of same streamed index - // are failed on the join filter. - // The masking behavior is like LeftAnti join. - JoinType::LeftAnti, - &streamed_indices, - mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ) - .unwrap() - .0; + let not_mask = if mask.null_count() > 0 { + // If the mask contains nulls, we need to use `prep_null_mask_filter` to + // handle the nulls in the mask as false to produce rows where the mask + // was null itself. + compute::not(&compute::prep_null_mask_filter(&mask))? + } else { + compute::not(&mask)? + }; let null_joined_batch = - compute::filter_record_batch(&output_batch, &null_mask)?; + filter_record_batch(&output_batch, ¬_mask)?; let mut buffered_columns = self .buffered_schema @@ -1457,11 +1588,11 @@ impl SMJStream { }; // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = RecordBatch::try_new( - Arc::clone(&self.schema), - columns.clone(), - )?; - self.output_record_batches.push(null_joined_streamed_batch); + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + self.output_record_batches + .batches + .push(null_joined_streamed_batch); // For full join, we also need to output the null joined rows from the buffered side. // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with @@ -1494,10 +1625,10 @@ impl SMJStream { } } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } @@ -1507,7 +1638,8 @@ impl SMJStream { } fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; + let record_batch = + concat_batches(&self.schema, &self.output_record_batches.batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); // If join filter exists, `self.output_size` is not accurate as we don't know the exact @@ -1520,9 +1652,92 @@ impl SMJStream { } else { self.output_size -= record_batch.num_rows(); } - self.output_record_batches.clear(); + + if !(self.filter.is_some() + && matches!(self.join_type, JoinType::Left | JoinType::LeftSemi)) + { + self.output_record_batches.batches.clear(); + } Ok(record_batch) } + + fn filter_joined_batch(&mut self) -> Result { + let record_batch = self.output_record_batch_and_reset()?; + let out_indices = self.output_record_batches.row_indices.finish(); + let out_mask = self.output_record_batches.filter_mask.finish(); + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + &self.output_record_batches.batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + let mut filtered_record_batch = + filter_record_batch(&record_batch, corrected_mask)?; + let buffered_columns_length = self.buffered_schema.fields.len(); + let streamed_columns_length = self.streamed_schema.fields.len(); + + if matches!(self.join_type, JoinType::Left | JoinType::Right) { + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), null_joined_batch.num_rows())) + .collect::>(); + + let columns = if matches!(self.join_type, JoinType::Right) { + let streamed_columns = null_joined_batch + .columns() + .iter() + .skip(buffered_columns_length) + .cloned() + .collect::>(); + + buffered_columns.extend(streamed_columns); + buffered_columns + } else { + // Left join or full outer join + let mut streamed_columns = null_joined_batch + .columns() + .iter() + .take(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + streamed_columns + }; + + // Push the streamed/buffered batch joined nulls to the output + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?; + } else if matches!(self.join_type, JoinType::LeftSemi) { + let output_column_indices = (0..streamed_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } + + self.output_record_batches.batches.clear(); + self.output_record_batches.batch_ids = vec![]; + self.output_record_batches.filter_mask = BooleanBuilder::new(); + self.output_record_batches.row_indices = UInt64Builder::new(); + Ok(filtered_record_batch) + } } /// Gets the arrays which join filters are applied on. @@ -1631,101 +1846,6 @@ fn get_buffered_columns_from_batch( } } -/// Calculate join filter bit mask considering join type specifics -/// `streamed_indices` - array of streamed datasource JOINED row indices -/// `mask` - array booleans representing computed join filter expression eval result: -/// true = the row index matches the join filter -/// false = the row index doesn't match the join filter -/// `streamed_indices` have the same length as `mask` -/// `matched_indices` array of streaming indices that already has a join filter match -/// `scanning_buffered_offset` current buffered offset across batches -/// -/// This return a tuple of: -/// - corrected mask with respect to the join type -/// - indices of rows in streamed batch that have a join filter match -fn get_filtered_join_mask( - join_type: JoinType, - streamed_indices: &UInt64Array, - mask: &BooleanArray, - matched_indices: &HashSet, - scanning_buffered_offset: &usize, -) -> Option<(BooleanArray, Vec)> { - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - match join_type { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - JoinType::LeftSemi => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_idx); - } else { - corrected_mask.append_value(false); - } - - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1) - { - seen_as_true = false; - } - } - Some((corrected_mask.finish(), filter_matched_indices)) - } - // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. - // `filter_matched_indices` needs to be set once per streaming index - // to prevent duplicates in the output - JoinType::LeftAnti => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - filter_matched_indices.push(streamed_idx); - } - - // Reset `seen_as_true` flag and calculate mask for the current streaming index - // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) - // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last - if (i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1)) - || (i == streamed_indices_length - 1 - && *scanning_buffered_offset == 0) - { - corrected_mask.append_value( - !matched_indices.contains(&streamed_idx) && !seen_as_true, - ); - seen_as_true = false; - } else { - corrected_mask.append_value(false); - } - } - - Some((corrected_mask.finish(), filter_matched_indices)) - } - _ => None, - } -} - /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1966,13 +2086,13 @@ mod tests { use std::sync::Arc; use arrow::array::{Date32Array, Date64Array, Int32Array}; - use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; - use hashbrown::HashSet; - use datafusion_common::JoinType::{LeftAnti, LeftSemi}; + use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -1982,7 +2102,7 @@ mod tests { use datafusion_execution::TaskContext; use crate::expressions::Column; - use crate::joins::sort_merge_join::get_filtered_join_mask; + use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; use crate::joins::utils::JoinOn; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; @@ -3214,170 +3334,573 @@ mod tests { } #[tokio::test] - async fn left_semi_join_filtered_mask() -> Result<()> { + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut tb = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![1; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + tb.batch_ids.extend(vec![2; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![3; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + tb.filter_mask + .extend(&BooleanArray::from(vec![true, false])); + tb.filter_mask.extend(&BooleanArray::from(vec![true])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, true])); + tb.filter_mask.extend(&BooleanArray::from(vec![false])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + let output = concat_batches(&schema, &tb.batches)?; + let out_mask = tb.filter_mask.finish(); + let out_indices = tb.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, false, false, false, false, false, false, false + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + false, false, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, true]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, true, false, false, false, false, false, false + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![1])) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![0])) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, true, false, true, false, false]), - vec![0, 1] - )) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, true]), - vec![1] - )) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![true, false, false, false, false, true]), - &HashSet::from_iter(vec![1]), - &0, - ), - Some(( - BooleanArray::from(vec![true, false, false, false, false, false]), - vec![0] - )) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + let corrected_mask = get_corrected_filter_mask( + JoinType::Left, + &out_indices, + &tb.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] ); + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[null_joined_batch] + ); Ok(()) } #[tokio::test] - async fn left_anti_join_filtered_mask() -> Result<()> { + async fn test_left_semi_join_filtered_mask() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut tb = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![1; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + tb.batch_ids.extend(vec![2; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![3; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + tb.filter_mask + .extend(&BooleanArray::from(vec![true, false])); + tb.filter_mask.extend(&BooleanArray::from(vec![true])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, true])); + tb.filter_mask.extend(&BooleanArray::from(vec![false])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + let output = concat_batches(&schema, &tb.batches)?; + let out_mask = tb.filter_mask.finish(); + let out_indices = tb.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![1])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![0])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, false]), - vec![0, 1] - )) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, true, false, false, false]), - vec![1] - )) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) ); + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftSemi, + &out_indices, + &tb.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index ebd53e9690fc..d00b7d6f0a52 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -100,13 +100,14 @@ Alice 100 Alice 2 Alice 50 Alice 1 Alice 50 Alice 2 +# Uncomment when filtered RIGHT moved # right join with join filter -query TITI rowsort -SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 Alice 1 +#query TITI rowsort +#SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +#---- +#Alice 100 Alice 1 +#Alice 100 Alice 2 +#Alice 50 Alice 1 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -126,22 +127,24 @@ Alice 50 Alice 1 Alice 50 Alice 2 Bob 1 NULL NULL +# Uncomment when filtered FULL moved # full join with join filter -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ----- -Alice 100 NULL NULL -Alice 50 Alice 2 -Bob 1 NULL NULL -NULL NULL Alice 1 - -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 NULL NULL -Bob 1 NULL NULL +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +#---- +#Alice 100 NULL NULL +#Alice 50 Alice 2 +#Bob 1 NULL NULL +#NULL NULL Alice 1 + +# Uncomment when filtered RIGHT moved +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +#---- +#Alice 100 Alice 1 +#Alice 100 Alice 2 +#Alice 50 NULL NULL +#Bob 1 NULL NULL statement ok DROP TABLE t1; @@ -405,221 +408,236 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != statement ok set datafusion.execution.batch_size = 10; -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c union all - select 11 a, 14 b, 4 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c where false - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 11 c union all - select 11 a, 14 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 11 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 14 c union all - select 11 a, 11 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c where false +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 11 c union all +# select 11 a, 14 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 11 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 14 c union all +# select 11 a, 11 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- # Test LEFT ANTI with cross batch data distribution statement ok set datafusion.execution.batch_size = 1; -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c union all - select 11 a, 14 b, 4 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c where false - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 11 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 14 c union all - select 11 a, 11 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query IIII -select * from ( -with t as ( - select id, id % 5 id1 from (select unnest(range(0,10)) id) -), t1 as ( - select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) -) -select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 -) order by 1, 2, 3, 4 ----- -5 0 0 2 -6 1 1 3 -7 2 2 4 -8 3 3 5 -9 4 4 6 -NULL NULL 5 7 -NULL NULL 6 8 -NULL NULL 7 9 -NULL NULL 8 10 -NULL NULL 9 11 +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c where false +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 11 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 14 c union all +# select 11 a, 11 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + +# Uncomment when filtered RIGHT moved +#query IIII +#select * from ( +#with t as ( +# select id, id % 5 id1 from (select unnest(range(0,10)) id) +#), t1 as ( +# select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) +#) +#select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 +#) order by 1, 2, 3, 4 +#---- +#5 0 0 2 +#6 1 1 3 +#7 2 2 4 +#8 3 3 5 +#9 4 4 6 +#NULL NULL 5 7 +#NULL NULL 6 8 +#NULL NULL 7 9 +#NULL NULL 8 10 +#NULL NULL 9 11 query IIII select * from ( From 73ba4c45ff44e7c3c697aa8fea7bb019bb76711a Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 18 Oct 2024 16:19:48 -0400 Subject: [PATCH 08/17] feat: Add regexp_count function (#12970) * Implement regexp_ccount * Update document * fix check * add more tests * Update the world to 1.80 * Fix doc format * Add null tests * Add uft8 support and bench * Refactoring regexp_count * Refactoring regexp_count * Revert ci change * Fix ci * Updates for documentation, minor improvements. * Updates for documentation, minor improvements. * updates to fix scalar tests, doc updates. * updated regex and string features to remove deps on other features. --------- Co-authored-by: Xin Li --- datafusion/functions/Cargo.toml | 2 +- datafusion/functions/benches/regx.rs | 54 +- datafusion/functions/src/regex/mod.rs | 27 +- datafusion/functions/src/regex/regexpcount.rs | 951 ++++++++++++++++++ datafusion/sqllogictest/test_files/regexp.slt | 331 +++++- .../user-guide/sql/scalar_functions_new.md | 32 + 6 files changed, 1382 insertions(+), 15 deletions(-) create mode 100644 datafusion/functions/src/regex/regexpcount.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 6099ad62c1d9..70a988dbfefb 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["regex_expressions", "uuid"] +string_expressions = ["uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index c9a9c1dfb19e..468d3d548bcf 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,11 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +62,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,6 +87,46 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("regexp_count_1000 string", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_count_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8view"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index cde777311aa1..803f51e915a9 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -19,11 +19,13 @@ use std::sync::Arc; +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs +make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); make_udf_function!( @@ -35,6 +37,24 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + /// Returns a list of regular expression matches in a string. pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -70,5 +90,10 @@ pub mod expr_fn { /// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { - vec![regexp_match(), regexp_like(), regexp_replace()] + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] } diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 000000000000..880c91094555 --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,951 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::strings::StringArrayType; +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use itertools::izip; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::{Arc, OnceLock}; + +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_count_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_count_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.") + .with_syntax_example("regexp_count(str, regexp[, start, flags])") + .with_sql_example(r#"```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_standard_argument("regexp","Regular") + .with_argument("start", "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + +pub fn regexp_count_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_count" + ); + } + } + + regexp_count( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_count` function. +/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_count( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (Utf8View, Utf8View, None) => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +pub fn regexp_count_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { + (Some(regex_array.value(0)), true) + } else { + (None, false) + }; + + let (start_array, start_scalar, is_start_scalar) = + if let Some(start_array) = start_array { + if is_start_scalar || start_array.len() == 1 { + (None, Some(start_array.value(0)), true) + } else { + (Some(start_array), None, false) + } + } else { + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = + if let Some(flags_array) = flags_array { + if is_flags_scalar || flags_array.len() == 1 { + (None, Some(flags_array.value(0)), true) + } else { + (Some(flags_array), None, false) + } + } else { + (None, None, true) + }; + + let mut regex_cache = HashMap::new(); + + match (is_regex_scalar, is_start_scalar, is_flags_scalar) { + (true, true, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .map(|value| count_matches(value, &pattern, start_scalar)) + .collect::, ArrowError>>()?, + ))) + } + (true, true, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(flags_array.iter()) + .map(|(value, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (true, false, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + let start_array = start_array.unwrap(); + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(start_array.iter()) + .map(|(value, start)| count_matches(value, &pattern, start)) + .collect::, ArrowError>>()?, + ))) + } + (true, false, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + start_array.unwrap().iter(), + flags_array.iter() + ) + .map(|(value, start, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(regex_array.iter()) + .map(|(value, regex)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), flags_array.iter()) + .map(|(value, regex, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), start_array.iter()) + .map(|(value, regex, start)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + regex_array.iter(), + start_array.iter(), + flags_array.iter() + ) + .map(|(value, regex, start, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + } +} + +fn compile_and_cache_regex( + regex: &str, + flags: Option<&str>, + regex_cache: &mut HashMap, +) -> Result { + match regex_cache.entry(regex.to_string()) { + Entry::Vacant(entry) => { + let compiled = compile_regex(regex, flags)?; + entry.insert(compiled.clone()); + Ok(compiled) + } + Entry::Occupied(entry) => Ok(entry.get().to_owned()), + } +} + +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +fn count_matches( + value: Option<&str>, + pattern: &Regex, + start: Option, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; + + if let Some(start) = start { + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let count = pattern.find_iter(find_slice.as_str()).count(); + Ok(count as i64) + } else { + let count = pattern.find_iter(value).count(); + Ok(count as i64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{GenericStringArray, StringViewArray}; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar(); + test_case_sensitive_regexp_count_scalar_start(); + test_case_insensitive_regexp_count_scalar_flags(); + test_case_sensitive_regexp_count_start_scalar_complex(); + + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let expected: Vec = vec![0, 1, 2, 1, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 2; + let expected: Vec = vec![0, 1, 1, 0, 2]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 1; + let flags = "i"; + let expected: Vec = vec![0, 1, 2, 2, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = ["", "abc", "a", "bc", "ab"]; + let start = 5; + let flags = ["", "i", "", "", "i"]; + let expected: Vec = vec![0, 0, 0, 1, 1]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index eedc3ddb6d59..800026dd766d 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -16,18 +16,18 @@ # under the License. statement ok -CREATE TABLE t (str varchar, pattern varchar, flags varchar) AS VALUES - ('abc', '^(a)', 'i'), - ('ABC', '^(A).*', 'i'), - ('aBc', '(b|d)', 'i'), - ('AbC', '(B|D)', null), - ('aBC', '^(b|c)', null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('Düsseldorf','[\p{Letter}-]+', null), - ('Москва', '[\p{L}-]+', null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', null), - ('إسرائيل', '^\p{Arabic}+$', null); +CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); # # regexp_like tests @@ -460,6 +460,313 @@ SELECT NULL not iLIKE NULL; ---- NULL +# regexp_count tests + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + statement ok drop table t; diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index ffc2b680b5c5..ca70c83e58f9 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1676,10 +1676,42 @@ regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) (minus support for several features including look-around and backreferences). The following regular expression functions are supported: +- [regexp_count](#regexp_count) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) +### `regexp_count` + +Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string. + +``` +regexp_count(str, regexp[, start, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +``` + ### `regexp_like` Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. From 8c9b9152c8201d8b75d8e0b9b85b85d3199c94d8 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Fri, 18 Oct 2024 16:41:04 -0400 Subject: [PATCH 09/17] Minor: Fixed regexpr_match docs (#13008) * regexp_match * update generated docs --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/regex/regexpmatch.rs | 2 +- docs/source/user-guide/sql/scalar_functions_new.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 443e50533268..4a86adbe683a 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -119,7 +119,7 @@ fn get_regexp_match_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_REGEX) - .with_description("Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") + .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string.") .with_syntax_example("regexp_match(str, regexp[, flags])") .with_sql_example(r#"```sql > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index ca70c83e58f9..1915623012f4 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1752,7 +1752,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `regexp_match` -Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string. ``` regexp_match(str, regexp[, flags]) From 10af8a73662f4f6aac09a34157b7cf5fee034502 Mon Sep 17 00:00:00 2001 From: Albert Skalt <133099191+askalt@users.noreply.github.com> Date: Fri, 18 Oct 2024 23:41:53 +0300 Subject: [PATCH 10/17] Improve performance for physical plan creation with many columns (#12950) * Add a benchmark for physical plan creation with many aggregates * Wrap AggregateFunctionExpr with Arc Patch f5c47fa274d53c1d524a1fb788d9a063bf5240ef removed Arc wrappers for AggregateFunctionExpr. But, it can be inefficient. When physical optimizer decides to replace a node child to other, it clones the node (with `with_new_children`). Assume, that node is `AggregateExec` than contains hundreds aggregates and these aggregates are cloned each time. This patch returns a Arc wrapping to not clone AggregateFunctionExpr itself but clone a pointer. * Do not build mapping if parent does not require any This patch adds a small optimization that can soft the edges on some queries. If there are no parent requirements we do not need to build column mapping. --- datafusion/core/benches/sql_planner.rs | 14 +++ .../physical_optimizer/update_aggr_exprs.rs | 10 +- datafusion/core/src/physical_planner.rs | 5 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 1 + .../combine_partial_final_agg.rs | 8 +- .../limited_distinct_aggregation.rs | 16 +-- datafusion/physical-expr/src/aggregate.rs | 2 +- datafusion/physical-expr/src/utils/mod.rs | 4 + .../physical-expr/src/window/aggregate.rs | 8 +- .../src/window/sliding_aggregate.rs | 13 +- .../src/aggregate_statistics.rs | 24 ++-- .../src/combine_partial_final_agg.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 119 ++++++++++-------- .../physical-plan/src/aggregates/row_hash.rs | 4 +- datafusion/physical-plan/src/windows/mod.rs | 5 +- datafusion/proto/src/physical_plan/mod.rs | 3 +- .../proto/src/physical_plan/to_proto.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 44 ++++--- 18 files changed, 165 insertions(+), 119 deletions(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 00f6d5916751..e7c35c8d86d6 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -144,6 +144,20 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("physical_select_aggregates_from_200", |b| { + let mut aggregates = String::new(); + for i in 0..200 { + if i > 0 { + aggregates.push_str(", "); + } + aggregates.push_str(format!("MAX(a{})", i).as_str()); + } + let query = format!("SELECT {} FROM t1", aggregates); + b.iter(|| { + physical_plan(&ctx, &query); + }); + }); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs index c0d9140c025e..26cdd65883e4 100644 --- a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -131,10 +131,10 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { /// successfully. Any errors occurring during the conversion process are /// passed through. fn try_convert_aggregate_if_better( - aggr_exprs: Vec, + aggr_exprs: Vec>, prefix_requirement: &[PhysicalSortRequirement], eq_properties: &EquivalenceProperties, -) -> Result> { +) -> Result>> { aggr_exprs .into_iter() .map(|aggr_expr| { @@ -154,7 +154,7 @@ fn try_convert_aggregate_if_better( let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs); if eq_properties.ordering_satisfy_requirement(&reqs) { // Existing ordering satisfies the aggregator requirements: - aggr_expr.with_beneficial_ordering(true)? + aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) } else if eq_properties.ordering_satisfy_requirement(&concat_slices( prefix_requirement, &reverse_aggr_req, @@ -163,12 +163,14 @@ fn try_convert_aggregate_if_better( // given the existing ordering (if possible): aggr_expr .reverse_expr() + .map(Arc::new) .unwrap_or(aggr_expr) .with_beneficial_ordering(true)? + .map(Arc::new) } else { // There is no beneficial ordering present -- aggregation // will still work albeit in a less efficient mode. - aggr_expr.with_beneficial_ordering(false)? + aggr_expr.with_beneficial_ordering(false)?.map(Arc::new) } .ok_or_else(|| { plan_datafusion_err!( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index cf2a157b04b6..a4dffd3d0208 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1523,7 +1523,7 @@ pub fn create_window_expr( } type AggregateExprWithOptionalArgs = ( - AggregateFunctionExpr, + Arc, // The filter clause, if any Option>, // Ordering requirements, if any @@ -1587,7 +1587,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .alias(name) .with_ignore_nulls(ignore_nulls) .with_distinct(*distinct) - .build()?; + .build() + .map(Arc::new)?; (agg_expr, filter, physical_sort_exprs) }; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 34061a64d783..ff512829333a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -405,6 +405,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .schema(Arc::clone(&schema)) .alias("sum1") .build() + .map(Arc::new) .unwrap(), ]; let expr = group_by_columns diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 24e46b3ad97c..85076abdaf29 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -84,7 +84,7 @@ fn parquet_exec(schema: &SchemaRef) -> Arc { fn partial_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -104,7 +104,7 @@ fn partial_aggregate_exec( fn final_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -130,11 +130,12 @@ fn count_expr( expr: Arc, name: &str, schema: &Schema, -) -> AggregateFunctionExpr { +) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) .schema(Arc::new(schema.clone())) .alias(name) .build() + .map(Arc::new) .unwrap() } @@ -218,6 +219,7 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { .schema(Arc::clone(&schema)) .alias("Sum(b)") .build() + .map(Arc::new) .unwrap(), ]; let groups: Vec<(Arc, String)> = diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index 042f6d622565..d6991711f581 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -347,10 +347,10 @@ fn test_has_aggregate_expression() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema, vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -384,10 +384,10 @@ fn test_has_filter() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 866596d0b690..6330c240241a 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -328,7 +328,7 @@ impl AggregateFunctionExpr { /// not implement the method, returns an error. Order insensitive and hard /// requirement aggregators return `Ok(None)`. pub fn with_beneficial_ordering( - self, + self: Arc, beneficial_ordering: bool, ) -> Result> { let Some(updated_fn) = self diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7..4bd022975ac3 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -86,6 +86,10 @@ pub fn map_columns_before_projection( parent_required: &[Arc], proj_exprs: &[(Arc, String)], ) -> Vec> { + if parent_required.is_empty() { + // No need to build mapping. + return vec![]; + } let column_mapping = proj_exprs .iter() .filter_map(|(expr, name)| { diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index d012fef93b67..3fe5d842dfd1 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -41,7 +41,7 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct PlainAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, order_by: Vec, window_frame: Arc, @@ -50,7 +50,7 @@ pub struct PlainAggregateWindowExpr { impl PlainAggregateWindowExpr { /// Create a new aggregate window function expression pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -137,14 +137,14 @@ impl WindowExpr for PlainAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 143d59eb4495..b889ec8c5d98 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -41,7 +41,7 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct SlidingAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, order_by: Vec, window_frame: Arc, @@ -50,7 +50,7 @@ pub struct SlidingAggregateWindowExpr { impl SlidingAggregateWindowExpr { /// Create a new (sliding) aggregate window function expression. pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -121,14 +121,14 @@ impl WindowExpr for SlidingAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), @@ -159,7 +159,10 @@ impl WindowExpr for SlidingAggregateWindowExpr { }) .collect::>(); Some(Arc::new(SlidingAggregateWindowExpr { - aggregate: self.aggregate.with_new_expressions(args, vec![])?, + aggregate: self + .aggregate + .with_new_expressions(args, vec![]) + .map(Arc::new)?, partition_by: partition_bys, order_by: new_order_by, window_frame: Arc::clone(&self.window_frame), diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index fd21362fd3eb..27870c7865f3 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -312,7 +312,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -321,7 +321,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -342,7 +342,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -351,7 +351,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -371,7 +371,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -383,7 +383,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -403,7 +403,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -415,7 +415,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -446,7 +446,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], filter, Arc::clone(&schema), @@ -455,7 +455,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -491,7 +491,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], filter, Arc::clone(&schema), @@ -500,7 +500,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 4e352e25b52c..86f7e73e9e35 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -125,7 +125,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { type GroupExprsRef<'a> = ( &'a PhysicalGroupBy, - &'a [AggregateFunctionExpr], + &'a [Arc], &'a [Option>], ); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 296c5811e577..f36bd920e83c 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -351,7 +351,7 @@ pub struct AggregateExec { /// Group by expressions group_by: PhysicalGroupBy, /// Aggregate expressions - aggr_expr: Vec, + aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause @@ -378,7 +378,10 @@ impl AggregateExec { /// Function used in `OptimizeAggregateOrder` optimizer rule, /// where we need parts of the new value, others cloned from the old one /// Rewrites aggregate exec with new aggregate expressions. - pub fn with_new_aggr_exprs(&self, aggr_expr: Vec) -> Self { + pub fn with_new_aggr_exprs( + &self, + aggr_expr: Vec>, + ) -> Self { Self { aggr_expr, // clone the rest of the fields @@ -404,7 +407,7 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -435,7 +438,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -545,7 +548,7 @@ impl AggregateExec { } /// Aggregate expressions - pub fn aggr_expr(&self) -> &[AggregateFunctionExpr] { + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -876,7 +879,7 @@ impl ExecutionPlan for AggregateExec { fn create_schema( input_schema: &Schema, group_by: &PhysicalGroupBy, - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: AggregateMode, ) -> Result { let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); @@ -1006,7 +1009,7 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( - aggr_exprs: &mut [AggregateFunctionExpr], + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, @@ -1034,7 +1037,7 @@ pub fn get_finer_aggregate_exprs_requirement( // Reverse requirement is satisfied by exiting ordering. // Hence reverse the aggregator requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1058,7 +1061,7 @@ pub fn get_finer_aggregate_exprs_requirement( // There is a requirement that both satisfies existing requirement and reverse // aggregate requirement. Use updated requirement requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1080,7 +1083,7 @@ pub fn get_finer_aggregate_exprs_requirement( /// * Partial: AggregateFunctionExpr::expressions /// * Final: columns of `AggregateFunctionExpr::state_fields()` pub fn aggregate_expressions( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { @@ -1135,7 +1138,7 @@ fn merge_expressions( pub type AccumulatorItem = Box; pub fn create_accumulators( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], ) -> Result> { aggr_expr .iter() @@ -1458,10 +1461,12 @@ mod tests { ], ); - let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .schema(Arc::clone(&input_schema)) - .alias("COUNT(1)") - .build()?]; + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(1)") + .build()?, + )]; let task_ctx = if spill { // adjust the max memory size to have the partial aggregate result for spill mode. @@ -1596,13 +1601,12 @@ mod tests { vec![vec![false]], ); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1925,17 +1929,16 @@ mod tests { ); // something that allocates within the aggregator - let aggregates_v0: Vec = - vec![test_median_agg_expr(Arc::clone(&input_schema))?]; + let aggregates_v0: Vec> = + vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates_v2: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1989,13 +1992,12 @@ mod tests { let groups = PhysicalGroupBy::default(); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(a)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2029,13 +2031,12 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2080,7 +2081,7 @@ mod tests { fn test_first_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -2092,13 +2093,14 @@ mod tests { .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -2109,6 +2111,7 @@ mod tests { .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // This function either constructs the physical plan below, @@ -2153,7 +2156,7 @@ mod tests { descending: false, nulls_first: false, }; - let aggregates: Vec = if is_first_acc { + let aggregates: Vec> = if is_first_acc { vec![test_first_value_agg_expr(&schema, sort_options)?] } else { vec![test_last_value_agg_expr(&schema, sort_options)?] @@ -2289,6 +2292,7 @@ mod tests { .order_by(ordering_req.to_vec()) .schema(Arc::clone(&test_schema)) .build() + .map(Arc::new) .unwrap() }) .collect::>(); @@ -2318,7 +2322,7 @@ mod tests { }; let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); - let aggregates: Vec = vec![ + let aggregates: Vec> = vec![ test_first_value_agg_expr(&schema, option_desc)?, test_last_value_agg_expr(&schema, option_desc)?, ]; @@ -2376,11 +2380,12 @@ mod tests { ], ); - let aggregates: Vec = + let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) .schema(Arc::clone(&schema)) .alias("1") - .build()?]; + .build() + .map(Arc::new)?]; let input_batches = (0..4) .map(|_| { @@ -2512,7 +2517,8 @@ mod tests { ) .schema(Arc::clone(&batch.schema())) .alias(String::from("SUM(value)")) - .build()?]; + .build() + .map(Arc::new)?]; let input = Arc::new(MemoryExec::try_new( &[vec![batch.clone()]], @@ -2560,7 +2566,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2641,7 +2648,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2728,7 +2736,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) .schema(Arc::clone(&input_schema)) .alias("COUNT(a)") - .build()?, + .build() + .map(Arc::new)?, ]; let grouping_set = PhysicalGroupBy::new( diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 624844b6b985..7d21cc2f1944 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -591,7 +591,7 @@ impl GroupedHashAggregateStream { /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( - agg_expr: &AggregateFunctionExpr, + agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { agg_expr.create_groups_accumulator() @@ -601,7 +601,7 @@ pub(crate) fn create_group_accumulator( "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", agg_expr.name() ); - let agg_expr_captured = agg_expr.clone(); + let agg_expr_captured = Arc::clone(agg_expr); let factory = move || agg_expr_captured.create_accumulator(); Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index adf61f27bc6f..f6902fcbe2e7 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -119,7 +119,8 @@ pub fn create_window_expr( .schema(Arc::new(input_schema.clone())) .alias(name) .with_ignore_nulls(ignore_nulls) - .build()?; + .build() + .map(Arc::new)?; window_expr_from_aggregate_expr( partition_by, order_by, @@ -142,7 +143,7 @@ fn window_expr_from_aggregate_expr( partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, - aggregate: AggregateFunctionExpr, + aggregate: Arc, ) -> Arc { // Is there a potentially unlimited sized window frame? let unbounded_window = window_frame.start_bound.is_unbounded(); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 9a6850cb2108..634ae284c955 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -488,7 +488,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let physical_aggr_expr: Vec = hash_agg + let physical_aggr_expr: Vec> = hash_agg .aggr_expr .iter() .zip(hash_agg.aggr_expr_name.iter()) @@ -518,6 +518,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .with_distinct(agg_node.distinct) .order_by(ordering_req) .build() + .map(Arc::new) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6072baca688c..33eca0723103 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -48,7 +48,7 @@ use crate::protobuf::{ use super::PhysicalExtensionCodec; pub fn serialize_physical_aggr_expr( - aggr_expr: AggregateFunctionExpr, + aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 025676f790a8..4a9bf6afb49e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -73,7 +73,6 @@ use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{ @@ -305,7 +304,8 @@ fn roundtrip_window() -> Result<()> { ) .schema(Arc::clone(&schema)) .alias("avg(b)") - .build()?, + .build() + .map(Arc::new)?, &[], &[], Arc::new(WindowFrame::new(None)), @@ -321,7 +321,8 @@ fn roundtrip_window() -> Result<()> { let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) .schema(Arc::clone(&schema)) .alias("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") - .build()?; + .build() + .map(Arc::new)?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, @@ -367,13 +368,13 @@ fn rountrip_aggregate() -> Result<()> { .alias("NTH_VALUE(b, 1)") .build()?; - let test_cases: Vec> = vec![ + let test_cases = vec![ // AVG - vec![avg_expr], + vec![Arc::new(avg_expr)], // NTH_VALUE - vec![nth_expr], + vec![Arc::new(nth_expr)], // STRING_AGG - vec![str_agg_expr], + vec![Arc::new(str_agg_expr)], ]; for aggregates in test_cases { @@ -400,12 +401,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("AVG(b)") - .build()?, + .build() + .map(Arc::new)?, ]; let agg = AggregateExec::try_new( @@ -429,13 +431,14 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = vec![AggregateExprBuilder::new( + let aggregates = vec![AggregateExprBuilder::new( approx_percentile_cont_udaf(), vec![col("b", &schema)?, lit(0.5)], ) .schema(Arc::clone(&schema)) .alias("APPROX_PERCENTILE_CONT(b, 0.5)") - .build()?]; + .build() + .map(Arc::new)?]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -464,13 +467,14 @@ fn rountrip_aggregate_with_sort() -> Result<()> { }, }]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("ARRAY_AGG(b)") .order_by(sort_exprs) - .build()?, + .build() + .map(Arc::new)?, ]; let agg = AggregateExec::try_new( @@ -531,12 +535,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("example_agg") - .build()?, + .build() + .map(Arc::new)?, ]; roundtrip_test_with_context( @@ -1001,7 +1006,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { AggregateExprBuilder::new(max_udaf(), vec![udf_expr as Arc]) .schema(schema.clone()) .alias("max") - .build()?; + .build() + .map(Arc::new)?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( @@ -1052,7 +1058,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) .schema(Arc::clone(&schema)) .alias("aggregate_udf") - .build()?; + .build() + .map(Arc::new)?; let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( @@ -1079,7 +1086,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { .alias("aggregate_udf") .distinct() .ignore_nulls() - .build()?; + .build() + .map(Arc::new)?; let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, From 34bd8237d2189eca5b560c034d15e63d97a15fa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 18 Oct 2024 23:00:24 +0200 Subject: [PATCH 11/17] Remove logical cross join in planning (#12985) * Remove logical cross join in planning * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * Implement some more substrait pieces * Update datafusion/core/src/physical_planner.rs Co-authored-by: Oleks V * Remove incorrect comment --------- Co-authored-by: Oleks V --- datafusion/core/src/physical_planner.rs | 22 ++++--- datafusion/expr/src/logical_plan/builder.rs | 11 +++- datafusion/expr/src/logical_plan/plan.rs | 6 ++ .../optimizer/src/eliminate_cross_join.rs | 25 +++++--- datafusion/optimizer/src/eliminate_join.rs | 26 +------- datafusion/optimizer/src/push_down_filter.rs | 4 +- datafusion/optimizer/src/push_down_limit.rs | 7 +-- datafusion/sql/src/relation/join.rs | 4 +- datafusion/sql/tests/cases/plan_to_sql.rs | 2 +- datafusion/sql/tests/sql_integration.rs | 30 ++++----- datafusion/sqllogictest/test_files/cte.slt | 2 +- .../sqllogictest/test_files/group_by.slt | 2 +- datafusion/sqllogictest/test_files/join.slt | 4 +- datafusion/sqllogictest/test_files/joins.slt | 2 +- datafusion/sqllogictest/test_files/select.slt | 2 +- datafusion/sqllogictest/test_files/update.slt | 4 +- .../substrait/src/logical_plan/consumer.rs | 12 +++- .../tests/cases/consumer_integration.rs | 62 +++++++++---------- 18 files changed, 117 insertions(+), 110 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a4dffd3d0208..918ebccbeb70 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -78,7 +78,7 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, + DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; @@ -1045,14 +1045,18 @@ impl DefaultPhysicalPlanner { session_state.config_options().optimizer.prefer_hash_join; let join: Arc = if join_on.is_empty() { - // there is no equal join condition, use the nested loop join - // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` - Arc::new(NestedLoopJoinExec::try_new( - physical_left, - physical_right, - join_filter, - join_type, - )?) + if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + // cross join if there is no join conditions and no join filter set + Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else { + // there is no equal join condition, use the nested loop join + Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + )?) + } } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && !prefer_hash_join diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index da2a96327ce5..6ab50440ec5b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -30,8 +30,8 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -950,9 +950,14 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin { + Ok(Self::new(LogicalPlan::Join(Join { left: self.plan, right: Arc::new(right), + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, schema: DFSchemaRef::new(join_schema), }))) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9bd57d22128d..10a99c9e78da 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -222,6 +222,7 @@ pub enum LogicalPlan { Join(Join), /// Apply Cross Join to two logical plans. /// This is used to implement SQL `CROSS JOIN` + /// Deprecated: use [LogicalPlan::Join] instead with empty `on` / no filter CrossJoin(CrossJoin), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an @@ -1873,6 +1874,11 @@ impl LogicalPlan { .as_ref() .map(|expr| format!(" Filter: {expr}")) .unwrap_or_else(|| "".to_string()); + let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) { + "Cross".to_string() + } else { + join_type.to_string() + }; match join_constraint { JoinConstraint::On => { write!( diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 550728ddd3f9..bce5c77ca674 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -25,7 +25,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ - CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, + Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; @@ -51,7 +51,7 @@ impl EliminateCrossJoin { /// Looks like this: /// ```text /// Filter(a.x = b.y AND b.xx = 100) -/// CrossJoin +/// Cross Join /// TableScan a /// TableScan b /// ``` @@ -351,10 +351,15 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::CrossJoin(CrossJoin { + Ok(LogicalPlan::Join(Join { left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, })) } @@ -513,7 +518,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -601,7 +606,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -627,7 +632,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -843,7 +848,7 @@ mod tests { let expected = vec![ "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", @@ -924,7 +929,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -999,7 +1004,7 @@ mod tests { "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -1238,7 +1243,7 @@ mod tests { let expected = vec![ "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index f9b79e036f9b..789235595dab 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -23,7 +23,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, - CrossJoin, Expr, + Expr, }; /// Eliminates joins when join condition is false. @@ -54,13 +54,6 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => { - Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { - left: join.left, - right: join.right, - schema: join.schema, - }))) - } Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -105,21 +98,4 @@ mod tests { let expected = "EmptyRelation"; assert_optimized_plan_equal(plan, expected) } - - #[test] - fn join_on_true() -> Result<()> { - let plan = LogicalPlanBuilder::empty(false) - .join_on( - LogicalPlanBuilder::empty(false).build()?, - Inner, - Some(lit(true)), - )? - .build()?; - - let expected = "\ - CrossJoin:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) - } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6e2cc0cbdbcb..2e3bca5b0bbd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1727,7 +1727,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.d\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.d, test1.e, test1.f\ @@ -1754,7 +1754,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.a\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 47fce64ae00e..6ed77387046e 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -254,10 +254,9 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { - Left | Right | Full => (Some(limit), Some(limit)), + Left | Right | Full | Inner => (Some(limit), Some(limit)), LeftAnti | LeftSemi => (Some(limit), None), RightAnti | RightSemi => (None, Some(limit)), - Inner => (None, None), } } else { match join.join_type { @@ -1116,7 +1115,7 @@ mod test { .build()?; let expected = "Limit: skip=0, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000\ \n Limit: skip=0, fetch=1000\ @@ -1136,7 +1135,7 @@ mod test { .build()?; let expected = "Limit: skip=1000, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=2000\ \n TableScan: test, fetch=2000\ \n Limit: skip=0, fetch=2000\ diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 409533a3eaa5..3f34608e3756 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -151,7 +151,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } } - JoinConstraint::None => not_impl_err!("NONE constraint is not supported"), + JoinConstraint::None => LogicalPlanBuilder::from(left) + .join_on(right, join_type, [])? + .build(), } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 74abdf075f23..2a3c5b5f6b2b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -243,7 +243,7 @@ fn roundtrip_crossjoin() -> Result<()> { .unwrap(); let expected = "Projection: j1.j1_id, j2.j2_string\ - \n Inner Join: Filter: Boolean(true)\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 19f3d31321ce..edb614493b38 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -898,7 +898,7 @@ fn natural_right_join() { fn natural_join_no_common_becomes_cross_join() { let sql = "SELECT * FROM person a NATURAL JOIN lineitem b"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: a\ \n TableScan: person\ \n SubqueryAlias: b\ @@ -2744,8 +2744,8 @@ fn cross_join_not_to_inner_join() { "select person.id from person, orders, lineitem where person.id = person.age;"; let expected = "Projection: person.id\ \n Filter: person.id = person.age\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: person\ \n TableScan: orders\ \n TableScan: lineitem"; @@ -2842,11 +2842,11 @@ fn exists_subquery_schema_outer_schema_overlap() { \n Subquery:\ \n Projection: person.first_name\ \n Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p2\ \n TableScan: person\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2934,10 +2934,10 @@ fn scalar_subquery_reference_outer_field() { \n Projection: count(*)\ \n Aggregate: groupBy=[[]], aggr=[[count(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j3\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; @@ -3123,7 +3123,7 @@ fn join_on_complex_condition() { fn lateral_constant() { let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3138,7 +3138,7 @@ fn lateral_comma_join() { j1, \ LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2"; let expected = "Projection: j1.j1_string, j2.j2_string\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3154,7 +3154,7 @@ fn lateral_comma_join_referencing_join_rhs() { \n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\ \n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n Inner Join: Filter: j1.j1_id = j2.j2_id\ \n TableScan: j1\ \n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\ @@ -3178,12 +3178,12 @@ fn lateral_comma_join_with_shadowing() { ) as j2\ ) as j2;"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ \n Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3215,7 +3215,7 @@ fn lateral_nested_left_join() { j1, \ (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n Left Join: Filter: Boolean(true)\ \n TableScan: j2\ @@ -4281,7 +4281,7 @@ fn test_table_alias() { let expected = "Projection: *\ \n SubqueryAlias: f\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ @@ -4299,7 +4299,7 @@ fn test_table_alias() { let expected = "Projection: *\ \n SubqueryAlias: f\ \n Projection: t1.id AS c1, t2.age AS c2\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e9fcf07e7739..60569803322c 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -722,7 +722,7 @@ logical_plan 03)----Projection: Int64(1) AS val 04)------EmptyRelation 05)----Projection: Int64(2) AS val -06)------CrossJoin: +06)------Cross Join: 07)--------Filter: recursive_cte.val < Int64(2) 08)----------TableScan: recursive_cte 09)--------SubqueryAlias: sub_cte diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 8202b806a755..4f2778b5c0d1 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4050,7 +4050,7 @@ EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 ---- logical_plan 01)Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: lhs 04)------Projection: multiple_ordered_table_with_pk.c, sum(multiple_ordered_table_with_pk.d) AS sum1 05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 519fbb887c7e..fe9ceaa7907a 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -671,7 +671,7 @@ query TT explain select * from t1 inner join t2 on true; ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--TableScan: t2 projection=[t2_id, t2_name, t2_int] physical_plan @@ -905,7 +905,7 @@ JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--SubqueryAlias: e 03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") 04)------TableScan: employees projection=[emp_id, name] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index be9321ddb945..558a9170c7d3 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4050,7 +4050,7 @@ query TT explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--SubqueryAlias: t1 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 0fef56aeea5c..9910ca8da71f 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -558,7 +558,7 @@ EXPLAIN SELECT * FROM ((SELECT column1 FROM foo) "T1" CROSS JOIN (SELECT column2 ---- logical_plan 01)SubqueryAlias: F -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: T1 04)------TableScan: foo projection=[column1] 05)----SubqueryAlias: T2 diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 59133379d443..aaba6998ee63 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -67,7 +67,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 @@ -86,7 +86,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 07)--------TableScan: t2 diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 08e54166d39a..5f1824bc4b30 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -780,7 +780,17 @@ pub async fn from_substrait_rel( )? .build() } - None => plan_err!("JoinRel without join condition is not allowed"), + None => { + let on: Vec = vec![]; + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + false, + )? + .build() + } } } Some(RelType::Cross(cross)) => { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index fffa29df1db5..bc38ef82977f 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -73,17 +73,17 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]]\ \n Projection: PARTSUPP.PS_SUPPLYCOST\ \n Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"EUROPE\")\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: SUPPLIER\ \n TableScan: NATION\ \n TableScan: REGION\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PART\ \n TableScan: SUPPLIER\ \n TableScan: PARTSUPP\ @@ -105,8 +105,8 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_MKTSEGMENT = Utf8(\"BUILDING\") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-03-15\") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8(\"1995-03-15\") AS Date32)\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: LINEITEM\ \n TableScan: CUSTOMER\ \n TableScan: ORDERS" @@ -142,11 +142,11 @@ mod tests { \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"ASIA\") AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: CUSTOMER\ \n TableScan: ORDERS\ \n TableScan: LINEITEM\ @@ -206,9 +206,9 @@ mod tests { \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-10-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8(\"R\") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: CUSTOMER\ \n TableScan: ORDERS\ \n TableScan: LINEITEM\ @@ -230,16 +230,16 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: SUPPLIER\ \n TableScan: NATION\ \n Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: SUPPLIER\ \n TableScan: NATION" @@ -257,7 +257,7 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END)]]\ \n Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END\ \n Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"MAIL\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"SHIP\") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: ORDERS\ \n TableScan: LINEITEM" ); @@ -292,7 +292,7 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8(\"PROMO%\") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32(\"1995-09-01\") AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-10-01\") AS Date32)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: LINEITEM\ \n TableScan: PART" ); @@ -321,7 +321,7 @@ mod tests { \n Projection: SUPPLIER.S_SUPPKEY\ \n Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\ \n TableScan: SUPPLIER\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: PART" ); @@ -353,8 +353,8 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ \n Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY\ \n TableScan: LINEITEM\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: CUSTOMER\ \n TableScan: ORDERS\ \n TableScan: LINEITEM" @@ -369,7 +369,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]]\ \n Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#12\") AND (PART.P_CONTAINER = CAST(Utf8(\"SM CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#23\") AND (PART.P_CONTAINER = CAST(Utf8(\"MED BAG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PKG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PACK\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#34\") AND (PART.P_CONTAINER = CAST(Utf8(\"LG CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\")\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: LINEITEM\ \n TableScan: PART" ); @@ -398,7 +398,7 @@ mod tests { \n Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ \n TableScan: LINEITEM\ \n TableScan: PARTSUPP\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: SUPPLIER\ \n TableScan: NATION" ); @@ -422,9 +422,9 @@ mod tests { \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE\ \n TableScan: LINEITEM\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: SUPPLIER\ \n TableScan: LINEITEM\ \n TableScan: ORDERS\ From 12568bf1bd9b3a6cd1ea1b0632dfd5bdbc00bea1 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Sat, 19 Oct 2024 07:42:16 -0400 Subject: [PATCH 12/17] fix spelling (#13014) --- datafusion/functions/src/regex/regexpmatch.rs | 2 +- docs/source/user-guide/sql/scalar_functions_new.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 4a86adbe683a..a458b205f4e3 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -119,7 +119,7 @@ fn get_regexp_match_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_REGEX) - .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string.") + .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") .with_syntax_example("regexp_match(str, regexp[, flags])") .with_sql_example(r#"```sql > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 1915623012f4..ac6e56a44c10 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1752,7 +1752,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `regexp_match` -Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string. +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. ``` regexp_match(str, regexp[, flags]) From 7a3414774cb7858d9649820ddffa59f5712a3153 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <33904309+akurmustafa@users.noreply.github.com> Date: Sat, 19 Oct 2024 04:49:14 -0700 Subject: [PATCH 13/17] replace take_array with arrow util (#13013) --- datafusion/common/src/utils/mod.rs | 57 +------------------ .../src/aggregate/groups_accumulator.rs | 7 +-- .../functions-aggregate/src/first_last.rs | 8 +-- .../physical-plan/src/repartition/mod.rs | 5 +- datafusion/physical-plan/src/sorts/sort.rs | 5 +- .../src/windows/bounded_window_agg_exec.rs | 8 ++- 6 files changed, 19 insertions(+), 71 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 5bf0f08b092a..def1def9853c 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -26,8 +26,7 @@ use crate::error::{_internal_datafusion_err, _internal_err}; use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; use arrow::buffer::OffsetBuffer; -use arrow::compute; -use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::compute::{partition, take_arrays, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::cast::AsArray; @@ -98,7 +97,7 @@ pub fn get_record_batch_at_indices( record_batch: &RecordBatch, indices: &PrimitiveArray, ) -> Result { - let new_columns = take_arrays(record_batch.columns(), indices)?; + let new_columns = take_arrays(record_batch.columns(), indices, None)?; RecordBatch::try_new_with_options( record_batch.schema(), new_columns, @@ -290,24 +289,6 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } -/// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. -/// -/// TODO: use implementation in arrow-rs when available: -/// -pub fn take_arrays(arrays: &[ArrayRef], indices: &dyn Array) -> Result> { - arrays - .iter() - .map(|array| { - compute::take( - array.as_ref(), - indices, - None, // None: no index check - ) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect() -} - pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -1003,40 +984,6 @@ mod tests { Ok(()) } - #[test] - fn test_take_arrays() -> Result<()> { - let arrays: Vec = vec![ - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), - Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), - Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), - ]; - - let row_indices_vec: Vec> = vec![ - // Get rows 0 and 1 - vec![0, 1], - // Get rows 0 and 1 - vec![0, 2], - // Get rows 1 and 3 - vec![1, 3], - // Get rows 2 and 4 - vec![2, 4], - ]; - for row_indices in row_indices_vec { - let indices: PrimitiveArray = - PrimitiveArray::from_iter_values(row_indices.iter().cloned()); - let chunk = take_arrays(&arrays, &indices)?; - for (arr_orig, arr_chunk) in arrays.iter().zip(&chunk) { - for (idx, orig_idx) in row_indices.iter().enumerate() { - let res1 = ScalarValue::try_from_array(arr_orig, *orig_idx as usize)?; - let res2 = ScalarValue::try_from_array(arr_chunk, idx)?; - assert_eq!(res1, res2); - } - } - } - Ok(()) - } - #[test] fn test_get_at_indices() -> Result<()> { let in_vec = vec![1, 2, 3, 4, 5, 6, 7]; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index b03df0224089..c936c80cbed7 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -27,11 +27,10 @@ use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, + compute::take_arrays, datatypes::UInt32Type, }; -use datafusion_common::{ - arrow_datafusion_err, utils::take_arrays, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; @@ -239,7 +238,7 @@ impl GroupsAccumulatorAdapter { // reorder the values and opt_filter by batch_indices so that // all values for each group are contiguous, then invoke the // accumulator once per group with values - let values = take_arrays(values, &batch_indices)?; + let values = take_arrays(values, &batch_indices, None)?; let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; // invoke each accumulator with the appropriate rows, first diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f6a84c84dcb0..2a3fc623657a 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -22,9 +22,9 @@ use std::fmt::Debug; use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn}; +use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::{compare_rows, get_row_at_idx, take_arrays}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; @@ -340,7 +340,7 @@ impl Accumulator for FirstValueAccumulator { filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - take_arrays(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { let first_row = get_row_at_idx(&ordered_states, 0)?; @@ -670,7 +670,7 @@ impl Accumulator for LastValueAccumulator { filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - take_arrays(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 902d9f4477bc..90e62d6f11f8 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -38,10 +38,11 @@ use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; +use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::{PrimitiveArray, RecordBatchOptions}; -use datafusion_common::utils::{take_arrays, transpose}; +use datafusion_common::utils::transpose; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; @@ -300,7 +301,7 @@ impl BatchPartitioner { let _timer = partitioner_timer.timer(); // Produce batches based on indices - let columns = take_arrays(batch.columns(), &indices)?; + let columns = take_arrays(batch.columns(), &indices, None)?; let mut options = RecordBatchOptions::new(); options = options.with_row_count(Some(indices.len())); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 5d86c2183b9e..8e13a2e07e49 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -40,13 +40,12 @@ use crate::{ SendableRecordBatchStream, Statistics, }; -use arrow::compute::{concat_batches, lexsort_to_indices, SortColumn}; +use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; -use datafusion_common::utils::take_arrays; use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -618,7 +617,7 @@ pub fn sort_batch( lexsort_to_indices(&sort_columns, fetch)? }; - let columns = take_arrays(batch.columns(), &indices)?; + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 4a4c940b22e2..6254ae139a00 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -42,7 +42,7 @@ use crate::{ use ahash::RandomState; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, - compute::{concat, concat_batches, sort_to_indices}, + compute::{concat, concat_batches, sort_to_indices, take_arrays}, datatypes::SchemaRef, record_batch::RecordBatch, }; @@ -50,7 +50,7 @@ use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::utils::{ evaluate_partition_ranges, get_at_indices, get_record_batch_at_indices, - get_row_at_idx, take_arrays, + get_row_at_idx, }; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -536,7 +536,9 @@ impl PartitionSearcher for LinearSearch { // We should emit columns according to row index ordering. let sorted_indices = sort_to_indices(&all_indices, None, None)?; // Construct new column according to row ordering. This fixes ordering - take_arrays(&new_columns, &sorted_indices).map(Some) + take_arrays(&new_columns, &sorted_indices, None) + .map(Some) + .map_err(|e| arrow_datafusion_err!(e)) } fn evaluate_partition_batches( From c7e5d8db453cf1b9d98aae520563a5ea67cdca4c Mon Sep 17 00:00:00 2001 From: Duong Cong Toai <35887761+duongcongtoai@users.noreply.github.com> Date: Sun, 20 Oct 2024 12:53:15 +0200 Subject: [PATCH 14/17] Improve recursive `unnest` options API (#12836) * refactor * refactor unnest options * more test * resolve comments * add back doc * fix proto * flaky test * clippy * use indexmap * chore: compile err * chore: update cargo * chore: fmt cargotoml --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 1 + datafusion/common/src/lib.rs | 2 +- datafusion/common/src/unnest.rs | 26 ++ datafusion/expr/src/logical_plan/builder.rs | 186 +++++------- datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 37 +-- datafusion/expr/src/logical_plan/tree_node.rs | 2 +- datafusion/physical-plan/src/unnest.rs | 70 ++--- datafusion/proto/proto/datafusion.proto | 18 +- datafusion/proto/src/generated/pbjson.rs | 285 +++++++++--------- datafusion/proto/src/generated/prost.rs | 32 +- .../proto/src/logical_plan/from_proto.rs | 13 +- datafusion/proto/src/logical_plan/mod.rs | 63 +--- datafusion/proto/src/logical_plan/to_proto.rs | 10 + datafusion/sql/Cargo.toml | 1 + datafusion/sql/src/select.rs | 57 +++- datafusion/sql/src/utils.rs | 226 ++++++++------ 17 files changed, 504 insertions(+), 529 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index dfd07a7658ff..08d5d4843c62 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1571,6 +1571,7 @@ dependencies = [ "arrow-schema", "datafusion-common", "datafusion-expr", + "indexmap", "log", "regex", "sqlparser", diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 10541e01914a..8323f5efc86d 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -70,7 +70,7 @@ pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; pub use stats::{ColumnStatistics, Statistics}; pub use table_reference::{ResolvedTableReference, TableReference}; -pub use unnest::UnnestOptions; +pub use unnest::{RecursionUnnestOption, UnnestOptions}; pub use utils::project_schema; // These are hidden from docs purely to avoid polluting the public view of what this crate exports. diff --git a/datafusion/common/src/unnest.rs b/datafusion/common/src/unnest.rs index fd92267f9b4c..db48edd06160 100644 --- a/datafusion/common/src/unnest.rs +++ b/datafusion/common/src/unnest.rs @@ -17,6 +17,8 @@ //! [`UnnestOptions`] for unnesting structured types +use crate::Column; + /// Options for unnesting a column that contains a list type, /// replicating values in the other, non nested rows. /// @@ -60,10 +62,27 @@ /// └─────────┘ └─────┘ └─────────┘ └─────┘ /// c1 c2 c1 c2 /// ``` +/// +/// `recursions` instruct how a column should be unnested (e.g unnesting a column multiple +/// time, with depth = 1 and depth = 2). Any unnested column not being mentioned inside this +/// options is inferred to be unnested with depth = 1 #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] pub struct UnnestOptions { /// Should nulls in the input be preserved? Defaults to true pub preserve_nulls: bool, + /// If specific columns need to be unnested multiple times (e.g at different depth), + /// declare them here. Any unnested columns not being mentioned inside this option + /// will be unnested with depth = 1 + pub recursions: Vec, +} + +/// Instruction on how to unnest a column (mostly with a list type) +/// such as how to name the output, and how many level it should be unnested +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct RecursionUnnestOption { + pub input_column: Column, + pub output_column: Column, + pub depth: usize, } impl Default for UnnestOptions { @@ -71,6 +90,7 @@ impl Default for UnnestOptions { Self { // default to true to maintain backwards compatible behavior preserve_nulls: true, + recursions: vec![], } } } @@ -87,4 +107,10 @@ impl UnnestOptions { self.preserve_nulls = preserve_nulls; self } + + /// Set the recursions for the unnest operation + pub fn with_recursions(mut self, recursion: RecursionUnnestOption) -> Self { + self.recursions.push(recursion); + self + } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6ab50440ec5b..f119a2ade827 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -44,6 +44,8 @@ use crate::{ TableProviderFilterPushDown, TableSource, WriteOp, }; +use super::dml::InsertOp; +use super::plan::ColumnUnnestList; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; @@ -54,9 +56,6 @@ use datafusion_common::{ }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; -use super::dml::InsertOp; -use super::plan::{ColumnUnnestList, ColumnUnnestType}; - /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -1186,7 +1185,7 @@ impl LogicalPlanBuilder { ) -> Result { unnest_with_options( Arc::unwrap_or_clone(self.plan), - vec![(column.into(), ColumnUnnestType::Inferred)], + vec![column.into()], options, ) .map(Self::new) @@ -1197,26 +1196,6 @@ impl LogicalPlanBuilder { self, columns: Vec, options: UnnestOptions, - ) -> Result { - unnest_with_options( - Arc::unwrap_or_clone(self.plan), - columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(), - options, - ) - .map(Self::new) - } - - /// Unnest the given columns with the given [`UnnestOptions`] - /// if one column is a list type, it can be recursively and simultaneously - /// unnested into the desired recursion levels - /// e.g select unnest(list_col,depth=1), unnest(list_col,depth=2) - pub fn unnest_columns_recursive_with_options( - self, - columns: Vec<(Column, ColumnUnnestType)>, - options: UnnestOptions, ) -> Result { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) @@ -1594,14 +1573,12 @@ impl TableSource for LogicalTableSource { /// Create a [`LogicalPlan::Unnest`] plan pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { - let unnestings = columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(); - unnest_with_options(input, unnestings, UnnestOptions::default()) + unnest_with_options(input, columns, UnnestOptions::default()) } -pub fn get_unnested_list_datatype_recursive( +// Get the data type of a multi-dimensional type after unnesting it +// with a given depth +fn get_unnested_list_datatype_recursive( data_type: &DataType, depth: usize, ) -> Result { @@ -1620,27 +1597,6 @@ pub fn get_unnested_list_datatype_recursive( internal_err!("trying to unnest on invalid data type {:?}", data_type) } -/// Infer the unnest type based on the data type: -/// - list type: infer to unnest(list(col, depth=1)) -/// - struct type: infer to unnest(struct) -fn infer_unnest_type( - col_name: &String, - data_type: &DataType, -) -> Result { - match data_type { - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { - Ok(ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(col_name), - depth: 1, - }])) - } - DataType::Struct(_) => Ok(ColumnUnnestType::Struct), - _ => { - internal_err!("trying to unnest on invalid data type {:?}", data_type) - } - } -} - pub fn get_struct_unnested_columns( col_name: &String, inner_fields: &Fields, @@ -1729,20 +1685,15 @@ pub fn get_unnested_columns( /// ``` pub fn unnest_with_options( input: LogicalPlan, - columns_to_unnest: Vec<(Column, ColumnUnnestType)>, + columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; let mut struct_columns = vec![]; let indices_to_unnest = columns_to_unnest .iter() - .map(|col_unnesting| { - Ok(( - input.schema().index_of_column(&col_unnesting.0)?, - col_unnesting, - )) - }) - .collect::>>()?; + .map(|c| Ok((input.schema().index_of_column(c)?, c))) + .collect::>>()?; let input_schema = input.schema(); @@ -1767,51 +1718,59 @@ pub fn unnest_with_options( .enumerate() .map(|(index, (original_qualifier, original_field))| { match indices_to_unnest.get(&index) { - Some((column_to_unnest, unnest_type)) => { - let mut inferred_unnest_type = unnest_type.clone(); - if let ColumnUnnestType::Inferred = unnest_type { - inferred_unnest_type = infer_unnest_type( + Some(column_to_unnest) => { + let recursions_on_column = options + .recursions + .iter() + .filter(|p| -> bool { &p.input_column == *column_to_unnest }) + .collect::>(); + let mut transformed_columns = recursions_on_column + .iter() + .map(|r| { + list_columns.push(( + index, + ColumnUnnestList { + output_column: r.output_column.clone(), + depth: r.depth, + }, + )); + Ok(get_unnested_columns( + &r.output_column.name, + original_field.data_type(), + r.depth, + )? + .into_iter() + .next() + .unwrap()) // because unnesting a list column always result into one result + }) + .collect::)>>>()?; + if transformed_columns.is_empty() { + transformed_columns = get_unnested_columns( &column_to_unnest.name, original_field.data_type(), + 1, )?; - } - let transformed_columns: Vec<(Column, Arc)> = - match inferred_unnest_type { - ColumnUnnestType::Struct => { + match original_field.data_type() { + DataType::Struct(_) => { struct_columns.push(index); - get_unnested_columns( - &column_to_unnest.name, - original_field.data_type(), - 1, - )? } - ColumnUnnestType::List(unnest_lists) => { - list_columns.extend( - unnest_lists - .iter() - .map(|ul| (index, ul.to_owned().clone())), - ); - unnest_lists - .iter() - .map( - |ColumnUnnestList { - output_column, - depth, - }| { - get_unnested_columns( - &output_column.name, - original_field.data_type(), - *depth, - ) - }, - ) - .collect::)>>>>()? - .into_iter() - .flatten() - .collect::>() + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + list_columns.push(( + index, + ColumnUnnestList { + output_column: Column::from_name( + &column_to_unnest.name, + ), + depth: 1, + }, + )); } - _ => return internal_err!("Invalid unnest type"), + _ => {} }; + } + // new columns dependent on the same original index dependency_indices .extend(std::iter::repeat(index).take(transformed_columns.len())); @@ -1860,7 +1819,7 @@ mod tests { use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; - use datafusion_common::SchemaError; + use datafusion_common::{RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> { @@ -2268,24 +2227,19 @@ mod tests { // Simultaneously unnesting a list (with different depth) and a struct column let plan = nested_table_scan("test_table")? - .unnest_columns_recursive_with_options( - vec![ - ( - "stringss".into(), - ColumnUnnestType::List(vec![ - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_1"), - depth: 1, - }, - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_2"), - depth: 2, - }, - ]), - ), - ("struct_singular".into(), ColumnUnnestType::Inferred), - ], - UnnestOptions::default(), + .unnest_columns_with_options( + vec!["stringss".into(), "struct_singular".into()], + UnnestOptions::default() + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_1".into(), + depth: 1, + }) + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_2".into(), + depth: 2, + }), )? .build()?; diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a189d4635e00..da44cfb010d7 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -35,8 +35,8 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, ColumnUnnestType, CrossJoin, - DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, + projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable, + Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 10a99c9e78da..72d8f7158be2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3367,39 +3367,6 @@ pub enum Partitioning { DistributeBy(Vec), } -/// Represents the unnesting operation on a column based on the context (a known struct -/// column, a list column, or let the planner infer the unnesting type). -/// -/// The inferred unnesting type works for both struct and list column, but the unnesting -/// will only be done once (depth = 1). In case recursion is needed on a multi-dimensional -/// list type, use [`ColumnUnnestList`] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] -pub enum ColumnUnnestType { - // Unnesting a list column, a vector of ColumnUnnestList is used because - // a column can be unnested at different levels, resulting different output columns - List(Vec), - // for struct, there can only be one unnest performed on one column at a time - Struct, - // Infer the unnest type based on column schema - // If column is a list column, the unnest depth will be 1 - // This value is to support sugar syntax of old api in Dataframe (unnest(either_list_or_struct_column)) - Inferred, -} - -impl fmt::Display for ColumnUnnestType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ColumnUnnestType::List(lists) => { - let list_strs: Vec = - lists.iter().map(|list| list.to_string()).collect(); - write!(f, "List([{}])", list_strs.join(", ")) - } - ColumnUnnestType::Struct => write!(f, "Struct"), - ColumnUnnestType::Inferred => write!(f, "Inferred"), - } - } -} - /// Represent the unnesting operation on a list column, such as the recursion depth and /// the output column name after unnesting /// @@ -3438,7 +3405,7 @@ pub struct Unnest { /// The incoming logical plan pub input: Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: Vec<(Column, ColumnUnnestType)>, + pub exec_columns: Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: Vec<(usize, ColumnUnnestList)>, @@ -3462,7 +3429,7 @@ impl PartialOrd for Unnest { /// The incoming logical plan pub input: &'a Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: &'a Vec<(Column, ColumnUnnestType)>, + pub exec_columns: &'a Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: &'a Vec<(usize, ColumnUnnestList)>, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 83206a2b2af5..606868e75abf 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -501,7 +501,7 @@ impl LogicalPlan { let exprs = columns .iter() - .map(|(c, _)| Expr::Column(c.clone())) + .map(|c| Expr::Column(c.clone())) .collect::>(); exprs.iter().apply_until_stop(f) } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 50af6b4960a5..2311541816f3 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -905,12 +905,10 @@ fn repeat_arrs_from_indices( #[cfg(test)] mod tests { use super::*; - use arrow::{ - datatypes::{Field, Int32Type}, - util::pretty::pretty_format_batches, - }; + use arrow::datatypes::{Field, Int32Type}; use arrow_array::{GenericListArray, OffsetSizeTrait, StringArray}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; + use datafusion_common::assert_batches_eq; // Create a GenericListArray with the following list values: // [A, B, C], [], NULL, [D], NULL, [NULL, F] @@ -1092,38 +1090,37 @@ mod tests { &HashSet::default(), &UnnestOptions { preserve_nulls: true, + recursions: vec![], }, )?; - let actual = - format!("{}", pretty_format_batches(vec![ret].as_ref())?).to_lowercase(); - let expected = r#" -+---------------------------------+---------------------------------+---------------------------------+ -| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 | -+---------------------------------+---------------------------------+---------------------------------+ -| [1, 2, 3] | 1 | a | -| | 2 | b | -| [4, 5] | 3 | | -| [1, 2, 3] | | a | -| | | b | -| [4, 5] | | | -| [1, 2, 3] | 4 | a | -| | 5 | b | -| [4, 5] | | | -| [7, 8, 9, 10] | 7 | c | -| | 8 | d | -| [11, 12, 13] | 9 | | -| | 10 | | -| [7, 8, 9, 10] | | c | -| | | d | -| [11, 12, 13] | | | -| [7, 8, 9, 10] | 11 | c | -| | 12 | d | -| [11, 12, 13] | 13 | | -| | | e | -+---------------------------------+---------------------------------+---------------------------------+ - "# - .trim(); - assert_eq!(actual, expected); + + let expected = &[ +"+---------------------------------+---------------------------------+---------------------------------+", +"| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 |", +"+---------------------------------+---------------------------------+---------------------------------+", +"| [1, 2, 3] | 1 | a |", +"| | 2 | b |", +"| [4, 5] | 3 | |", +"| [1, 2, 3] | | a |", +"| | | b |", +"| [4, 5] | | |", +"| [1, 2, 3] | 4 | a |", +"| | 5 | b |", +"| [4, 5] | | |", +"| [7, 8, 9, 10] | 7 | c |", +"| | 8 | d |", +"| [11, 12, 13] | 9 | |", +"| | 10 | |", +"| [7, 8, 9, 10] | | c |", +"| | | d |", +"| [11, 12, 13] | | |", +"| [7, 8, 9, 10] | 11 | c |", +"| | 12 | d |", +"| [11, 12, 13] | 13 | |", +"| | | e |", +"+---------------------------------+---------------------------------+---------------------------------+", + ]; + assert_batches_eq!(expected, &[ret]); Ok(()) } @@ -1177,7 +1174,10 @@ mod tests { preserve_nulls: bool, expected: Vec, ) -> datafusion_common::Result<()> { - let options = UnnestOptions { preserve_nulls }; + let options = UnnestOptions { + preserve_nulls, + recursions: vec![], + }; let longest_length = find_longest_length(list_arrays, &options)?; let expected_array = Int64Array::from(expected); assert_eq!( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9964ab498fb1..a15fa2c5f9c6 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -264,7 +264,7 @@ message CopyToNode { message UnnestNode { LogicalPlanNode input = 1; - repeated ColumnUnnestExec exec_columns = 2; + repeated datafusion_common.Column exec_columns = 2; repeated ColumnUnnestListItem list_type_columns = 3; repeated uint64 struct_type_columns = 4; repeated uint64 dependency_indices = 5; @@ -285,17 +285,15 @@ message ColumnUnnestListRecursion { uint32 depth = 2; } -message ColumnUnnestExec { - datafusion_common.Column column = 1; - oneof UnnestType { - ColumnUnnestListRecursions list = 2; - datafusion_common.EmptyMessage struct = 3; - datafusion_common.EmptyMessage inferred = 4; - } -} - message UnnestOptions { bool preserve_nulls = 1; + repeated RecursionUnnestOption recursions = 2; +} + +message RecursionUnnestOption { + datafusion_common.Column output_column = 1; + datafusion_common.Column input_column = 2; + uint32 depth = 3; } message UnionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4417d1149681..d223e3646b51 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2306,145 +2306,6 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { deserializer.deserialize_struct("datafusion.ColumnIndex", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnUnnestExec { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.column.is_some() { - len += 1; - } - if self.unnest_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnUnnestExec", len)?; - if let Some(v) = self.column.as_ref() { - struct_ser.serialize_field("column", v)?; - } - if let Some(v) = self.unnest_type.as_ref() { - match v { - column_unnest_exec::UnnestType::List(v) => { - struct_ser.serialize_field("list", v)?; - } - column_unnest_exec::UnnestType::Struct(v) => { - struct_ser.serialize_field("struct", v)?; - } - column_unnest_exec::UnnestType::Inferred(v) => { - struct_ser.serialize_field("inferred", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ColumnUnnestExec { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "column", - "list", - "struct", - "inferred", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Column, - List, - Struct, - Inferred, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "column" => Ok(GeneratedField::Column), - "list" => Ok(GeneratedField::List), - "struct" => Ok(GeneratedField::Struct), - "inferred" => Ok(GeneratedField::Inferred), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnUnnestExec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnUnnestExec") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut column__ = None; - let mut unnest_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Column => { - if column__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - column__ = map_.next_value()?; - } - GeneratedField::List => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::List) -; - } - GeneratedField::Struct => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("struct")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::Struct) -; - } - GeneratedField::Inferred => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inferred")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::Inferred) -; - } - } - } - Ok(ColumnUnnestExec { - column: column__, - unnest_type: unnest_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.ColumnUnnestExec", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for ColumnUnnestListItem { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17489,6 +17350,135 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RecursionUnnestOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.output_column.is_some() { + len += 1; + } + if self.input_column.is_some() { + len += 1; + } + if self.depth != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RecursionUnnestOption", len)?; + if let Some(v) = self.output_column.as_ref() { + struct_ser.serialize_field("outputColumn", v)?; + } + if let Some(v) = self.input_column.as_ref() { + struct_ser.serialize_field("inputColumn", v)?; + } + if self.depth != 0 { + struct_ser.serialize_field("depth", &self.depth)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "output_column", + "outputColumn", + "input_column", + "inputColumn", + "depth", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OutputColumn, + InputColumn, + Depth, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "outputColumn" | "output_column" => Ok(GeneratedField::OutputColumn), + "inputColumn" | "input_column" => Ok(GeneratedField::InputColumn), + "depth" => Ok(GeneratedField::Depth), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RecursionUnnestOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RecursionUnnestOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut output_column__ = None; + let mut input_column__ = None; + let mut depth__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OutputColumn => { + if output_column__.is_some() { + return Err(serde::de::Error::duplicate_field("outputColumn")); + } + output_column__ = map_.next_value()?; + } + GeneratedField::InputColumn => { + if input_column__.is_some() { + return Err(serde::de::Error::duplicate_field("inputColumn")); + } + input_column__ = map_.next_value()?; + } + GeneratedField::Depth => { + if depth__.is_some() { + return Err(serde::de::Error::duplicate_field("depth")); + } + depth__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(RecursionUnnestOption { + output_column: output_column__, + input_column: input_column__, + depth: depth__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RecursionUnnestOption", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for RepartitionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -20411,10 +20401,16 @@ impl serde::Serialize for UnnestOptions { if self.preserve_nulls { len += 1; } + if !self.recursions.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.UnnestOptions", len)?; if self.preserve_nulls { struct_ser.serialize_field("preserveNulls", &self.preserve_nulls)?; } + if !self.recursions.is_empty() { + struct_ser.serialize_field("recursions", &self.recursions)?; + } struct_ser.end() } } @@ -20427,11 +20423,13 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { const FIELDS: &[&str] = &[ "preserve_nulls", "preserveNulls", + "recursions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { PreserveNulls, + Recursions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20454,6 +20452,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { { match value { "preserveNulls" | "preserve_nulls" => Ok(GeneratedField::PreserveNulls), + "recursions" => Ok(GeneratedField::Recursions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20474,6 +20473,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { V: serde::de::MapAccess<'de>, { let mut preserve_nulls__ = None; + let mut recursions__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::PreserveNulls => { @@ -20482,10 +20482,17 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { } preserve_nulls__ = Some(map_.next_value()?); } + GeneratedField::Recursions => { + if recursions__.is_some() { + return Err(serde::de::Error::duplicate_field("recursions")); + } + recursions__ = Some(map_.next_value()?); + } } } Ok(UnnestOptions { preserve_nulls: preserve_nulls__.unwrap_or_default(), + recursions: recursions__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d3fe031a48c9..6b234be57a92 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -400,7 +400,7 @@ pub struct UnnestNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub exec_columns: ::prost::alloc::vec::Vec, + pub exec_columns: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "3")] pub list_type_columns: ::prost::alloc::vec::Vec, #[prost(uint64, repeated, tag = "4")] @@ -432,28 +432,20 @@ pub struct ColumnUnnestListRecursion { pub depth: u32, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnUnnestExec { - #[prost(message, optional, tag = "1")] - pub column: ::core::option::Option, - #[prost(oneof = "column_unnest_exec::UnnestType", tags = "2, 3, 4")] - pub unnest_type: ::core::option::Option, -} -/// Nested message and enum types in `ColumnUnnestExec`. -pub mod column_unnest_exec { - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum UnnestType { - #[prost(message, tag = "2")] - List(super::ColumnUnnestListRecursions), - #[prost(message, tag = "3")] - Struct(super::super::datafusion_common::EmptyMessage), - #[prost(message, tag = "4")] - Inferred(super::super::datafusion_common::EmptyMessage), - } -} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct UnnestOptions { #[prost(bool, tag = "1")] pub preserve_nulls: bool, + #[prost(message, repeated, tag = "2")] + pub recursions: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RecursionUnnestOption { + #[prost(message, optional, tag = "1")] + pub output_column: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub input_column: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub depth: u32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 20d007048a00..99b11939e95b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, - TableReference, UnnestOptions, + exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, + Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; @@ -56,6 +56,15 @@ impl From<&protobuf::UnnestOptions> for UnnestOptions { fn from(opts: &protobuf::UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: r.input_column.as_ref().unwrap().into(), + output_column: r.output_column.as_ref().unwrap().into(), + depth: r.depth as usize, + }) + .collect::>(), } } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 6061a7a0619a..f57910b09ade 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -19,11 +19,10 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; -use crate::protobuf::column_unnest_exec::UnnestType; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - ColumnUnnestExec, ColumnUnnestListItem, ColumnUnnestListRecursion, - ColumnUnnestListRecursions, CustomTableScanNode, SortExprNodeCollection, + ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode, + SortExprNodeCollection, }; use crate::{ convert_required, into_required, @@ -69,8 +68,7 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, WindowUDF, }; -use datafusion_expr::{AggregateUDF, ColumnUnnestList, ColumnUnnestType, Unnest}; -use datafusion_proto_common::EmptyMessage; +use datafusion_expr::{AggregateUDF, ColumnUnnestList, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; @@ -875,33 +873,7 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(unnest.input, ctx, extension_codec)?; Ok(datafusion_expr::LogicalPlan::Unnest(Unnest { input: Arc::new(input), - exec_columns: unnest - .exec_columns - .iter() - .map(|c| { - ( - c.column.as_ref().unwrap().to_owned().into(), - match c.unnest_type.as_ref().unwrap() { - UnnestType::Inferred(_) => ColumnUnnestType::Inferred, - UnnestType::Struct(_) => ColumnUnnestType::Struct, - UnnestType::List(l) => ColumnUnnestType::List( - l.recursions - .iter() - .map(|ul| ColumnUnnestList { - output_column: ul - .output_column - .as_ref() - .unwrap() - .to_owned() - .into(), - depth: ul.depth as usize, - }) - .collect(), - ), - }, - ) - }) - .collect(), + exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), list_type_columns: unnest .list_type_columns .iter() @@ -1610,32 +1582,7 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), exec_columns: exec_columns .iter() - .map(|(col, unnesting)| ColumnUnnestExec { - column: Some(col.into()), - unnest_type: Some(match unnesting { - ColumnUnnestType::Inferred => { - UnnestType::Inferred(EmptyMessage {}) - } - ColumnUnnestType::Struct => { - UnnestType::Struct(EmptyMessage {}) - } - ColumnUnnestType::List(list) => { - UnnestType::List(ColumnUnnestListRecursions { - recursions: list - .iter() - .map(|ul| ColumnUnnestListRecursion { - output_column: Some( - ul.output_column - .to_owned() - .into(), - ), - depth: ul.depth as _, - }) - .collect(), - }) - } - }), - }) + .map(|col| col.into()) .collect(), list_type_columns: proto_unnest_list_items, struct_type_columns: struct_type_columns diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 15fec3a8b2a8..a34a220e490c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ WindowFrameUnits, WindowFunctionDefinition, }; +use crate::protobuf::RecursionUnnestOption; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -49,6 +50,15 @@ impl From<&UnnestOptions> for protobuf::UnnestOptions { fn from(opts: &UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: Some((&r.input_column).into()), + output_column: Some((&r.output_column).into()), + depth: r.depth as u32, + }) + .collect(), } } } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 5c4b83fe38e1..90be576a884e 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -46,6 +46,7 @@ arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +indexmap = { workspace = true } log = { workspace = true } regex = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c665dec21df4..80a08da5e35d 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,8 +25,8 @@ use crate::utils::{ }; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::UnnestOptions; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, @@ -38,6 +38,7 @@ use datafusion_expr::{ qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; +use indexmap::IndexMap; use sqlparser::ast::{ Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderByExpr, WildcardAdditionalOptions, WindowType, @@ -301,7 +302,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // The transformation happen bottom up, one at a time for each iteration // Only exhaust the loop if no more unnest transformation is found for i in 0.. { - let mut unnest_columns = vec![]; + let mut unnest_columns = IndexMap::new(); // from which column used for projection, before the unnest happen // including non unnest column and unnest column let mut inner_projection_exprs = vec![]; @@ -329,14 +330,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { break; } else { // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); - + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } let plan = LogicalPlanBuilder::from(intermediate_plan) .project(inner_projection_exprs)? - .unnest_columns_recursive_with_options( - unnest_columns, - unnest_options, - )? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? .build()?; intermediate_plan = plan; intermediate_select_exprs = outer_projection_exprs; @@ -405,7 +419,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut intermediate_select_exprs = group_expr; loop { - let mut unnest_columns = vec![]; + let mut unnest_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; let outer_projection_exprs = rewrite_recursive_unnests_bottom_up( @@ -418,7 +432,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if unnest_columns.is_empty() { break; } else { - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); let mut projection_exprs = match &aggr_expr_using_columns { Some(exprs) => (*exprs).clone(), @@ -440,12 +454,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; projection_exprs.extend(inner_projection_exprs); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } + intermediate_plan = LogicalPlanBuilder::from(intermediate_plan) .project(projection_exprs)? - .unnest_columns_recursive_with_options( - unnest_columns, - unnest_options, - )? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? .build()?; intermediate_select_exprs = outer_projection_exprs; diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 787bc6634355..14436de01843 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -34,9 +34,9 @@ use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ - col, expr_vec_fmt, ColumnUnnestList, ColumnUnnestType, Expr, ExprSchemable, - LogicalPlan, + col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, }; +use indexmap::IndexMap; use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree @@ -295,7 +295,7 @@ pub(crate) fn value_to_string(value: &Value) -> Option { pub(crate) fn rewrite_recursive_unnests_bottom_up( input: &LogicalPlan, - unnest_placeholder_columns: &mut Vec<(Column, ColumnUnnestType)>, + unnest_placeholder_columns: &mut IndexMap>>, inner_projection_exprs: &mut Vec, original_exprs: &[Expr], ) -> Result> { @@ -326,7 +326,7 @@ struct RecursiveUnnestRewriter<'a> { top_most_unnest: Option, consecutive_unnest: Vec>, inner_projection_exprs: &'a mut Vec, - columns_unnestings: &'a mut Vec<(Column, ColumnUnnestType)>, + columns_unnestings: &'a mut IndexMap>>, transformed_root_exprs: Option>, } impl<'a> RecursiveUnnestRewriter<'a> { @@ -360,13 +360,11 @@ impl<'a> RecursiveUnnestRewriter<'a> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - // let placeholder_name = unnest_expr.display_name()?; let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); let post_unnest_name = format!("unnest_placeholder({},depth={})", inner_expr_name, level); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by - // let post_unnest_alias = print_unnest(&inner_expr_name, level); let placeholder_column = Column::from_name(placeholder_name.clone()); let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; @@ -380,10 +378,8 @@ impl<'a> RecursiveUnnestRewriter<'a> { self.inner_projection_exprs, expr_in_unnest.clone().alias(placeholder_name.clone()), ); - self.columns_unnestings.push(( - Column::from_name(placeholder_name.clone()), - ColumnUnnestType::Struct, - )); + self.columns_unnestings + .insert(Column::from_name(placeholder_name.clone()), None); Ok( get_struct_unnested_columns(&placeholder_name, &inner_fields) .into_iter() @@ -399,39 +395,18 @@ impl<'a> RecursiveUnnestRewriter<'a> { expr_in_unnest.clone().alias(placeholder_name.clone()), ); - // Let post_unnest_column = Column::from_name(post_unnest_name); let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name); - match self + let list_unnesting = self .columns_unnestings - .iter_mut() - .find(|(inner_col, _)| inner_col == &placeholder_column) - { - // There is not unnesting done on this column yet - None => { - self.columns_unnestings.push(( - Column::from_name(placeholder_name.clone()), - ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(post_unnest_name), - depth: level, - }]), - )); - } - // Some unnesting(at some level) has been done on this column - // e.g select unnest(column3), unnest(unnest(column3)) - Some((_, unnesting)) => match unnesting { - ColumnUnnestType::List(list) => { - let unnesting = ColumnUnnestList { - output_column: Column::from_name(post_unnest_name), - depth: level, - }; - if !list.contains(&unnesting) { - list.push(unnesting); - } - } - _ => { - return internal_err!("not reached"); - } - }, + .entry(placeholder_column) + .or_insert(Some(vec![])); + let unnesting = ColumnUnnestList { + output_column: Column::from_name(post_unnest_name), + depth: level, + }; + let list_unnestings = list_unnesting.as_mut().unwrap(); + if !list_unnestings.contains(&unnesting) { + list_unnestings.push(unnesting); } Ok(vec![post_unnest_expr]) } @@ -478,8 +453,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { } /// The rewriting only happens when the traversal has reached the top-most unnest expr - /// within a sequence of consecutive unnest exprs. - /// node, for example given a stack of expr + /// within a sequence of consecutive unnest exprs node /// /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))** /// ```text @@ -560,7 +534,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { // For column exprs that are not descendants of any unnest node // retain their projection // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b - // this condition can be checked by maintaining an Option + // this condition can be checked by maintaining an Option if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() { push_projection_dedupl(self.inner_projection_exprs, expr.clone()); } @@ -589,7 +563,7 @@ fn push_projection_dedupl(projection: &mut Vec, expr: Expr) { /// is done only for the bottom expression pub(crate) fn rewrite_recursive_unnest_bottom_up( input: &LogicalPlan, - unnest_placeholder_columns: &mut Vec<(Column, ColumnUnnestType)>, + unnest_placeholder_columns: &mut IndexMap>>, inner_projection_exprs: &mut Vec, original_expr: &Expr, ) -> Result> { @@ -610,8 +584,8 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102 // // The transformation looks like: - // - unnest(array_col) will be transformed into unnest(array_col) - // - unnest(array_col) + 1 will be transformed into unnest(array_col) + 1 + // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)") + // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1") let Transformed { data: transformed_expr, transformed, @@ -647,17 +621,33 @@ mod tests { use arrow_schema::Fields; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::{ - col, lit, unnest, ColumnUnnestType, EmptyRelation, LogicalPlan, + col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::expr_fn::count; + use indexmap::IndexMap; use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up}; - fn column_unnests_eq(l: Vec<(&str, &str)>, r: &[(Column, ColumnUnnestType)]) { - let r_formatted: Vec = - r.iter().map(|i| format!("{}|{}", i.0, i.1)).collect(); - let l_formatted: Vec = - l.iter().map(|i| format!("{}|{}", i.0, i.1)).collect(); + + fn column_unnests_eq( + l: Vec<&str>, + r: &IndexMap>>, + ) { + let r_formatted: Vec = r + .iter() + .map(|i| match i.1 { + None => format!("{}", i.0), + Some(vec) => format!( + "{}=>[{}]", + i.0, + vec.iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ), + }) + .collect(); + let l_formatted: Vec = l.iter().map(|i| i.to_string()).collect(); assert_eq!(l_formatted, r_formatted); } @@ -687,7 +677,7 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // unnest(unnest(3d_col)) + unnest(unnest(3d_col)) @@ -712,10 +702,9 @@ mod tests { .add(col("i64_col"))] ); column_unnests_eq( - vec![( - "unnest_placeholder(3d_col)", - "List([unnest_placeholder(3d_col,depth=2)|depth=2])", - )], + vec![ + "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + ], &unnest_placeholder_columns, ); @@ -746,9 +735,7 @@ mod tests { ] ); column_unnests_eq( - vec![("unnest_placeholder(3d_col)", - "List([unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1])"), - ], + vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -794,7 +781,7 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // unnest(struct_col) @@ -813,7 +800,7 @@ mod tests { ] ); column_unnests_eq( - vec![("unnest_placeholder(struct_col)", "Struct")], + vec!["unnest_placeholder(struct_col)"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -833,11 +820,8 @@ mod tests { )?; column_unnests_eq( vec![ - ("unnest_placeholder(struct_col)", "Struct"), - ( - "unnest_placeholder(array_col)", - "List([unnest_placeholder(array_col,depth=1)|depth=1])", - ), + "unnest_placeholder(struct_col)", + "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); @@ -860,24 +844,44 @@ mod tests { ] ); - // A nested structure struct[[]] + Ok(()) + } + + // Unnest -> field access -> unnest + #[test] + fn test_transform_non_consecutive_unnests() -> Result<()> { + // List of struct + // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}] let schema = Schema::new(vec![ Field::new( - "struct_col", // {array_col: [1,2,3]} - ArrowDataType::Struct(Fields::from(vec![Field::new( - "matrix", - ArrowDataType::List(Arc::new(Field::new( - "matrix_row", - ArrowDataType::List(Arc::new(Field::new( - "item", - ArrowDataType::Int64, + "struct_list", + ArrowDataType::List(Arc::new(Field::new( + "element", + ArrowDataType::Struct(Fields::from(vec![ + Field::new( + // list of i64 + "subfield1", + ArrowDataType::List(Arc::new(Field::new( + "i64_element", + ArrowDataType::Int64, + true, + ))), true, - ))), - true, - ))), + ), + Field::new( + // list of utf8 + "subfield2", + ArrowDataType::List(Arc::new(Field::new( + "utf8_element", + ArrowDataType::Utf8, + true, + ))), + true, + ), + ])), true, - )])), - false, + ))), + true, ), Field::new("int_col", ArrowDataType::Int32, false), ]); @@ -889,39 +893,69 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // An expr with multiple unnest - let original_expr = unnest(unnest(col("struct_col").field("matrix"))); + let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1")); let transformed_exprs = rewrite_recursive_unnest_bottom_up( &input, &mut unnest_placeholder_columns, &mut inner_projection_exprs, - &original_expr, + &select_expr1, )?; // Only the inner most/ bottom most unnest is transformed assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(struct_col[matrix],depth=2)") - .alias("UNNEST(UNNEST(struct_col[matrix]))")] + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield1") + )] ); - // TODO: add a test case where - // unnest -> field access -> unnest column_unnests_eq( - vec![( - "unnest_placeholder(struct_col[matrix])", - "List([unnest_placeholder(struct_col[matrix],depth=2)|depth=2])", - )], + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], + &unnest_placeholder_columns, + ); + + assert_eq!( + inner_projection_exprs, + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + ); + + // continue rewrite another expr in select + let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &select_expr2, + )?; + // Only the inner most/ bottom most unnest is transformed + assert_eq!( + transformed_exprs, + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield2") + )] + ); + + // unnest place holder columns remain the same + // because expr1 and expr2 derive from the same unnest result + column_unnests_eq( + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_col") - .field("matrix") - .alias("unnest_placeholder(struct_col[matrix])"),] + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] ); Ok(()) From 373fe23733d97dfac6195d77ebca0646fe9c37d0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 20 Oct 2024 08:45:51 -0400 Subject: [PATCH 15/17] Update version to 42.1.0, add CHANGELOG (#12986) (#12989) * Update version to 42.1.0, add CHANGELOG (#12986) * CHANGELOG for 42.1.0 * Update version to 42.1.0 * Update datafusion-cli/Cargo.lock * update config docs * update datafusion-cli --- Cargo.toml | 48 ++++----- datafusion-cli/Cargo.lock | 166 +++++++++++++++--------------- datafusion-cli/Cargo.toml | 4 +- dev/changelog/42.1.0.md | 42 ++++++++ docs/source/user-guide/configs.md | 2 +- 5 files changed, 152 insertions(+), 110 deletions(-) create mode 100644 dev/changelog/42.1.0.md diff --git a/Cargo.toml b/Cargo.toml index 2c142c87c892..63bfb7fce413 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,7 +59,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" rust-version = "1.79" -version = "42.0.0" +version = "42.1.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -92,29 +92,29 @@ bytes = "1.4" chrono = { version = "0.4.38", default-features = false } ctor = "0.2.0" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "42.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "42.0.0" } -datafusion-common = { path = "datafusion/common", version = "42.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "42.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "42.0.0" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "42.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "42.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "42.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "42.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "42.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "42.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "42.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "42.0.0" } +datafusion = { path = "datafusion/core", version = "42.1.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "42.1.0" } +datafusion-common = { path = "datafusion/common", version = "42.1.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.1.0" } +datafusion-execution = { path = "datafusion/execution", version = "42.1.0" } +datafusion-expr = { path = "datafusion/expr", version = "42.1.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "42.1.0" } +datafusion-functions = { path = "datafusion/functions", version = "42.1.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.1.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.1.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.1.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "42.1.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.1.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "42.1.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.1.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.1.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.1.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.1.0" } +datafusion-proto = { path = "datafusion/proto", version = "42.1.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "42.1.0" } +datafusion-sql = { path = "datafusion/sql", version = "42.1.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.1.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "42.1.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 08d5d4843c62..612209fdd922 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.13" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e614738943d3f68c628ae3dbce7c3daffb196665f82f8c8ea6b65de73c79429" +checksum = "103db485efc3e41214fe4fda9f3dbeae2eb9082f48fd236e6095627a9422066e" dependencies = [ "bzip2", "flate2", @@ -523,9 +523,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.45.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" +checksum = "0dc2faec3205d496c7e57eff685dd944203df7ce16a4116d0281c44021788a7b" dependencies = [ "aws-credential-types", "aws-runtime", @@ -545,9 +545,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.46.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" +checksum = "c93c241f52bc5e0476e259c953234dab7e2a35ee207ee202e86c0095ec4951dc" dependencies = [ "aws-credential-types", "aws-runtime", @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.45.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" +checksum = "b259429be94a3459fa1b00c5684faee118d74f9577cc50aebadc36e507c63b5f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -663,9 +663,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -678,7 +678,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.28" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ "jobserver", "libc", @@ -974,9 +974,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -984,9 +984,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -1162,9 +1162,9 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "dashmap" @@ -1182,7 +1182,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "apache-avro", @@ -1239,7 +1239,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow-schema", "async-trait", @@ -1252,7 +1252,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "assert_cmd", @@ -1282,7 +1282,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "apache-avro", @@ -1305,7 +1305,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "42.0.0" +version = "42.1.0" dependencies = [ "log", "tokio", @@ -1313,7 +1313,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "chrono", @@ -1332,7 +1332,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1354,7 +1354,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "datafusion-common", @@ -1363,7 +1363,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-buffer", @@ -1388,7 +1388,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1407,7 +1407,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1419,7 +1419,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-array", @@ -1440,7 +1440,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "42.0.0" +version = "42.1.0" dependencies = [ "datafusion-common", "datafusion-expr", @@ -1453,7 +1453,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1461,7 +1461,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "async-trait", @@ -1479,7 +1479,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1505,7 +1505,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1517,7 +1517,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-schema", @@ -1531,7 +1531,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1564,7 +1564,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-array", @@ -2066,9 +2066,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -2090,9 +2090,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", @@ -2116,7 +2116,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2132,9 +2132,9 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", - "rustls 0.23.14", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2153,7 +2153,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.0", "pin-project-lite", "socket2", "tokio", @@ -2260,9 +2260,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2339,9 +2339,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libflate" @@ -2625,7 +2625,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.0", "itertools", "md-5", "parking_lot", @@ -2879,9 +2879,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" dependencies = [ "unicode-ident", ] @@ -2913,7 +2913,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.14", + "rustls 0.23.15", "socket2", "thiserror", "tokio", @@ -2930,7 +2930,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.14", + "rustls 0.23.15", "slab", "thiserror", "tinyvec", @@ -3074,7 +3074,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3085,7 +3085,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.14", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3204,9 +3204,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.14" +version = "0.23.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" +checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" dependencies = [ "once_cell", "ring", @@ -3261,9 +3261,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -3288,9 +3288,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "rustyline" @@ -3411,9 +3411,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944" dependencies = [ "itoa", "memchr", @@ -3772,7 +3772,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.14", + "rustls 0.23.15", "rustls-pki-types", "tokio", ] @@ -3950,9 +3950,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -4006,9 +4006,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -4017,9 +4017,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -4032,9 +4032,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4044,9 +4044,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4054,9 +4054,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -4067,9 +4067,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" @@ -4086,9 +4086,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index fe929495aae6..8e4352612889 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "42.0.0" +version = "42.1.0" authors = ["Apache DataFusion "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -39,7 +39,7 @@ aws-sdk-sts = "1.43.0" # end pin aws-sdk crates aws-credential-types = "1.2.0" clap = { version = "4.5.16", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "42.0.0", features = [ +datafusion = { path = "../datafusion/core", version = "42.1.0", features = [ "avro", "crypto_expressions", "datetime_expressions", diff --git a/dev/changelog/42.1.0.md b/dev/changelog/42.1.0.md new file mode 100644 index 000000000000..cf4f911150ac --- /dev/null +++ b/dev/changelog/42.1.0.md @@ -0,0 +1,42 @@ + + +# Apache DataFusion 42.1.0 Changelog + +This release consists of 5 commits from 4 contributors. See credits at the end of this changelog for more information. + +**Other:** + +- Backport update to arrow 53.1.0 on branch-42 [#12977](https://github.com/apache/datafusion/pull/12977) (alamb) +- Backport "Provide field and schema metadata missing on cross joins, and union with null fields" (#12729) [#12974](https://github.com/apache/datafusion/pull/12974) (matthewmturner) +- Backport "physical-plan: Cast nested group values back to dictionary if necessary" (#12586) [#12976](https://github.com/apache/datafusion/pull/12976) (matthewmturner) +- backport-to-DF-42: Provide field and schema metadata missing on distinct aggregations [#12975](https://github.com/apache/datafusion/pull/12975) (Xuanwo) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 2 Matthew Turner + 1 Andrew Lamb + 1 Andy Grove + 1 Xuanwo +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index c61a7b673334..10917932482c 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -66,7 +66,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 42.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 42.1.0 | (writing) Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | From 8d4614d6c43104a13b42c062d957843f25ee32db Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sun, 20 Oct 2024 05:46:25 -0700 Subject: [PATCH 16/17] Don't preserve functional dependency when generating UNION logical plan (#44) (#12979) * Don't preserve functional dependency when generating UNION logical plan * Remove extra lines --- datafusion/core/src/dataframe/mod.rs | 48 +++++++++++++++++++++ datafusion/expr/src/logical_plan/builder.rs | 11 +++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8a0829cd5e4b..4feadd260d7f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2623,6 +2623,54 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_union() -> Result<()> { + let df = test_table().await?; + + let df1 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![min(col("c2"))])? + // SELECT `c1` , min(c2) as `result` + .select(vec![col("c1"), min(col("c2")).alias("result")])?; + let df2 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![max(col("c3"))])? + // SELECT `c1` , max(c3) as `result` + .select(vec![col("c1"), max(col("c3")).alias("result")])?; + + let df_union = df1.union(df2)?; + let df = df_union + // GROUP BY `c1` + .aggregate( + vec![col("c1")], + vec![sum(col("result")).alias("sum_result")], + )? + // SELECT `c1`, sum(result) as `sum_result` + .select(vec![(col("c1")), col("sum_result")])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+----+------------+", + "| c1 | sum_result |", + "+----+------------+", + "| a | 84 |", + "| b | 69 |", + "| c | 124 |", + "| d | 126 |", + "| e | 121 |", + "+----+------------+" + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_aggregate_subexpr() -> Result<()> { let df = test_table().await?; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index f119a2ade827..21304068a8ab 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -51,8 +51,8 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, + plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, + Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -1386,7 +1386,12 @@ pub fn validate_unique_names<'a>( pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { // Temporarily use the schema from the left input and later rely on the analyzer to // coerce the two schemas into a common one. - let schema = Arc::clone(left_plan.schema()); + + // Functional Dependencies doesn't preserve after UNION operation + let schema = (**left_plan.schema()).clone(); + let schema = + Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?); + Ok(LogicalPlan::Union(Union { inputs: vec![Arc::new(left_plan), Arc::new(right_plan)], schema, From 972e3abea4286b0d06c44498d576c8498ddd3be2 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sun, 20 Oct 2024 14:47:59 +0200 Subject: [PATCH 17/17] feat: Decorrelate more predicate subqueries (#12945) * Decorrelate more predicate subqueries * Added sqllogictest explain tests --- datafusion/core/tests/tpcds_planning.rs | 3 - .../src/decorrelate_predicate_subquery.rs | 500 ++++++++---------- .../sqllogictest/test_files/subquery.slt | 170 +++++- .../sqllogictest/test_files/tpch/q20.slt.part | 8 +- .../tests/cases/roundtrip_logical_plan.rs | 18 +- 5 files changed, 405 insertions(+), 294 deletions(-) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index b99bc2680044..6beb29183483 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -571,7 +571,6 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -697,7 +696,6 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -750,7 +748,6 @@ async fn tpcds_physical_q44() -> Result<()> { create_physical_plan(44).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q45() -> Result<()> { create_physical_plan(45).await diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index d1ac80003ba7..cdffa8c645ea 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,6 +17,7 @@ //! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; +use std::iter; use std::ops::Deref; use std::sync::Arc; @@ -27,16 +28,17 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; +use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, + exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; +use itertools::chain; use log::debug; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins @@ -48,79 +50,6 @@ impl DecorrelatePredicateSubquery { pub fn new() -> Self { Self::default() } - - fn rewrite_subquery( - &self, - mut subquery: Subquery, - config: &dyn OptimizerConfig, - ) -> Result { - subquery.subquery = Arc::new( - self.rewrite(Arc::unwrap_or_clone(subquery.subquery), config)? - .data, - ); - Ok(subquery) - } - - /// Finds expressions that have the predicate subqueries (and recurses when found) - /// - /// # Arguments - /// - /// * `predicate` - A conjunction to split and search - /// * `optimizer_config` - For generating unique subquery aliases - /// - /// Returns a tuple (subqueries, non-subquery expressions) - fn extract_subquery_exprs( - &self, - predicate: Expr, - config: &dyn OptimizerConfig, - ) -> Result<(Vec, Vec)> { - let filters = split_conjunction_owned(predicate); // TODO: add ExistenceJoin to support disjunctions - - let mut subqueries = vec![]; - let mut others = vec![]; - for it in filters.into_iter() { - match it { - Expr::Not(not_expr) => match *not_expr { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - !negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, !negated)); - } - expr => others.push(not(expr)), - }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, negated)); - } - expr => others.push(expr), - } - } - - Ok((subqueries, others)) - } } impl OptimizerRule for DecorrelatePredicateSubquery { @@ -133,69 +62,51 @@ impl OptimizerRule for DecorrelatePredicateSubquery { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + let plan = plan + .map_subqueries(|subquery| { + subquery.transform_down(|p| self.rewrite(p, config)) + })? + .data; + let LogicalPlan::Filter(filter) = plan else { return Ok(Transformed::no(plan)); }; - // if there are no subqueries in the predicate, return the original plan - let has_subqueries = - split_conjunction(&filter.predicate) - .iter() - .any(|expr| match expr { - Expr::Not(not_expr) => { - matches!(not_expr.as_ref(), Expr::InSubquery(_) | Expr::Exists(_)) - } - Expr::InSubquery(_) | Expr::Exists(_) => true, - _ => false, - }); - - if !has_subqueries { + if !has_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - let Filter { - predicate, input, .. - } = filter; - let (subqueries, mut other_exprs) = - self.extract_subquery_exprs(predicate, config)?; - if subqueries.is_empty() { + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + if with_subqueries.is_empty() { return internal_err!( "can not find expected subqueries in DecorrelatePredicateSubquery" ); } // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(input); - for subquery in subqueries { - if let Some(plan) = - build_join(&subquery, &cur_input, config.alias_generator())? - { - cur_input = plan; - } else { - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - let sub_query_expr = match subquery { - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: false, - } => in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: true, - } => not_in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: false, - } => exists(query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: true, - } => not_exists(query.subquery), - }; - other_exprs.push(sub_query_expr); + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top(&subquery, &cur_input, config.alias_generator())? + { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } } @@ -216,6 +127,104 @@ impl OptimizerRule for DecorrelatePredicateSubquery { } } +fn rewrite_inner_subqueries( + outer: LogicalPlan, + expr: Expr, + config: &dyn OptimizerConfig, +) -> Result<(LogicalPlan, Expr)> { + let mut cur_input = outer; + let alias = config.alias_generator(); + let expr_without_subqueries = expr.transform(|e| match e { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + negated, + }) => { + match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? + { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + } + } + Expr::InSubquery(InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }) => { + let in_predicate = subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |output_expr| { + Ok(Expr::eq(*expr.clone(), output_expr)) + })?; + match existence_join( + &cur_input, + Arc::clone(&subquery), + Some(in_predicate), + negated, + alias, + )? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))), + None => Ok(Transformed::no(in_subquery(*expr, subquery))), + } + } + _ => Ok(Transformed::no(e)), + })?; + Ok((cur_input, expr_without_subqueries.data)) +} + +enum SubqueryPredicate { + // The subquery expression is at the top level of the filter and can be fully replaced by a + // semi/anti join + Top(SubqueryInfo), + // The subquery expression is embedded within another expression and is replaced using an + // existence join + Embedded(Expr), +} + +fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { + match expr { + Expr::Not(not_expr) => match *not_expr { + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, !negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) + } + expr => SubqueryPredicate::Embedded(not(expr)), + }, + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) + } + expr => SubqueryPredicate::Embedded(expr), + } +} + +fn has_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + /// Optimize the subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// @@ -246,7 +255,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { /// Projection: t2.id /// TableScan: t2 /// ``` -fn build_join( +fn build_join_top( query_info: &SubqueryInfo, left: &LogicalPlan, alias: &Arc, @@ -265,9 +274,70 @@ fn build_join( }) .map_or(Ok(None), |v| v.map(Some))?; + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); + build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) +} + +/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join +/// and checking if the column is null or not. If native support is added for Existence/Mark then +/// we should use that instead. +/// +/// This is used to handle the case when the subquery is embedded in a more complex boolean +/// expression like and OR. For example +/// +/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)` +/// +/// The optimized plan will be: +/// +/// ```text +/// Projection: t1.id +/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL +/// Left Join: Filter: t1.id = __correlated_sq_1.id +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.id, true as __exists +/// TableScan: t2 +fn existence_join( + left: &LogicalPlan, + subquery: Arc, + in_predicate_opt: Option, + negated: bool, + alias_generator: &Arc, +) -> Result> { + // Add non nullable column to emulate existence join + let always_true_expr = lit(true).alias("__exists"); + let cols = chain( + subquery.schema().columns().into_iter().map(Expr::Column), + iter::once(always_true_expr), + ); + let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?; + let alias = alias_generator.next("__correlated_sq"); + + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists")); + let exists_expr = if negated { + exists_col.is_null() + } else { + exists_col.is_not_null() + }; + + Ok( + build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)? + .map(|plan| (plan, exists_expr)), + ) +} +fn build_join( + left: &LogicalPlan, + subquery: &LogicalPlan, + in_predicate_opt: Option, + join_type: JoinType, + alias: String, +) -> Result> { let mut pull_up = PullUpCorrelatedExpr::new() .with_in_predicate_opt(in_predicate_opt.clone()) .with_exists_sub_query(in_predicate_opt.is_none()); @@ -278,7 +348,7 @@ fn build_join( } let sub_query_alias = LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.to_string())? + .alias(alias.to_string())? .build()?; let mut all_correlated_cols = BTreeSet::new(); pull_up @@ -289,8 +359,7 @@ fn build_join( // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, &subquery_alias) - .map(Option::Some) + replace_qualified_name(filter, &all_correlated_cols, &alias).map(Option::Some) })?; if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { @@ -302,7 +371,7 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate.and(join_filter)) } @@ -315,17 +384,13 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate) } _ => None, } { // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; @@ -361,6 +426,19 @@ impl SubqueryInfo { negated, } } + + pub fn expr(self) -> Expr { + match self.where_in_expr { + Some(expr) => match self.negated { + true => not_in_subquery(expr, self.query.subquery), + false => in_subquery(expr, self.query.subquery), + }, + None => match self.negated { + true => not_exists(self.query.subquery), + false => exists(self.query.subquery), + }, + } + } } #[cfg(test)] @@ -371,7 +449,7 @@ mod tests { use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::{and, binary_expr, col, lit, not, or, out_ref_col, table_scan}; + use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -442,60 +520,6 @@ mod tests { assert_optimized_plan_equal(plan, expected) } - /// Test for IN subquery with additional OR filter - /// filter expression not modified - #[test] - fn in_subquery_with_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(or( - and( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - ), - in_subquery(col("c"), test_subquery_with_name("sq")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) OR test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - - #[test] - fn in_subquery_with_and_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and( - or( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - in_subquery(col("b"), test_subquery_with_name("sq1")?), - ), - in_subquery(col("c"), test_subquery_with_name("sq2")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) OR test.b IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq1.c [c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq2.c [c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test for nested IN subqueries #[test] fn in_subquery_nested() -> Result<()> { @@ -512,51 +536,19 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } - /// Test for filter input modification in case filter not supported - /// Outer filter expression not modified while inner converted to join - #[test] - fn in_subquery_input_modified() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c"), test_subquery_with_name("sq_inner")?))? - .project(vec![col("b"), col("c")])? - .alias("wrapped")? - .filter(or( - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - in_subquery(col("c"), test_subquery_with_name("sq_outer")?), - ))? - .project(vec![col("b")])? - .build()?; - - let expected = "Projection: wrapped.b [b:UInt32]\ - \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq_outer.c [c:UInt32]\ - \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ - \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_inner.c [c:UInt32]\ - \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] @@ -630,13 +622,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; @@ -1003,44 +995,6 @@ mod tests { Ok(()) } - /// Test for correlated IN subquery filter with disjustions - #[test] - fn in_subquery_disjunction() -> Result<()> { - let sq = Arc::new( - LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(col("orders.o_custkey")), - )? - .project(vec![col("orders.o_custkey")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter( - in_subquery(col("customer.c_custkey"), sq) - .or(col("customer.c_custkey").eq(lit(1))), - )? - .project(vec![col("customer.c_custkey")])? - .build()?; - - // TODO: support disjunction - for now expect unaltered plan - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey IN () OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - Subquery: [o_custkey:Int64] - Projection: orders.o_custkey [o_custkey:Int64] - Filter: outer_ref(customer.c_custkey) = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) - } - /// Test for correlated IN subquery filter #[test] fn in_subquery_correlated() -> Result<()> { @@ -1407,13 +1361,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 30b3631681e7..22857dd285c2 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -415,13 +415,13 @@ query TT explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int)) ---- logical_plan -01)LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +01)LeftSemi Join: t1.t1_id = __correlated_sq_2.t2_id 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] -03)--SubqueryAlias: __correlated_sq_1 +03)--SubqueryAlias: __correlated_sq_2 04)----Projection: t2.t2_id -05)------LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int +05)------LeftSemi Join: Filter: __correlated_sq_1.t1_int > t2.t2_int 06)--------TableScan: t2 projection=[t2_id, t2_int] -07)--------SubqueryAlias: __correlated_sq_2 +07)--------SubqueryAlias: __correlated_sq_1 08)----------TableScan: t1 projection=[t1_int] #invalid_scalar_subquery @@ -1028,6 +1028,168 @@ false true true +# in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 + +# exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +33 c 3 +44 d 4 + +# in_subquery_to_join_with_correlated_outer_filter_and_or +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists +04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) +05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id +06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int] +07)----------SubqueryAlias: __correlated_sq_1 +08)------------TableScan: t3 projection=[t3_id] +09)--------SubqueryAlias: __correlated_sq_2 +10)----------Projection: t2.t2_id, Boolean(true) AS __exists +11)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +11 a 1 +22 b 2 +44 d 4 + +# Nested subqueries +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where exists ( + select * from t2 where t1.t1_id = t2.t2_id OR exists ( + select * from t3 where t2.t2_id = t3.t3_id + ) +) +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 # issue: https://github.com/apache/datafusion/issues/7027 query TTTT rowsort diff --git a/datafusion/sqllogictest/test_files/tpch/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/q20.slt.part index 67ea87b6ee61..177e38e51ca4 100644 --- a/datafusion/sqllogictest/test_files/tpch/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q20.slt.part @@ -58,19 +58,19 @@ order by logical_plan 01)Sort: supplier.s_name ASC NULLS LAST 02)--Projection: supplier.s_name, supplier.s_address -03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey +03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_2.ps_suppkey 04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey 08)------------Filter: nation.n_name = Utf8("CANADA") 09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] -10)------SubqueryAlias: __correlated_sq_1 +10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey 12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) -13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey +13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_1.p_partkey 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] -15)--------------SubqueryAlias: __correlated_sq_2 +15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey 17)------------------Filter: part.p_name LIKE Utf8("forest%") 18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ae67b6924436..06a047b108bd 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -474,16 +474,14 @@ async fn roundtrip_inlist_5() -> Result<()> { // using assert_expected_plan here as a workaround assert_expected_plan( "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", - "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2\ - \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2", + "Projection: data.a, data.f\ + \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\ + \n Projection: data.a, data.f, Boolean(true)\ + \n Left Join: data.a = data2.a\ + \n TableScan: data projection=[a, f]\ + \n Projection: data2.a, Boolean(true)\ + \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ + \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", true).await }