Skip to content

Commit

Permalink
feat(df-repr): add back join order enumeration (#204)
Browse files Browse the repository at this point in the history
ref #194

after the memo table refactor, adding back a more efficient join order
enumeration implementation.

---------

Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh authored Oct 30, 2024
1 parent ae425ed commit 4096073
Show file tree
Hide file tree
Showing 14 changed files with 241 additions and 161 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions optd-core/src/cascades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mod memo;
mod optimizer;
mod tasks;

use memo::Memo;
pub use optimizer::{CascadesOptimizer, GroupId, OptimizerProperties, RelNodeContext};
pub use memo::Memo;
pub use optimizer::{CascadesOptimizer, ExprId, GroupId, OptimizerProperties, RelNodeContext};
use tasks::Task;
19 changes: 17 additions & 2 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ pub struct RelMemoNode<T: RelNodeTyp> {
pub data: Option<Value>,
}

impl<T: RelNodeTyp> RelMemoNode<T> {
pub fn into_rel_node(self) -> RelNode<T> {
RelNode {
typ: self.typ,
children: self
.children
.into_iter()
.map(|x| Arc::new(RelNode::new_group(x)))
.collect(),
data: self.data,
}
}
}

impl<T: RelNodeTyp> std::fmt::Display for RelMemoNode<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}", self.typ)?;
Expand Down Expand Up @@ -401,7 +415,7 @@ impl<T: RelNodeTyp> Memo<T> {
}

/// Get the memoized representation of a node, only for debugging purpose
pub(crate) fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef<T> {
pub fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef<T> {
while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) {
expr_id = *new_expr_id;
}
Expand All @@ -411,7 +425,8 @@ impl<T: RelNodeTyp> Memo<T> {
.clone()
}

pub(crate) fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
pub fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
let group_id = self.reduce_group(group_id);
let group = self.groups.get(&group_id).expect("group not found");
let mut exprs = group.group_exprs.iter().copied().collect_vec();
exprs.sort();
Expand Down
4 changes: 4 additions & 0 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
.map(|x| x.cost.0[0])
.unwrap_or(0.0)
}

pub fn memo(&self) -> &Memo<T> {
&self.memo
}
}

impl<T: RelNodeTyp> Optimizer<T> for CascadesOptimizer<T> {
Expand Down
1 change: 1 addition & 0 deletions optd-datafusion-bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ anyhow = "1"
async-recursion = "1"
futures-lite = "2"
futures-util = "0.3"
itertools = "0.11"
152 changes: 14 additions & 138 deletions optd-datafusion-bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ use datafusion::{
physical_plan::{displayable, explain::ExplainExec, ExecutionPlan},
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
};
use itertools::Itertools;
use optd_datafusion_repr::{
plan_nodes::{
ConstantType, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PhysicalHashJoin,
PhysicalNestedLoopJoin, PlanNode,
},
plan_nodes::{ConstantType, OptRelNode, PlanNode},
properties::schema::Catalog,
DatafusionOptimizer,
DatafusionOptimizer, MemoExt,
};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -89,93 +87,6 @@ pub struct OptdQueryPlanner {
pub optimizer: Arc<Mutex<Option<Box<DatafusionOptimizer>>>>,
}

#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
enum JoinOrder {
Table(String),
HashJoin(Box<Self>, Box<Self>),
NestedLoopJoin(Box<Self>, Box<Self>),
}

#[allow(dead_code)]
impl JoinOrder {
pub fn conv_into_logical_join_order(&self) -> LogicalJoinOrder {
match self {
JoinOrder::Table(name) => LogicalJoinOrder::Table(name.clone()),
JoinOrder::HashJoin(left, right) => LogicalJoinOrder::Join(
Box::new(left.conv_into_logical_join_order()),
Box::new(right.conv_into_logical_join_order()),
),
JoinOrder::NestedLoopJoin(left, right) => LogicalJoinOrder::Join(
Box::new(left.conv_into_logical_join_order()),
Box::new(right.conv_into_logical_join_order()),
),
}
}
}

#[allow(unused)]
#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
enum LogicalJoinOrder {
Table(String),
Join(Box<Self>, Box<Self>),
}

#[allow(dead_code)]
fn get_join_order(rel_node: OptRelNodeRef) -> Option<JoinOrder> {
match rel_node.typ {
OptRelNodeTyp::PhysicalHashJoin(_) => {
let join = PhysicalHashJoin::from_rel_node(rel_node.clone()).unwrap();
let left = get_join_order(join.left().into_rel_node())?;
let right = get_join_order(join.right().into_rel_node())?;
Some(JoinOrder::HashJoin(Box::new(left), Box::new(right)))
}
OptRelNodeTyp::PhysicalNestedLoopJoin(_) => {
let join = PhysicalNestedLoopJoin::from_rel_node(rel_node.clone()).unwrap();
let left = get_join_order(join.left().into_rel_node())?;
let right = get_join_order(join.right().into_rel_node())?;
Some(JoinOrder::NestedLoopJoin(Box::new(left), Box::new(right)))
}
OptRelNodeTyp::PhysicalScan => {
let scan =
optd_datafusion_repr::plan_nodes::PhysicalScan::from_rel_node(rel_node).unwrap();
Some(JoinOrder::Table(scan.table().to_string()))
}
_ => {
for child in &rel_node.children {
if let Some(res) = get_join_order(child.clone()) {
return Some(res);
}
}
None
}
}
}

impl std::fmt::Display for LogicalJoinOrder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogicalJoinOrder::Table(name) => write!(f, "{}", name),
LogicalJoinOrder::Join(left, right) => {
write!(f, "(Join {} {})", left, right)
}
}
}
}

