Skip to content

Commit

Permalink
Implement PhysicalExpr CSE
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Oct 21, 2024
1 parent 8065fb2 commit 9361ee6
Show file tree
Hide file tree
Showing 28 changed files with 740 additions and 61 deletions.
2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ arrow-schema = { workspace = true }
chrono = { workspace = true }
half = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }
libc = "0.2.140"
num_cpus = { workspace = true }
object_store = { workspace = true, optional = true }
Expand Down
4 changes: 4 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ use std::collections::{BTreeMap, HashMap};
use std::fmt::{self, Display};
use std::str::FromStr;

use crate::alias::AliasGenerator;
use crate::error::_config_err;
use crate::parsers::CompressionTypeVariant;
use crate::{DataFusionError, Result};
use std::sync::Arc;

/// A macro that wraps a configuration struct and automatically derives
/// [`Default`] and [`ConfigField`] for it, allowing it to be used
Expand Down Expand Up @@ -674,6 +676,8 @@ pub struct ConfigOptions {
pub explain: ExplainOptions,
/// Optional extensions registered using [`Extensions::insert`]
pub extensions: Extensions,
/// Return alias generator used to generate unique aliases
pub alias_generator: Arc<AliasGenerator>,
}

impl ConfigField for ConfigOptions {
Expand Down
45 changes: 26 additions & 19 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use crate::tree_node::{
TreeNodeVisitor,
};
use crate::Result;
use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
Expand Down Expand Up @@ -123,11 +122,11 @@ type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;

/// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s
/// by their identifiers.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize)>;
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize, Option<usize>)>;

/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
/// extracted during the second, rewriting traversal.
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
type CommonNodes<'n, N> = Vec<(N, String)>;

type ChildrenList<N> = (Vec<N>, Vec<N>);

Expand Down Expand Up @@ -155,7 +154,7 @@ pub trait CSEController {
fn generate_alias(&self) -> String;

// Replaces a node to the generated alias.
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node;

// A helper method called on each node during top-down traversal during the second,
// rewriting traversal of CSE.
Expand Down Expand Up @@ -331,8 +330,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.controller.is_ignored(node) {
self.id_array[down_index].1 = Some(node_id);
let (count, conditional_count) =
self.node_stats.entry(node_id).or_insert((0, 0));
let (count, conditional_count, _) =
self.node_stats.entry(node_id).or_insert((0, 0, None));
if self.conditional {
*conditional_count += 1;
} else {
Expand All @@ -355,7 +354,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
/// replaced [`TreeNode`] tree.
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,
node_stats: &'a mut NodeStats<'n, N>,

/// cache to speed up second traversal
id_array: &'a IdArray<'n, N>,
Expand Down Expand Up @@ -383,7 +382,8 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter

// Handle nodes with identifiers only
if let Some(node_id) = node_id {
let (count, conditional_count) = self.node_stats.get(&node_id).unwrap();
let (count, conditional_count, common_index) =
self.node_stats.get_mut(&node_id).unwrap();
if *count > 1 || *count == 1 && *conditional_count > 0 {
// step index to skip all sub-node (which has smaller series number).
while self.down_index < self.id_array.len()
Expand All @@ -392,13 +392,15 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
self.down_index += 1;
}

let (node, alias) =
self.common_nodes.entry(node_id).or_insert_with(|| {
let node_alias = self.controller.generate_alias();
(node, node_alias)
});
let index = *common_index.get_or_insert_with(|| {
let index = self.common_nodes.len();
let node_alias = self.controller.generate_alias();
self.common_nodes.push((node, node_alias));
index
});
let (node, alias) = self.common_nodes.get(index).unwrap();

let rewritten = self.controller.rewrite(node, alias);
let rewritten = self.controller.rewrite(node, alias, index);

return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
}
Expand Down Expand Up @@ -491,7 +493,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
&mut self,
node: N,
id_array: &IdArray<'n, N>,
node_stats: &NodeStats<'n, N>,
node_stats: &mut NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<N> {
if id_array.is_empty() {
Expand All @@ -514,7 +516,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
&mut self,
nodes_list: Vec<Vec<N>>,
arrays_list: &[Vec<IdArray<'n, N>>],
node_stats: &NodeStats<'n, N>,
node_stats: &mut NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<Vec<Vec<N>>> {
nodes_list
Expand Down Expand Up @@ -559,13 +561,13 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
// nodes so we have to keep them intact.
nodes_list.clone(),
&id_arrays_list,
&node_stats,
&mut node_stats,
&mut common_nodes,
)?;
assert!(!common_nodes.is_empty());

Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
common_nodes,
new_nodes_list,
original_nodes_list: nodes_list,
})
Expand Down Expand Up @@ -635,7 +637,12 @@ mod test {
self.alias_generator.next(CSE_PREFIX)
}

fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
fn rewrite(
&mut self,
node: &Self::Node,
alias: &str,
_index: usize,
) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}
}
Expand Down
10 changes: 7 additions & 3 deletions datafusion/core/src/physical_optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

//! Physical optimizer traits

use datafusion_physical_optimizer::PhysicalOptimizerRule;
use std::sync::Arc;

use super::projection_pushdown::ProjectionPushdown;
use super::update_aggr_exprs::OptimizeAggregateOrder;
use crate::physical_optimizer::aggregate_statistics::AggregateStatistics;
Expand All @@ -33,6 +30,9 @@ use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggr
use crate::physical_optimizer::output_requirements::OutputRequirements;
use crate::physical_optimizer::sanity_checker::SanityCheckPlan;
use crate::physical_optimizer::topk_aggregation::TopKAggregation;
use datafusion_physical_optimizer::eliminate_common_physical_subexprs::EliminateCommonPhysicalSubexprs;
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use std::sync::Arc;

/// A rule-based physical optimizer.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -103,6 +103,10 @@ impl PhysicalOptimizer {
// replacing operators with fetching variants, or adding limits
// past operators that support limit pushdown.
Arc::new(LimitPushdown::new()),
// The EliminateCommonPhysicalSubExprs rule extracts common physical
// subexpression trees into a `ProjectionExec` node under the actual node to
// calculate the common values only once.
Arc::new(EliminateCommonPhysicalSubexprs::new()),
// The SanityCheckPlan rule checks whether the order and
// distribution requirements of each node in the plan
// is satisfied. It will also reject non-runnable query
Expand Down
66 changes: 44 additions & 22 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,15 @@ impl CommonSubexprEliminate {
// Since `group_expr` may have changed, schema may also.
// Use `try_new()` method.
Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Transformed::no)
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
} else {
Aggregate::try_new_with_schema(
new_input,
new_group_expr,
rewritten_aggr_expr,
schema,
)
.map(LogicalPlan::Aggregate)
.map(Transformed::no)
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
}
}
}
Expand Down Expand Up @@ -617,9 +615,7 @@ impl CSEController for ExprCSEController<'_> {

fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
match node {
// In case of `ScalarFunction`s we don't know which children are surely
// executed so start visiting all children conditionally and stop the
// recursion with `TreeNodeRecursion::Jump`.
// In case of `ScalarFunction`s all children can be conditionally executed.
Expr::ScalarFunction(ScalarFunction { func, args })
if func.short_circuits() =>
{
Expand Down Expand Up @@ -689,7 +685,7 @@ impl CSEController for ExprCSEController<'_> {
self.alias_generator.next(CSE_PREFIX)
}

fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node {
// alias the expressions without an `Alias` ancestor node
if self.alias_counter > 0 {
col(alias)
Expand Down Expand Up @@ -1019,10 +1015,14 @@ mod test {
fn subexpr_in_same_order() -> Result<()> {
let table_scan = test_table_scan()?;

let a = col("a");
let lit_1 = lit(1);
let _1_plus_a = lit_1 + a;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
(lit(1) + col("a")).alias("first"),
(lit(1) + col("a")).alias("second"),
_1_plus_a.clone().alias("first"),
_1_plus_a.alias("second"),
])?
.build()?;

Expand All @@ -1039,8 +1039,13 @@ mod test {
fn subexpr_in_different_order() -> Result<()> {
let table_scan = test_table_scan()?;

let a = col("a");
let lit_1 = lit(1);
let _1_plus_a = lit_1.clone() + a.clone();
let a_plus_1 = a + lit_1;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
.project(vec![_1_plus_a, a_plus_1])?
.build()?;

let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
Expand All @@ -1055,6 +1060,8 @@ mod test {
fn cross_plans_subexpr() -> Result<()> {
let table_scan = test_table_scan()?;

let _1_plus_col_a = lit(1) + col("a");

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a")])?
.project(vec![lit(1) + col("a")])?
Expand Down Expand Up @@ -1307,9 +1314,12 @@ mod test {
fn test_volatile() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_child = col("a") + col("b");
let rand = rand_func().call(vec![]);
let a = col("a");
let b = col("b");
let extracted_child = a + b;
let rand = rand_expr();
let not_extracted_volatile = extracted_child + rand;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
not_extracted_volatile.clone().alias("c1"),
Expand All @@ -1330,13 +1340,19 @@ mod test {
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let rand = rand_func().call(vec![]);
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
let a = col("a");
let b = col("b");
let rand = rand_expr();
let rand_eq_0 = rand.eq(lit(0));

let extracted_short_circuit_leg_1 = a.eq(lit(0));
let not_extracted_volatile_short_circuit_1 =
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
extracted_short_circuit_leg_1.or(rand_eq_0.clone());

let not_extracted_short_circuit_leg_2 = b.eq(lit(0));
let not_extracted_volatile_short_circuit_2 =
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
rand_eq_0.or(not_extracted_short_circuit_leg_2);

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
Expand All @@ -1359,7 +1375,10 @@ mod test {
fn test_non_top_level_common_expression() -> Result<()> {
let table_scan = test_table_scan()?;

let common_expr = col("a") + col("b");
let a = col("a");
let b = col("b");
let common_expr = a + b;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
common_expr.clone().alias("c1"),
Expand All @@ -1382,8 +1401,11 @@ mod test {
fn test_nested_common_expression() -> Result<()> {
let table_scan = test_table_scan()?;

let nested_common_expr = col("a") + col("b");
let a = col("a");
let b = col("b");
let nested_common_expr = a + b;
let common_expr = nested_common_expr.clone() * nested_common_expr;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
common_expr.clone().alias("c1"),
Expand All @@ -1406,8 +1428,8 @@ mod test {
///
/// Does not use datafusion_functions::rand to avoid introducing a
/// dependency on that crate.
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
fn rand_expr() -> Expr {
ScalarUDF::new_from_impl(RandomStub::new()).call(vec![])
}

#[derive(Debug)]
Expand Down
Loading

0 comments on commit 9361ee6

Please sign in to comment.