From 5207c57c474a313659111ccc5794d31d72c0f84f Mon Sep 17 00:00:00 2001 From: Benjamin O Date: Sun, 16 Jun 2024 00:01:18 -0700 Subject: [PATCH] Correctness fixes --- optd-datafusion-repr/src/lib.rs | 4 ++ optd-datafusion-repr/src/plan_nodes.rs | 13 +++++- .../src/properties/column_ref.rs | 11 +++-- .../src/rules/subquery/depjoin_pushdown.rs | 44 +++++++++++++++++-- 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/optd-datafusion-repr/src/lib.rs b/optd-datafusion-repr/src/lib.rs index 09d2211b..f489347c 100644 --- a/optd-datafusion-repr/src/lib.rs +++ b/optd-datafusion-repr/src/lib.rs @@ -220,6 +220,10 @@ impl DatafusionOptimizer { } pub fn heuristic_optimize(&mut self, root_rel: OptRelNodeRef) -> OptRelNodeRef { + println!( + "{}", + PlanNode::from_group(root_rel.clone()).explain_to_string(None) + ); let res = self .hueristic_optimizer .optimize(root_rel) diff --git a/optd-datafusion-repr/src/plan_nodes.rs b/optd-datafusion-repr/src/plan_nodes.rs index bd40e4f3..a243e3bc 100644 --- a/optd-datafusion-repr/src/plan_nodes.rs +++ b/optd-datafusion-repr/src/plan_nodes.rs @@ -322,8 +322,17 @@ impl Expr { .into_iter() .map(|child| { if child.typ == OptRelNodeTyp::List { - // TODO: What should we do with List? - return Some(child); + return Some( + ExprList::new( + ExprList::from_rel_node(child.clone()) + .unwrap() + .to_vec() + .into_iter() + .map(|x| x.rewrite_column_refs(rewrite_fn).unwrap()) + .collect(), + ) + .into_rel_node(), + ); } Expr::from_rel_node(child.clone()) .unwrap() diff --git a/optd-datafusion-repr/src/properties/column_ref.rs b/optd-datafusion-repr/src/properties/column_ref.rs index 1f552147..7a69d7d5 100644 --- a/optd-datafusion-repr/src/properties/column_ref.rs +++ b/optd-datafusion-repr/src/properties/column_ref.rs @@ -81,14 +81,17 @@ impl PropertyBuilder for ColumnRefPropertyBuilder { // Concatentate the children properties. Self::concat_children_properties(children) } - OptRelNodeTyp::Projection => children[1] + OptRelNodeTyp::Projection => { + children[1] .iter() - .map(|p| match p { + .map(|p| { + match p { ColumnRef::ChildColumnRef { col_idx } => children[0][*col_idx].clone(), ColumnRef::Derived => ColumnRef::Derived, _ => panic!("projection expr must be Derived or ChildColumnRef"), - }) - .collect(), + }}) + .collect() + } // Should account for all physical join types. OptRelNodeTyp::Join(_) | OptRelNodeTyp::DepJoin(_) diff --git a/optd-datafusion-repr/src/rules/subquery/depjoin_pushdown.rs b/optd-datafusion-repr/src/rules/subquery/depjoin_pushdown.rs index 7f696f9e..e8758da0 100644 --- a/optd-datafusion-repr/src/rules/subquery/depjoin_pushdown.rs +++ b/optd-datafusion-repr/src/rules/subquery/depjoin_pushdown.rs @@ -1,5 +1,4 @@ // TODO: No push past join -// TODO: support multiple depjoin (move to rewriting pass, probably) // TODO: Sideways information passing?? use itertools::Itertools; use optd_core::rel_node::Value; @@ -167,7 +166,24 @@ fn apply_dep_initial_distinct( JoinType::Inner, ); - vec![new_join.into_rel_node().as_ref().clone()] + // Ensure that the schema above the new_join is the same as it was before + // for correctness (Project the left side of the new join, + // plus the *right side of the right side*) + let new_proj = LogicalProjection::new( + PlanNode::from_rel_node(new_join.into_rel_node()).unwrap(), + ExprList::new( + (0..left_schema_size) + .chain( + (left_schema_size + correlated_col_indices.len()) + ..(left_schema_size + correlated_col_indices.len() + right_schema_size), + ) + .into_iter() + .map(|x| ColumnRefExpr::new(x).into_expr()) + .collect(), + ), + ); + + vec![new_proj.into_rel_node().as_ref().clone()] } define_rule!( @@ -365,7 +381,27 @@ fn apply_dep_join_past_agg( groups .to_vec() .into_iter() - .chain(correlated_col_indices.clone()) + .map(|x| { + x.rewrite_column_refs(&mut |col| Some(col + correlated_col_indices.len())) + .unwrap() + }) + .chain(correlated_col_indices.iter().map(|x| { + x.rewrite_column_refs(&mut |col| Some(col + correlated_col_indices.len())) + .unwrap() + })) + .collect(), + ); + + let exprs = ExprList::from_rel_node(exprs.into()).unwrap(); + + let new_exprs = ExprList::new( + exprs + .to_vec() + .into_iter() + .map(|x| { + x.rewrite_column_refs(&mut |col| Some(col + correlated_col_indices.len())) + .unwrap() + }) .collect(), ); @@ -379,7 +415,7 @@ fn apply_dep_join_past_agg( let new_agg = LogicalAgg::new( PlanNode::from_rel_node(new_dep_join.into_rel_node()).unwrap(), - ExprList::from_rel_node(exprs.into()).unwrap(), + new_exprs, new_groups, );