impl std::fmt::Display for JoinOrder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JoinOrder::Table(name) => write!(f, "{}", name),
JoinOrder::HashJoin(left, right) => {
write!(f, "(HashJoin {} {})", left, right)
}
JoinOrder::NestedLoopJoin(left, right) => {
write!(f, "(NLJ {} {})", left, right)
}
}
}
}

impl OptdQueryPlanner {
pub fn enable_adaptive(&self) {
self.optimizer
Expand Down Expand Up @@ -247,7 +158,7 @@ impl OptdQueryPlanner {
}
}

let (_, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;
let (group_id, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;

if let Some(explains) = &mut explains {
explains.push(StringifiedPlan::new(
Expand All @@ -258,52 +169,17 @@ impl OptdQueryPlanner {
.unwrap()
.explain_to_string(if verbose { Some(&meta) } else { None }),
));

// const ENABLE_JOIN_ORDER: bool = false;

// if ENABLE_JOIN_ORDER {
// let join_order = get_join_order(optimized_rel.clone());
// explains.push(StringifiedPlan::new(
// PlanType::OptimizedPhysicalPlan {
// optimizer_name: "optd-join-order".to_string(),
// },
// if let Some(join_order) = join_order {
// join_order.to_string()
// } else {
// "None".to_string()
// },
// ));
// let bindings = optimizer
// .optd_cascades_optimizer()
// .get_all_group_bindings(group_id, true);
// let mut join_orders = BTreeSet::new();
// let mut logical_join_orders = BTreeSet::new();
// for binding in bindings {
// if let Some(join_order) = get_join_order(binding) {
// logical_join_orders.insert(join_order.conv_into_logical_join_order());
// join_orders.insert(join_order);
// }
// }
// explains.push(StringifiedPlan::new(
// PlanType::OptimizedPhysicalPlan {
// optimizer_name: "optd-all-join-orders".to_string(),
// },
// join_orders.iter().map(|x| x.to_string()).join("\n"),
// ));
// explains.push(StringifiedPlan::new(
// PlanType::OptimizedPhysicalPlan {
// optimizer_name: "optd-all-logical-join-orders".to_string(),
// },
// logical_join_orders.iter().map(|x| x.to_string()).join("\n"),
// ));
// }
let join_orders = optimizer
.optd_cascades_optimizer()
.memo()
.enumerate_join_order(group_id);
explains.push(StringifiedPlan::new(
PlanType::OptimizedPhysicalPlan {
optimizer_name: "optd-all-logical-join-orders".to_string(),
},
join_orders.iter().map(|x| x.to_string()).join("\n"),
));
}
// println!(
// "{} cost={}",
// get_join_order(optimized_rel.clone()).unwrap(),
// optimizer.optd_optimizer().get_cost_of(group_id)
// );
// optimizer.dump(Some(group_id));
ctx.optimizer = Some(&optimizer);
let physical_plan = ctx.conv_from_optd(optimized_rel, meta).await?;
if let Some(explains) = &mut explains {
Expand Down
4 changes: 3 additions & 1 deletion optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ use crate::rules::{
DepInitialDistinct, DepJoinEliminateAtScan, DepJoinPastAgg, DepJoinPastFilter, DepJoinPastProj,
};

pub use memo_ext::{LogicalJoinOrder, MemoExt};

pub mod cost;
mod explain;
mod memo_ext;
pub mod plan_nodes;
pub mod properties;
pub mod rules;
#[cfg(test)]
mod testing;
// mod expand;

pub struct DatafusionOptimizer {
heuristic_optimizer: HeuristicsOptimizer<OptRelNodeTyp>,
Expand Down
Loading

0 comments on commit 4096073

Please sign in to comment.