Skip to content

Commit

Permalink
refactor(core): add memo generic to cascades optimizer (#216)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh authored Nov 3, 2024
1 parent bd0dc0e commit 7005948
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 88 deletions.
15 changes: 9 additions & 6 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub struct Group {
}

/// Trait for memo table implementations.
pub trait Memo<T: RelNodeTyp> {
pub trait Memo<T: RelNodeTyp>: 'static + Send + Sync {
/// Add an expression to the memo table. If the expression already exists, it will return the existing group id and
/// expr id. Otherwise, a new group and expr will be created.
fn add_new_expr(&mut self, rel_node: RelNodeRef<T>) -> (GroupId, ExprId);
Expand All @@ -130,6 +130,10 @@ pub trait Memo<T: RelNodeTyp> {
/// Update the group info.
fn update_group_info(&mut self, group_id: GroupId, group_info: GroupInfo);

/// Estimated plan space for the memo table, only useful when plan exploration budget is enabled.
/// Returns number of expressions in the memo table.
fn estimated_plan_space(&self) -> usize;

// The below functions can be overwritten by the memo table implementation if there
// are more efficient way to retrieve the information.

Expand Down Expand Up @@ -330,6 +334,10 @@ impl<T: RelNodeTyp> Memo<T> for NaiveMemo<T> {
let grp = self.groups.get_mut(&group_id);
grp.unwrap().info = group_info;
}

fn estimated_plan_space(&self) -> usize {
self.expr_id_to_expr_node.len()
}
}

impl<T: RelNodeTyp> NaiveMemo<T> {
Expand Down Expand Up @@ -605,11 +613,6 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
group.info.winner = Winner::Unknown;
}
}

/// Return number of expressions in the memo table.
pub fn compute_plan_space(&self) -> usize {
self.expr_id_to_expr_node.len()
}
}

#[cfg(test)]
Expand Down
62 changes: 32 additions & 30 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ pub struct OptimizerProperties {
pub partial_explore_space: Option<usize>,
}

