Skip to content

Commit

Permalink
stash
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Nov 5, 2024
1 parent 524a17c commit 5c9739b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 41 deletions.
4 changes: 2 additions & 2 deletions optd-datafusion-repr/src/rules/filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn merge_conds(first: Expr, second: Expr) -> Expr {
LogOpExpr::new_flattened_nested_logical(LogOpType::And, new_expr_list).into_expr();
let mut changed = false;
// TODO: such simplifications should be invoked from optd-core, instead of ad-hoc
Expr::ensures_interpret(simplify_log_expr(flattened.strip(), &mut changed)).unwrap()
Expr::ensures_interpret(simplify_log_expr(flattened.unwrap_rel_node(), &mut changed))
}

#[derive(Debug, Clone, Copy)]
Expand All @@ -65,7 +65,7 @@ fn determine_join_cond_dep(
let mut right_col = false;
for child in children {
if child.typ() == OptRelNodeTyp::ColumnRef {
let col_ref = ColumnRefExpr::ensures_interpret(child.clone().strip()).unwrap();
let col_ref = ColumnRefExpr::ensures_interpret(child.clone().strip());
let index = col_ref.index();
if index < left_schema_size {
left_col = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ mod tests {

assert_eq!(plan.typ, OptRelNodeTyp::Projection);
assert_eq!(plan.child(1), res_proj_exprs);
assert!(matches!(plan.child(0).typ, OptRelNodeTyp::Scan));
assert!(matches!(plan.child_rel(0).typ, OptRelNodeTyp::Scan));
}

#[test]
Expand Down
64 changes: 26 additions & 38 deletions optd-datafusion-repr/src/rules/subquery/depjoin_pushdown.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use optd_core::rel_node::MaybeRelNode;
use optd_core::rel_node::{MaybeRelNode, RelNodeRef};
// TODO: No push past join
// TODO: Sideways information passing??
use optd_core::rules::{Rule, RuleMatcher};
Expand Down Expand Up @@ -122,8 +122,8 @@ fn apply_dep_initial_distinct(
let new_dep_join = DependentJoin::new(
distinct_agg_node.into_plan_node(),
PlanNode::from_group(right.into()),
Expr::ensures_interpret(cond.into()),
ExprList::ensures_interpret(extern_cols.into()),
Expr::ensures_interpret(cond),
ExprList::ensures_interpret(extern_cols),
JoinType::Cross,
);

Expand Down Expand Up @@ -201,7 +201,7 @@ fn apply_dep_join_past_proj(
) -> Vec<MaybeRelNode<OptRelNodeTyp>> {
// TODO: can we have external columns in projection node? I don't think so?
// Cross join should always have true cond
assert!(cond == *ConstantExpr::bool(true).strip());
assert!(cond == ConstantExpr::bool(true).strip());
let left_schema_len = optimizer
.get_property::<SchemaPropertyBuilder>(left.clone().into(), 0)
.len();
Expand All @@ -223,12 +223,12 @@ fn apply_dep_join_past_proj(
let new_dep_join = DependentJoin::new(
PlanNode::from_group(left.into()),
PlanNode::from_group(right.into()),
Expr::ensures_interpret(cond.into()).unwrap(),
ExprList::ensures_interpret(extern_cols.into()).unwrap(),
Expr::ensures_interpret(cond),
ExprList::ensures_interpret(extern_cols),
JoinType::Cross,
);
let new_proj = LogicalProjection::new(
PlanNode::ensures_interpret(new_dep_join.strip()).unwrap(),
PlanNode::ensures_interpret(new_dep_join.strip()),
new_proj_exprs,
);

Expand Down Expand Up @@ -261,24 +261,18 @@ fn apply_dep_join_past_filter(
}: DepJoinPastFilterPicks,
) -> Vec<MaybeRelNode<OptRelNodeTyp>> {
// Cross join should always have true cond
assert!(cond == *ConstantExpr::bool(true).strip());
assert!(cond == ConstantExpr::bool(true).strip());
let left_schema_len = optimizer
.get_property::<SchemaPropertyBuilder>(left.clone().into(), 0)
.len();

let correlated_col_indices = ExprList::ensures_interpret(extern_cols.clone().into())
.unwrap()
let correlated_col_indices = ExprList::ensures_interpret(extern_cols)
.to_vec()
.into_iter()
.map(|x| {
ExternColumnRefExpr::ensures_interpret(x.strip())
.unwrap()
.index()
})
.map(|x| ExternColumnRefExpr::ensures_interpret(x.strip()).index())
.collect::<Vec<usize>>();

let rewritten_expr = Expr::ensures_interpret(filter_cond.into())
.unwrap()
let rewritten_expr = Expr::ensures_interpret(filter_cond)
.rewrite_column_refs(&mut |col| Some(col + left_schema_len))
.unwrap();

Expand All @@ -294,7 +288,7 @@ fn apply_dep_join_past_filter(
let new_dep_join = DependentJoin::new(
PlanNode::from_group(left.into()),
PlanNode::from_group(right.into()),
Expr::ensures_interpret(cond.into()).unwrap(),
Expr::ensures_interpret(cond),
ExprList::new(
correlated_col_indices
.into_iter()
Expand All @@ -305,7 +299,7 @@ fn apply_dep_join_past_filter(
);

let new_filter = LogicalFilter::new(
PlanNode::ensures_interpret(new_dep_join.strip()).unwrap(),
PlanNode::ensures_interpret(new_dep_join.strip()),
rewritten_expr,
);

Expand Down Expand Up @@ -345,25 +339,21 @@ fn apply_dep_join_past_agg(
}: DepJoinPastAggPicks,
) -> Vec<MaybeRelNode<OptRelNodeTyp>> {
// Cross join should always have true cond
assert!(cond == *ConstantExpr::bool(true).strip());
assert!(cond == ConstantExpr::bool(true).strip());

// TODO: OUTER JOIN TRANSFORMATION

let extern_cols = ExprList::ensures_interpret(extern_cols.into()).unwrap();
let extern_cols = ExprList::ensures_interpret(extern_cols);
let correlated_col_indices = extern_cols
.to_vec()
.into_iter()
.map(|x| {
ColumnRefExpr::new(
ExternColumnRefExpr::ensures_interpret(x.strip())
.unwrap()
.index(),
)
.into_expr()
ColumnRefExpr::new(ExternColumnRefExpr::ensures_interpret(x.strip()).index())
.into_expr()
})
.collect::<Vec<Expr>>();

let groups = ExprList::ensures_interpret(groups.clone().into()).unwrap();
let groups = ExprList::ensures_interpret(groups);

let new_groups = ExprList::new(
groups
Expand All @@ -380,7 +370,7 @@ fn apply_dep_join_past_agg(
.collect(),
);

let exprs = ExprList::ensures_interpret(exprs.into()).unwrap();
let exprs = ExprList::ensures_interpret(exprs);

let new_exprs = ExprList::new(
exprs
Expand All @@ -396,13 +386,13 @@ fn apply_dep_join_past_agg(
let new_dep_join = DependentJoin::new(
PlanNode::from_group(left.into()),
PlanNode::from_group(right.into()),
Expr::ensures_interpret(cond.into()).unwrap(),
Expr::ensures_interpret(cond),
extern_cols,
JoinType::Cross,
);

let new_agg = LogicalAgg::new(
PlanNode::ensures_interpret(new_dep_join.strip()).unwrap(),
PlanNode::ensures_interpret(new_dep_join.strip()),
new_exprs,
new_groups,
);
Expand Down Expand Up @@ -430,24 +420,22 @@ fn apply_dep_join_eliminate_at_scan(
}: DepJoinEliminatePicks,
) -> Vec<MaybeRelNode<OptRelNodeTyp>> {
// Cross join should always have true cond
assert!(cond == *ConstantExpr::bool(true).strip());
assert!(cond == ConstantExpr::bool(true).strip());

fn inspect(node: &RelNode<OptRelNodeTyp>) -> bool {
if matches!(node.typ, OptRelNodeTyp::Placeholder(_)) {
unimplemented!("this is a heuristics rule");
}
fn inspect(node: RelNodeRef<OptRelNodeTyp>) -> bool {
if node.typ == OptRelNodeTyp::ExternColumnRef {
return false;
}
for child in &node.children {
if !inspect(child) {
if !inspect(child.unwrap_rel_node()) {
/* if it panics it's a heuristic rule */
return false;
}
}
true
}

if inspect(&right) {
if inspect(right.unwrap_rel_node()) {
let new_join = LogicalJoin::new(
PlanNode::from_group(left.into()),
PlanNode::from_group(right.into()),
Expand Down

0 comments on commit 5c9739b

Please sign in to comment.