From 9361ee69251b1835567556ee6f8c3d56351076f6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 20 Oct 2024 11:29:11 +0200 Subject: [PATCH] Implement `PhysicalExpr` CSE --- datafusion-cli/Cargo.lock | 2 +- datafusion/common/Cargo.toml | 1 - datafusion/common/src/config.rs | 4 + datafusion/common/src/cse.rs | 45 +- .../core/src/physical_optimizer/optimizer.rs | 10 +- .../optimizer/src/common_subexpr_eliminate.rs | 66 ++- .../physical-expr-common/src/physical_expr.rs | 28 +- .../physical-expr-common/src/sort_expr.rs | 3 + .../physical-expr/src/expressions/binary.rs | 10 +- .../physical-expr/src/expressions/case.rs | 9 +- .../physical-expr/src/expressions/cast.rs | 10 +- .../physical-expr/src/expressions/column.rs | 10 +- .../physical-expr/src/expressions/in_list.rs | 8 + .../src/expressions/is_not_null.rs | 7 +- .../physical-expr/src/expressions/is_null.rs | 7 +- .../physical-expr/src/expressions/like.rs | 10 +- .../physical-expr/src/expressions/literal.rs | 9 +- .../physical-expr/src/expressions/negative.rs | 7 +- .../physical-expr/src/expressions/no_op.rs | 7 +- .../physical-expr/src/expressions/not.rs | 7 +- .../physical-expr/src/expressions/try_cast.rs | 9 +- .../src/expressions/unknown_column.rs | 7 + .../physical-expr/src/scalar_function.rs | 17 +- datafusion/physical-optimizer/Cargo.toml | 1 + .../src/eliminate_common_physical_subexprs.rs | 497 ++++++++++++++++++ datafusion/physical-optimizer/src/lib.rs | 1 + .../tests/cases/roundtrip_physical_plan.rs | 6 + .../sqllogictest/test_files/explain.slt | 3 + 28 files changed, 740 insertions(+), 61 deletions(-) create mode 100644 datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 401f203dd931..419c7be81a1f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1293,7 +1293,6 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", - "indexmap", "instant", "libc", "num_cpus", @@ -1524,6 +1523,7 @@ dependencies = [ "arrow-schema", "datafusion-common", "datafusion-execution", + "datafusion-expr", "datafusion-expr-common", "datafusion-physical-expr", "datafusion-physical-plan", diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 0747672a18f6..1ac27b40c219 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -56,7 +56,6 @@ arrow-schema = { workspace = true } chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } -indexmap = { workspace = true } libc = "0.2.140" num_cpus = { workspace = true } object_store = { workspace = true, optional = true } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 47ffe0b1c66b..b256e6aae703 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -22,9 +22,11 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Display}; use std::str::FromStr; +use crate::alias::AliasGenerator; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; use crate::{DataFusionError, Result}; +use std::sync::Arc; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -674,6 +676,8 @@ pub struct ConfigOptions { pub explain: ExplainOptions, /// Optional extensions registered using [`Extensions::insert`] pub extensions: Extensions, + /// Return alias generator used to generate unique aliases + pub alias_generator: Arc, } impl ConfigField for ConfigOptions { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 453ae26e7333..ff447d8e9207 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -25,7 +25,6 @@ use crate::tree_node::{ TreeNodeVisitor, }; use crate::Result; -use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; use std::marker::PhantomData; @@ -123,11 +122,11 @@ type IdArray<'n, N> = Vec<(usize, Option>)>; /// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s /// by their identifiers. -type NodeStats<'n, N> = HashMap, (usize, usize)>; +type NodeStats<'n, N> = HashMap, (usize, usize, Option)>; /// A map that contains the common [`TreeNode`]s and their alias by their identifiers, /// extracted during the second, rewriting traversal. -type CommonNodes<'n, N> = IndexMap, (N, String)>; +type CommonNodes<'n, N> = Vec<(N, String)>; type ChildrenList = (Vec, Vec); @@ -155,7 +154,7 @@ pub trait CSEController { fn generate_alias(&self) -> String; // Replaces a node to the generated alias. - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node; // A helper method called on each node during top-down traversal during the second, // rewriting traversal of CSE. @@ -331,8 +330,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito self.id_array[down_index].0 = self.up_index; if is_valid && !self.controller.is_ignored(node) { self.id_array[down_index].1 = Some(node_id); - let (count, conditional_count) = - self.node_stats.entry(node_id).or_insert((0, 0)); + let (count, conditional_count, _) = + self.node_stats.entry(node_id).or_insert((0, 0, None)); if self.conditional { *conditional_count += 1; } else { @@ -355,7 +354,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito /// replaced [`TreeNode`] tree. struct CSERewriter<'a, 'n, N, C: CSEController> { /// statistics of [`TreeNode`]s - node_stats: &'a NodeStats<'n, N>, + node_stats: &'a mut NodeStats<'n, N>, /// cache to speed up second traversal id_array: &'a IdArray<'n, N>, @@ -383,7 +382,8 @@ impl> TreeNodeRewriter // Handle nodes with identifiers only if let Some(node_id) = node_id { - let (count, conditional_count) = self.node_stats.get(&node_id).unwrap(); + let (count, conditional_count, common_index) = + self.node_stats.get_mut(&node_id).unwrap(); if *count > 1 || *count == 1 && *conditional_count > 0 { // step index to skip all sub-node (which has smaller series number). while self.down_index < self.id_array.len() @@ -392,13 +392,15 @@ impl> TreeNodeRewriter self.down_index += 1; } - let (node, alias) = - self.common_nodes.entry(node_id).or_insert_with(|| { - let node_alias = self.controller.generate_alias(); - (node, node_alias) - }); + let index = *common_index.get_or_insert_with(|| { + let index = self.common_nodes.len(); + let node_alias = self.controller.generate_alias(); + self.common_nodes.push((node, node_alias)); + index + }); + let (node, alias) = self.common_nodes.get(index).unwrap(); - let rewritten = self.controller.rewrite(node, alias); + let rewritten = self.controller.rewrite(node, alias, index); return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); } @@ -491,7 +493,7 @@ impl> CSE &mut self, node: N, id_array: &IdArray<'n, N>, - node_stats: &NodeStats<'n, N>, + node_stats: &mut NodeStats<'n, N>, common_nodes: &mut CommonNodes<'n, N>, ) -> Result { if id_array.is_empty() { @@ -514,7 +516,7 @@ impl> CSE &mut self, nodes_list: Vec>, arrays_list: &[Vec>], - node_stats: &NodeStats<'n, N>, + node_stats: &mut NodeStats<'n, N>, common_nodes: &mut CommonNodes<'n, N>, ) -> Result>> { nodes_list @@ -559,13 +561,13 @@ impl> CSE // nodes so we have to keep them intact. nodes_list.clone(), &id_arrays_list, - &node_stats, + &mut node_stats, &mut common_nodes, )?; assert!(!common_nodes.is_empty()); Ok(FoundCommonNodes::Yes { - common_nodes: common_nodes.into_values().collect(), + common_nodes, new_nodes_list, original_nodes_list: nodes_list, }) @@ -635,7 +637,12 @@ mod test { self.alias_generator.next(CSE_PREFIX) } - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + fn rewrite( + &mut self, + node: &Self::Node, + alias: &str, + _index: usize, + ) -> Self::Node { TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) } } diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 7a6f991121ef..4d4a3689d31c 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -17,9 +17,6 @@ //! Physical optimizer traits -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use std::sync::Arc; - use super::projection_pushdown::ProjectionPushdown; use super::update_aggr_exprs::OptimizeAggregateOrder; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; @@ -33,6 +30,9 @@ use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggr use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::sanity_checker::SanityCheckPlan; use crate::physical_optimizer::topk_aggregation::TopKAggregation; +use datafusion_physical_optimizer::eliminate_common_physical_subexprs::EliminateCommonPhysicalSubexprs; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use std::sync::Arc; /// A rule-based physical optimizer. #[derive(Clone, Debug)] @@ -103,6 +103,10 @@ impl PhysicalOptimizer { // replacing operators with fetching variants, or adding limits // past operators that support limit pushdown. Arc::new(LimitPushdown::new()), + // The EliminateCommonPhysicalSubExprs rule extracts common physical + // subexpression trees into a `ProjectionExec` node under the actual node to + // calculate the common values only once. + Arc::new(EliminateCommonPhysicalSubexprs::new()), // The SanityCheckPlan rule checks whether the order and // distribution requirements of each node in the plan // is satisfied. It will also reject non-runnable query diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 921011d33fc4..d3eecfbdc3b6 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -387,8 +387,7 @@ impl CommonSubexprEliminate { // Since `group_expr` may have changed, schema may also. // Use `try_new()` method. Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) - .map(Transformed::no) + .map(|p| Transformed::no(LogicalPlan::Aggregate(p))) } else { Aggregate::try_new_with_schema( new_input, @@ -396,8 +395,7 @@ impl CommonSubexprEliminate { rewritten_aggr_expr, schema, ) - .map(LogicalPlan::Aggregate) - .map(Transformed::no) + .map(|p| Transformed::no(LogicalPlan::Aggregate(p))) } } } @@ -617,9 +615,7 @@ impl CSEController for ExprCSEController<'_> { fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { match node { - // In case of `ScalarFunction`s we don't know which children are surely - // executed so start visiting all children conditionally and stop the - // recursion with `TreeNodeRecursion::Jump`. + // In case of `ScalarFunction`s all children can be conditionally executed. Expr::ScalarFunction(ScalarFunction { func, args }) if func.short_circuits() => { @@ -689,7 +685,7 @@ impl CSEController for ExprCSEController<'_> { self.alias_generator.next(CSE_PREFIX) } - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node { // alias the expressions without an `Alias` ancestor node if self.alias_counter > 0 { col(alias) @@ -1019,10 +1015,14 @@ mod test { fn subexpr_in_same_order() -> Result<()> { let table_scan = test_table_scan()?; + let a = col("a"); + let lit_1 = lit(1); + let _1_plus_a = lit_1 + a; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ - (lit(1) + col("a")).alias("first"), - (lit(1) + col("a")).alias("second"), + _1_plus_a.clone().alias("first"), + _1_plus_a.alias("second"), ])? .build()?; @@ -1039,8 +1039,13 @@ mod test { fn subexpr_in_different_order() -> Result<()> { let table_scan = test_table_scan()?; + let a = col("a"); + let lit_1 = lit(1); + let _1_plus_a = lit_1.clone() + a.clone(); + let a_plus_1 = a + lit_1; + let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![lit(1) + col("a"), col("a") + lit(1)])? + .project(vec![_1_plus_a, a_plus_1])? .build()?; let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ @@ -1055,6 +1060,8 @@ mod test { fn cross_plans_subexpr() -> Result<()> { let table_scan = test_table_scan()?; + let _1_plus_col_a = lit(1) + col("a"); + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![lit(1) + col("a"), col("a")])? .project(vec![lit(1) + col("a")])? @@ -1307,9 +1314,12 @@ mod test { fn test_volatile() -> Result<()> { let table_scan = test_table_scan()?; - let extracted_child = col("a") + col("b"); - let rand = rand_func().call(vec![]); + let a = col("a"); + let b = col("b"); + let extracted_child = a + b; + let rand = rand_expr(); let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ not_extracted_volatile.clone().alias("c1"), @@ -1330,13 +1340,19 @@ mod test { fn test_volatile_short_circuits() -> Result<()> { let table_scan = test_table_scan()?; - let rand = rand_func().call(vec![]); - let extracted_short_circuit_leg_1 = col("a").eq(lit(0)); + let a = col("a"); + let b = col("b"); + let rand = rand_expr(); + let rand_eq_0 = rand.eq(lit(0)); + + let extracted_short_circuit_leg_1 = a.eq(lit(0)); let not_extracted_volatile_short_circuit_1 = - extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0))); - let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0)); + extracted_short_circuit_leg_1.or(rand_eq_0.clone()); + + let not_extracted_short_circuit_leg_2 = b.eq(lit(0)); let not_extracted_volatile_short_circuit_2 = - rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2); + rand_eq_0.or(not_extracted_short_circuit_leg_2); + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ not_extracted_volatile_short_circuit_1.clone().alias("c1"), @@ -1359,7 +1375,10 @@ mod test { fn test_non_top_level_common_expression() -> Result<()> { let table_scan = test_table_scan()?; - let common_expr = col("a") + col("b"); + let a = col("a"); + let b = col("b"); + let common_expr = a + b; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ common_expr.clone().alias("c1"), @@ -1382,8 +1401,11 @@ mod test { fn test_nested_common_expression() -> Result<()> { let table_scan = test_table_scan()?; - let nested_common_expr = col("a") + col("b"); + let a = col("a"); + let b = col("b"); + let nested_common_expr = a + b; let common_expr = nested_common_expr.clone() * nested_common_expr; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ common_expr.clone().alias("c1"), @@ -1406,8 +1428,8 @@ mod test { /// /// Does not use datafusion_functions::rand to avoid introducing a /// dependency on that crate. - fn rand_func() -> ScalarUDF { - ScalarUDF::new_from_impl(RandomStub::new()) + fn rand_expr() -> Expr { + ScalarUDF::new_from_impl(RandomStub::new()).call(vec![]) } #[derive(Debug)] diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 41e46d261508..783acba138b7 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -26,6 +26,7 @@ use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -52,7 +53,9 @@ use datafusion_expr_common::sort_properties::ExprProperties; /// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html -pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { +pub trait PhysicalExpr: + Send + Sync + Display + Debug + DynEq + DynHash + DynHashNode +{ /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -149,6 +152,10 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn get_properties(&self, _children: &[ExprProperties]) -> Result { Ok(ExprProperties::new_unknown()) } + + fn is_volatile(&self) -> bool { + false + } } pub trait DynEq { @@ -174,7 +181,7 @@ impl Eq for dyn PhysicalExpr {} /// Note: [`PhysicalExpr`] is not constrained by [`Hash`] directly because it must remain /// object safe. pub trait DynHash { - fn dyn_hash(&self, _state: &mut dyn Hasher); + fn dyn_hash(&self, state: &mut dyn Hasher); } impl DynHash for T { @@ -190,6 +197,23 @@ impl Hash for dyn PhysicalExpr { } } +pub trait DynHashNode { + fn dyn_hash_node(&self, state: &mut dyn Hasher); +} + +impl DynHashNode for T { + fn dyn_hash_node(&self, mut state: &mut dyn Hasher) { + self.type_id().hash(&mut state); + self.hash_node(&mut state) + } +} + +impl HashNode for dyn PhysicalExpr { + fn hash_node(&self, state: &mut H) { + self.dyn_hash_node(state); + } +} + /// Returns a copy of this expr if we change any child according to the pointer comparison. /// The size of `children` must be equal to the size of `PhysicalExpr::children()`. pub fn with_new_children_if_necessary( diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index c8f741b2522f..8c5801e8e71f 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -40,6 +40,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; +/// # use datafusion_common::cse::HashNode; /// # use arrow::compute::SortOptions; /// # use arrow::datatypes::{DataType, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; @@ -60,6 +61,8 @@ use datafusion_expr_common::columnar_value::ColumnarValue; /// # impl Display for MyPhysicalExpr { /// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") } /// # } +/// # impl HashNode for MyPhysicalExpr {fn hash_node(&self, _state: &mut H) {} +/// # } /// # fn col(name: &str) -> Arc { Arc::new(MyPhysicalExpr) } /// // Sort by a ASC /// let options = SortOptions::default(); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 7cdac6cbb1f9..a2640ef208ad 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,7 +17,7 @@ mod kernels; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; @@ -32,6 +32,7 @@ use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; @@ -660,6 +661,13 @@ impl BinaryExpr { } } +impl HashNode for BinaryExpr { + fn hash_node(&self, state: &mut H) { + self.op.hash(state); + self.fail_on_overflow.hash(state); + } +} + fn concat_elements(left: Arc, right: Arc) -> Result { Ok(match left.data_type() { DataType::Utf8 => Arc::new(concat_elements_utf8( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 61facedc826f..320f25bc353c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,7 +16,7 @@ // under the License. use std::borrow::Cow; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::expressions::try_cast; @@ -31,6 +31,7 @@ use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarV use datafusion_expr::ColumnarValue; use super::{Column, Literal}; +use datafusion_common::cse::HashNode; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; @@ -507,6 +508,12 @@ impl PhysicalExpr for CaseExpr { } } +impl HashNode for CaseExpr { + fn hash_node(&self, state: &mut H) { + self.eval_method.hash(state); + } +} + /// Create a CASE expression pub fn case( expr: Option>, diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 6feabb3b9a2c..76d9ea725c9a 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -25,6 +25,7 @@ use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -179,6 +180,13 @@ impl PhysicalExpr for CastExpr { } } +impl HashNode for CastExpr { + fn hash_node(&self, state: &mut H) { + self.cast_type.hash(state); + self.cast_options.hash(state); + } +} + /// Return a PhysicalExpression representing `expr` casted to /// `cast_type`, if any casting is needed. /// diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 5f6932f6d725..b0046ff9afdf 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,7 +18,7 @@ //! Physical column reference: [`Column`] use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -27,6 +27,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_schema::SchemaRef; +use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -140,6 +141,13 @@ impl PhysicalExpr for Column { } } +impl HashNode for Column { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + self.index.hash(state); + } +} + impl Column { fn bounds_check(&self, input_schema: &Schema) -> Result<()> { if self.index < input_schema.fields.len() { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index bd0b61a56d66..f97c19a261a2 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -44,6 +44,7 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; +use datafusion_common::cse::HashNode; use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; @@ -419,6 +420,13 @@ impl Hash for InListExpr { } } +impl HashNode for InListExpr { + fn hash_node(&self, state: &mut H) { + self.negated.hash(state); + // Add `self.static_filter` when hash is available + } +} + /// Creates a unary expression InList pub fn in_list( expr: Arc, diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 0e1fe80ca22c..d02ad32daffb 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,7 +17,7 @@ //! IS NOT NULL expression -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; @@ -93,6 +94,10 @@ impl PhysicalExpr for IsNotNullExpr { } } +impl HashNode for IsNotNullExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Create an IS NOT NULL expression pub fn is_not_null(arg: Arc) -> Result> { Ok(Arc::new(IsNotNullExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 2a6fa5d6dbb2..731cda5ff41c 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,7 +17,7 @@ //! IS NULL expression -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; @@ -92,6 +93,10 @@ impl PhysicalExpr for IsNullExpr { } } +impl HashNode for IsNullExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Create an IS NULL expression pub fn is_null(arg: Arc) -> Result> { Ok(Arc::new(IsNullExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index fef86a0a2758..835c3d3c80c7 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; @@ -128,6 +129,13 @@ impl PhysicalExpr for LikeExpr { } } +impl HashNode for LikeExpr { + fn hash_node(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + } +} + /// used for optimize Dictionary like fn can_like_type(from_type: &DataType) -> bool { match from_type { diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index f0d02eb605b2..c1d42cafb2dd 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,7 +18,7 @@ //! Literal expressions for physical operations use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -27,6 +27,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -94,6 +95,12 @@ impl PhysicalExpr for Literal { } } +impl HashNode for Literal { + fn hash_node(&self, state: &mut H) { + self.value.hash(state); + } +} + /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index ee653e31f2c1..22a79164d507 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,7 +18,7 @@ //! Negation (-) expression use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -28,6 +28,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -136,6 +137,10 @@ impl PhysicalExpr for NegativeExpr { } } +impl HashNode for NegativeExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Creates a unary expression NEGATIVE /// /// # Errors diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index c17b52f5cdff..042bcb5b625d 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,7 +18,7 @@ //! NoOp placeholder for physical operations use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::{ @@ -27,6 +27,7 @@ use arrow::{ }; use crate::PhysicalExpr; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -78,3 +79,7 @@ impl PhysicalExpr for NoOp { Ok(self) } } + +impl HashNode for NoOp { + fn hash_node(&self, _state: &mut H) {} +} diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 0d5f1966af36..75be906b54ee 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -19,12 +19,13 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -100,6 +101,10 @@ impl PhysicalExpr for NotExpr { } } +impl HashNode for NotExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Creates a unary expression NOT pub fn not(arg: Arc) -> Result> { Ok(Arc::new(NotExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 03e995948cae..6449b8a8906e 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -26,6 +26,7 @@ use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; +use datafusion_common::cse::HashNode; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -111,6 +112,12 @@ impl PhysicalExpr for TryCastExpr { } } +impl HashNode for TryCastExpr { + fn hash_node(&self, state: &mut H) { + self.cast_type().hash(state); + } +} + /// Return a PhysicalExpression representing `expr` casted to /// `cast_type`, if any casting is needed. /// diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index f64bea842e6b..ad8af9f38c55 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -26,6 +26,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -93,6 +94,12 @@ impl Hash for UnKnownColumn { } } +impl HashNode for UnKnownColumn { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + } +} + impl PartialEq for UnKnownColumn { fn eq(&self, _other: &Self) -> bool { // UnknownColumn is not a valid expression, so it should not be equal to any other expression. diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index eadd3ba26235..d8b5051b705d 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -39,11 +39,13 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::Array; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF}; +use datafusion_expr_common::signature::Volatility; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -210,6 +212,19 @@ impl PhysicalExpr for ScalarFunctionExpr { range, }) } + + fn is_volatile(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +impl HashNode for ScalarFunctionExpr { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + self.return_type.hash(state); + self.nullable.hash(state); + // Add `self.fun` when hash is available + } } /// Create a physical expression for the UDF. diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index e7bf4a80fc45..02b8374b0870 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -36,6 +36,7 @@ arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } diff --git a/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs b/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs new file mode 100644 index 000000000000..3360532edd76 --- /dev/null +++ b/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs @@ -0,0 +1,497 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateCommonPhysicalSubexprs`] to avoid redundant computation of common physical +//! sub-expressions. + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::ConfigOptions; +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::Result; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column}; +use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; +use datafusion_physical_plan::projection::ProjectionExec; + +const CSE_PREFIX: &str = "__common_physical_expr"; + +// Optimizer rule to avoid redundant computation of common physical subexpressions +#[derive(Default, Debug)] +pub struct EliminateCommonPhysicalSubexprs {} + +impl EliminateCommonPhysicalSubexprs { + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for EliminateCommonPhysicalSubexprs { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_down(|plan| { + let plan_any = plan.as_any(); + if let Some(p) = plan_any.downcast_ref::() { + match CSE::new(PhysicalExprCSEController::new( + config.alias_generator.as_ref(), + p.input().schema().fields().len(), + )) + .extract_common_nodes(vec![p + .expr() + .iter() + .map(|(e, _)| e) + .cloned() + .collect()])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let common_exprs = p + .input() + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + ( + Arc::new(Column::new(field.name(), i)) + as Arc, + field.name().to_string(), + ) + }) + .chain(common_exprs) + .collect(); + let common = Arc::new(ProjectionExec::try_new( + common_exprs, + Arc::clone(p.input()), + )?); + + let new_exprs = new_exprs_list + .pop() + .unwrap() + .into_iter() + .zip(p.expr().iter().map(|(_, alias)| alias.to_string())) + .collect(); + let new_project = + Arc::new(ProjectionExec::try_new(new_exprs, common)?) + as Arc; + + Ok(Transformed::yes(new_project)) + } + FoundCommonNodes::No { .. } => Ok(Transformed::no(plan)), + } + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "eliminate_common_physical_subexpressions" + } + + /// This rule will change the nullable properties of the schema, disable the schema check. + fn schema_check(&self) -> bool { + false + } +} + +pub struct PhysicalExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + base_index: usize, +} + +impl<'a> PhysicalExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, base_index: usize) -> Self { + Self { + alias_generator, + base_index, + } + } +} + +impl CSEController for PhysicalExprCSEController<'_> { + type Node = Arc; + + fn conditional_children( + node: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + if let Some(s) = node.as_any().downcast_ref::() { + // In case of `ScalarFunction`s all children can be conditionally executed. + if s.fun().short_circuits() { + Some((vec![], s.args().iter().collect())) + } else { + None + } + } else if let Some(b) = node.as_any().downcast_ref::() { + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + if *b.op() == Operator::And || *b.op() == Operator::Or { + Some((vec![b.left()], vec![b.right()])) + } else { + None + } + } else { + node.as_any().downcast_ref::().map(|c| { + ( + // In case of `Case` the optional base expression and the first when + // expressions are surely executed, but we account subexpressions as + // conditional in the others. + c.expr() + .into_iter() + .chain(c.when_then_expr().iter().take(1).map(|(when, _)| when)) + .collect(), + c.when_then_expr() + .iter() + .take(1) + .map(|(_, then)| then) + .chain( + c.when_then_expr() + .iter() + .skip(1) + .flat_map(|(when, then)| [when, then]), + ) + .chain(c.else_expr()) + .collect(), + ) + }) + } + } + + fn is_valid(node: &Self::Node) -> bool { + !node.is_volatile() + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + node.children().is_empty() + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, _node: &Self::Node, alias: &str, index: usize) -> Self::Node { + Arc::new(Column::new(alias, self.base_index + index)) + } + + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +#[cfg(test)] +mod tests { + use crate::eliminate_common_physical_subexprs::EliminateCommonPhysicalSubexprs; + use crate::optimizer::PhysicalOptimizerRule; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::Result; + use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::operator::Operator; + use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_physical_expr::expressions::{binary, col, lit}; + use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; + use datafusion_physical_plan::memory::MemoryExec; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; + use std::any::Any; + use std::sync::Arc; + + fn mock_data() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + Arc::new(MemoryExec::try_new(&[vec![]], Arc::clone(&schema), None).unwrap()) + } + + #[test] + fn subexpr_in_same_order() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let lit_1 = lit(1); + let _1_plus_a = binary(lit_1, Operator::Plus, a, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(&_1_plus_a), "first".to_string()), + (_1_plus_a, "second".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as first, __common_physical_expr_1@2 as second]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, 1 + a@0 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn subexpr_in_different_order() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let lit_1 = lit(1); + let _1_plus_a = binary( + Arc::clone(&lit_1), + Operator::Plus, + Arc::clone(&a), + &table_scan.schema(), + )?; + let a_plus_1 = binary(a, Operator::Plus, lit_1, &table_scan.schema())?; + + let exprs = vec![ + (_1_plus_a, "first".to_string()), + (a_plus_1, "second".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[1 + a@0 as first, a@0 + 1 as second]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let extracted_child = binary(a, Operator::Plus, b, &table_scan.schema())?; + let rand = rand_expr(); + let not_extracted_volatile = + binary(extracted_child, Operator::Plus, rand, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(¬_extracted_volatile), "c1".to_string()), + (not_extracted_volatile, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 + random() as c1, __common_physical_expr_1@2 + random() as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_volatile_short_circuits() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let rand = rand_expr(); + let rand_eq_0 = binary(rand, Operator::Eq, lit(0), &table_scan.schema())?; + + let extracted_short_circuit_leg_1 = + binary(a, Operator::Eq, lit(0), &table_scan.schema())?; + let not_extracted_volatile_short_circuit_1 = binary( + extracted_short_circuit_leg_1, + Operator::Or, + Arc::clone(&rand_eq_0), + &table_scan.schema(), + )?; + + let not_extracted_short_circuit_leg_2 = + binary(b, Operator::Eq, lit(0), &table_scan.schema())?; + let not_extracted_volatile_short_circuit_2 = binary( + rand_eq_0, + Operator::Or, + not_extracted_short_circuit_leg_2, + &table_scan.schema(), + )?; + + let exprs = vec![ + ( + Arc::clone(¬_extracted_volatile_short_circuit_1), + "c1".to_string(), + ), + (not_extracted_volatile_short_circuit_1, "c2".to_string()), + ( + Arc::clone(¬_extracted_volatile_short_circuit_2), + "c3".to_string(), + ), + (not_extracted_volatile_short_circuit_2, "c4".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 OR random() = 0 as c1, __common_physical_expr_1@2 OR random() = 0 as c2, random() = 0 OR b@1 = 0 as c3, random() = 0 OR b@1 = 0 as c4]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 = 0 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_non_top_level_common_expression() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let common_expr = binary(a, Operator::Plus, b, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(&common_expr), "c1".to_string()), + (common_expr, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let c1 = col("c1", &plan.schema())?; + let c2 = col("c2", &plan.schema())?; + + let exprs = vec![(c1, "c1".to_string()), (c2, "c2".to_string())]; + let plan = Arc::new(ProjectionExec::try_new(exprs, plan)?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2]", + " ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_nested_common_expression() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let nested_common_expr = binary(a, Operator::Plus, b, &table_scan.schema())?; + let common_expr = binary( + Arc::clone(&nested_common_expr), + Operator::Multiply, + nested_common_expr, + &table_scan.schema(), + )?; + + let exprs = vec![ + (Arc::clone(&common_expr), "c1".to_string()), + (common_expr, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, __common_physical_expr_2@2 * __common_physical_expr_2@2 as __common_physical_expr_1]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_2]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + fn rand_expr() -> Arc { + let r = RandomStub::new(); + let n = r.name().to_string(); + let t = r.return_type(&[]).unwrap(); + Arc::new(ScalarFunctionExpr::new( + &n, + Arc::new(ScalarUDF::new_from_impl(r)), + vec![], + t, + )) + } + + #[derive(Debug)] + struct RandomStub { + signature: Signature, + } + + impl RandomStub { + fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for RandomStub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 439f1dc873d1..bd81c5e0d0a0 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -19,6 +19,7 @@ pub mod aggregate_statistics; pub mod combine_partial_final_agg; +pub mod eliminate_common_physical_subexprs; pub mod limit_pushdown; pub mod limited_distinct_aggregation; mod optimizer; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index e22eed73b394..d0c06d969d58 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::Display; +use std::hash::Hasher; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -81,6 +82,7 @@ use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr, Stati use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; +use datafusion_common::cse::HashNode; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; @@ -785,6 +787,10 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } } + impl HashNode for CustomPredicateExpr { + fn hash_node(&self, _state: &mut H) {} + } + #[derive(Debug)] struct CustomPhysicalExtensionCodec; impl PhysicalExtensionCodec for CustomPhysicalExtensionCodec { diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b1962ffcc116..ba46848c030f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -254,6 +254,7 @@ physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPAC physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after LimitPushdown SAME TEXT AS ABOVE +physical_plan after eliminate_common_physical_subexpressions SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] @@ -330,6 +331,7 @@ physical_plan after OutputRequirements physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after LimitPushdown ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after eliminate_common_physical_subexpressions SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan_with_schema ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] @@ -370,6 +372,7 @@ physical_plan after OutputRequirements physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after LimitPushdown ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after eliminate_common_physical_subexpressions SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 physical_plan_with_stats ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]]