pub struct CascadesOptimizer<T: RelNodeTyp> {
memo: NaiveMemo<T>,
pub(super) tasks: VecDeque<Box<dyn Task<T>>>,
pub struct CascadesOptimizer<T: RelNodeTyp, M: Memo<T> = NaiveMemo<T>> {
memo: M,
pub(super) tasks: VecDeque<Box<dyn Task<T, M>>>,
explored_group: HashSet<GroupId>,
explored_expr: HashSet<ExprId>,
fired_rules: HashMap<ExprId, HashSet<RuleId>>,
rules: Arc<[Arc<RuleWrapper<T, Self>>]>,
disabled_rules: HashSet<usize>,
cost: Arc<dyn CostModel<T>>,
cost: Arc<dyn CostModel<T, M>>,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
pub ctx: OptimizerContext,
pub prop: OptimizerProperties,
Expand Down Expand Up @@ -80,22 +80,18 @@ impl Display for ExprId {
}
}

impl<T: RelNodeTyp> CascadesOptimizer<T> {
impl<T: RelNodeTyp> CascadesOptimizer<T, NaiveMemo<T>> {
pub fn new(
rules: Vec<Arc<RuleWrapper<T, Self>>>,
cost: Box<dyn CostModel<T>>,
cost: Box<dyn CostModel<T, NaiveMemo<T>>>,
property_builders: Vec<Box<dyn PropertyBuilderAny<T>>>,
) -> Self {
Self::new_with_prop(rules, cost, property_builders, Default::default())
}

pub fn panic_on_explore_limit(&mut self, enabled: bool) {
self.prop.panic_on_budget = enabled;
}

pub fn new_with_prop(
rules: Vec<Arc<RuleWrapper<T, Self>>>,
cost: Box<dyn CostModel<T>>,
cost: Box<dyn CostModel<T, NaiveMemo<T>>>,
property_builders: Vec<Box<dyn PropertyBuilderAny<T>>>,
prop: OptimizerProperties,
) -> Self {
Expand All @@ -117,7 +113,27 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
}
}

pub fn cost(&self) -> Arc<dyn CostModel<T>> {
/// Clear the memo table and all optimizer states.
pub fn step_clear(&mut self) {
self.memo = NaiveMemo::new(self.property_builders.clone());
self.fired_rules.clear();
self.explored_group.clear();
self.explored_expr.clear();
}

/// Clear the winner so that the optimizer can continue to explore the group.
pub fn step_clear_winner(&mut self) {
self.memo.clear_winner();
self.explored_expr.clear();
}
}

impl<T: RelNodeTyp, M: Memo<T>> CascadesOptimizer<T, M> {
pub fn panic_on_explore_limit(&mut self, enabled: bool) {
self.prop.panic_on_budget = enabled;
}

pub fn cost(&self) -> Arc<dyn CostModel<T, M>> {
self.cost.clone()
}

Expand Down Expand Up @@ -173,20 +189,6 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
}
}

/// Clear the memo table and all optimizer states.
pub fn step_clear(&mut self) {
self.memo = NaiveMemo::new(self.property_builders.clone());
self.fired_rules.clear();
self.explored_group.clear();
self.explored_expr.clear();
}

/// Clear the winner so that the optimizer can continue to explore the group.
pub fn step_clear_winner(&mut self) {
self.memo.clear_winner();
self.explored_expr.clear();
}

/// Optimize a `RelNode`.
pub fn step_optimize_rel(&mut self, root_rel: RelNodeRef<T>) -> Result<GroupId> {
let (group_id, _) = self.add_new_expr(root_rel);
Expand Down Expand Up @@ -228,14 +230,14 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
.push_back(Box::new(OptimizeGroupTask::new(group_id)));
// get the task from the stack
self.ctx.budget_used = false;
let plan_space_begin = self.memo.compute_plan_space();
let plan_space_begin = self.memo.estimated_plan_space();
let mut iter = 0;
while let Some(task) = self.tasks.pop_back() {
let new_tasks = task.execute(self)?;
self.tasks.extend(new_tasks);
iter += 1;
if !self.ctx.budget_used {
let plan_space = self.memo.compute_plan_space();
let plan_space = self.memo.estimated_plan_space();
if let Some(partial_explore_space) = self.prop.partial_explore_space {
if plan_space - plan_space_begin > partial_explore_space {
println!(
Expand Down Expand Up @@ -362,12 +364,12 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
.insert(rule_id);
}

pub fn memo(&self) -> &NaiveMemo<T> {
pub fn memo(&self) -> &M {
&self.memo
}
}

impl<T: RelNodeTyp> Optimizer<T> for CascadesOptimizer<T> {
impl<T: RelNodeTyp, M: Memo<T>> Optimizer<T> for CascadesOptimizer<T, M> {
fn optimize(&mut self, root_rel: RelNodeRef<T>) -> Result<RelNodeRef<T>> {
self.optimize_inner(root_rel)
}
Expand Down
6 changes: 3 additions & 3 deletions optd-core/src/cascades/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;

use crate::rel_node::RelNodeTyp;

use super::CascadesOptimizer;
use super::{CascadesOptimizer, Memo};

mod apply_rule;
mod explore_group;
Expand All @@ -16,7 +16,7 @@ pub use optimize_expression::OptimizeExpressionTask;
pub use optimize_group::OptimizeGroupTask;
pub use optimize_inputs::OptimizeInputsTask;

pub trait Task<T: RelNodeTyp>: 'static + Send + Sync {
fn execute(&self, optimizer: &mut CascadesOptimizer<T>) -> Result<Vec<Box<dyn Task<T>>>>;
pub trait Task<T: RelNodeTyp, M: Memo<T>>: 'static + Send + Sync {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>>;
fn describe(&self) -> String;
}
26 changes: 13 additions & 13 deletions optd-core/src/cascades/tasks/apply_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
memo::RelMemoNodeRef,
optimizer::{CascadesOptimizer, ExprId, RuleId},
tasks::{OptimizeExpressionTask, OptimizeInputsTask},
GroupId,
GroupId, Memo,
},
rel_node::{RelNode, RelNodeTyp},
rules::{OptimizeType, RuleMatcher},
Expand All @@ -33,12 +33,12 @@ impl ApplyRuleTask {
}
}

fn match_node<T: RelNodeTyp>(
fn match_node<T: RelNodeTyp, M: Memo<T>>(
typ: &T,
children: &[RuleMatcher<T>],
pick_to: Option<usize>,
node: RelMemoNodeRef<T>,
optimizer: &CascadesOptimizer<T>,
optimizer: &CascadesOptimizer<T, M>,
) -> Vec<HashMap<usize, RelNode<T>>> {
if let RuleMatcher::PickMany { .. } | RuleMatcher::IgnoreMany = children.last().unwrap() {
} else {
Expand Down Expand Up @@ -123,19 +123,19 @@ fn match_node<T: RelNodeTyp>(
picks
}

fn match_and_pick_expr<T: RelNodeTyp>(
fn match_and_pick_expr<T: RelNodeTyp, M: Memo<T>>(
matcher: &RuleMatcher<T>,
expr_id: ExprId,
optimizer: &CascadesOptimizer<T>,
optimizer: &CascadesOptimizer<T, M>,
) -> Vec<HashMap<usize, RelNode<T>>> {
let node = optimizer.get_expr_memoed(expr_id);
match_and_pick(matcher, node, optimizer)
}

fn match_and_pick_group<T: RelNodeTyp>(
fn match_and_pick_group<T: RelNodeTyp, M: Memo<T>>(
matcher: &RuleMatcher<T>,
group_id: GroupId,
optimizer: &CascadesOptimizer<T>,
optimizer: &CascadesOptimizer<T, M>,
) -> Vec<HashMap<usize, RelNode<T>>> {
let mut matches = vec![];
for expr_id in optimizer.get_all_exprs_in_group(group_id) {
Expand All @@ -145,10 +145,10 @@ fn match_and_pick_group<T: RelNodeTyp>(
matches
}

fn match_and_pick<T: RelNodeTyp>(
fn match_and_pick<T: RelNodeTyp, M: Memo<T>>(
matcher: &RuleMatcher<T>,
node: RelMemoNodeRef<T>,
optimizer: &CascadesOptimizer<T>,
optimizer: &CascadesOptimizer<T, M>,
) -> Vec<HashMap<usize, RelNode<T>>> {
match matcher {
RuleMatcher::MatchAndPickNode {
Expand All @@ -171,8 +171,8 @@ fn match_and_pick<T: RelNodeTyp>(
}
}

impl<T: RelNodeTyp> Task<T> for ApplyRuleTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T>) -> Result<Vec<Box<dyn Task<T>>>> {
impl<T: RelNodeTyp, M: Memo<T>> Task<T, M> for ApplyRuleTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>> {
if optimizer.is_rule_fired(self.expr_id, self.rule_id) {
return Ok(vec![]);
}
Expand Down Expand Up @@ -204,12 +204,12 @@ impl<T: RelNodeTyp> Task<T> for ApplyRuleTask {
if expr_typ.is_logical() {
tasks.push(
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring))
as Box<dyn Task<T>>,
as Box<dyn Task<T, M>>,
);
} else {
tasks
.push(Box::new(OptimizeInputsTask::new(expr_id, true))
as Box<dyn Task<T>>);
as Box<dyn Task<T, M>>);
}
optimizer.unmark_expr_explored(expr_id);
trace!(event = "apply_rule", expr_id = %self.expr_id, rule_id = %self.rule_id, new_expr_id = %expr_id);
Expand Down
8 changes: 5 additions & 3 deletions optd-core/src/cascades/tasks/explore_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
cascades::{
optimizer::{CascadesOptimizer, GroupId},
tasks::OptimizeExpressionTask,
Memo,
},
rel_node::RelNodeTyp,
};
Expand All @@ -21,8 +22,8 @@ impl ExploreGroupTask {
}
}

impl<T: RelNodeTyp> Task<T> for ExploreGroupTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T>) -> Result<Vec<Box<dyn Task<T>>>> {
impl<T: RelNodeTyp, M: Memo<T>> Task<T, M> for ExploreGroupTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>> {
trace!(event = "task_begin", task = "explore_group", group_id = %self.group_id);
let mut tasks = vec![];
if optimizer.is_group_explored(self.group_id) {
Expand All @@ -34,7 +35,8 @@ impl<T: RelNodeTyp> Task<T> for ExploreGroupTask {
for expr in exprs {
let typ = optimizer.get_expr_memoed(expr).typ.clone();
if typ.is_logical() {
tasks.push(Box::new(OptimizeExpressionTask::new(expr, true)) as Box<dyn Task<T>>);
tasks
.push(Box::new(OptimizeExpressionTask::new(expr, true)) as Box<dyn Task<T, M>>);
}
}
optimizer.mark_group_explored(self.group_id);
Expand Down
11 changes: 7 additions & 4 deletions optd-core/src/cascades/tasks/optimize_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
cascades::{
optimizer::{CascadesOptimizer, ExprId},
tasks::{ApplyRuleTask, ExploreGroupTask},
Memo,
},
rel_node::{RelNodeTyp, Value},
rules::RuleMatcher,
Expand Down Expand Up @@ -35,8 +36,8 @@ fn top_matches<T: RelNodeTyp>(
}
}

impl<T: RelNodeTyp> Task<T> for OptimizeExpressionTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T>) -> Result<Vec<Box<dyn Task<T>>>> {
impl<T: RelNodeTyp, M: Memo<T>> Task<T, M> for OptimizeExpressionTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>> {
let expr = optimizer.get_expr_memoed(self.expr_id);
trace!(event = "task_begin", task = "optimize_expr", expr_id = %self.expr_id, expr = %expr);
let mut tasks = vec![];
Expand All @@ -56,10 +57,12 @@ impl<T: RelNodeTyp> Task<T> for OptimizeExpressionTask {
if top_matches(rule.matcher(), expr.typ.clone(), expr.data.clone()) {
tasks.push(
Box::new(ApplyRuleTask::new(rule_id, self.expr_id, self.exploring))
as Box<dyn Task<T>>,
as Box<dyn Task<T, M>>,
);
for &input_group_id in &expr.children {
tasks.push(Box::new(ExploreGroupTask::new(input_group_id)) as Box<dyn Task<T>>);
tasks.push(
Box::new(ExploreGroupTask::new(input_group_id)) as Box<dyn Task<T, M>>
);
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions optd-core/src/cascades/tasks/optimize_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
cascades::{
optimizer::GroupId,
tasks::{optimize_expression::OptimizeExpressionTask, OptimizeInputsTask},
CascadesOptimizer,
CascadesOptimizer, Memo,
},
rel_node::RelNodeTyp,
};
Expand All @@ -22,8 +22,8 @@ impl OptimizeGroupTask {
}
}

impl<T: RelNodeTyp> Task<T> for OptimizeGroupTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T>) -> Result<Vec<Box<dyn Task<T>>>> {
impl<T: RelNodeTyp, M: Memo<T>> Task<T, M> for OptimizeGroupTask {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>> {
trace!(event = "task_begin", task = "optimize_group", group_id = %self.group_id);
let group_info = optimizer.get_group_info(self.group_id);
if group_info.winner.has_decided() {
Expand All @@ -36,13 +36,13 @@ impl<T: RelNodeTyp> Task<T> for OptimizeGroupTask {
for &expr in &exprs {
let typ = optimizer.get_expr_memoed(expr).typ.clone();
if typ.is_logical() {
tasks.push(Box::new(OptimizeExpressionTask::new(expr, false)) as Box<dyn Task<T>>);
tasks.push(Box::new(OptimizeExpressionTask::new(expr, false)) as Box<dyn Task<T, M>>);
}
}
for &expr in &exprs {
let typ = optimizer.get_expr_memoed(expr).typ.clone();
if !typ.is_logical() {
tasks.push(Box::new(OptimizeInputsTask::new(expr, true)) as Box<dyn Task<T>>);
tasks.push(Box::new(OptimizeInputsTask::new(expr, true)) as Box<dyn Task<T, M>>);
}
}
trace!(event = "task_finish", task = "optimize_group", group_id = %self.group_id, exprs_cnt = exprs_cnt);
Expand Down
Loading

0 comments on commit 7005948

Please sign in to comment.