Skip to content

Commit

Permalink
feat(core): support cost-based pruning
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Nov 3, 2024
1 parent 29647f2 commit 5604cf8
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub type RuleId = usize;

#[derive(Default, Clone, Debug)]
pub struct OptimizerContext {
pub upper_bound: Option<f64>,
pub budget_used: bool,
pub rules_applied: usize,
}
Expand Down
54 changes: 49 additions & 5 deletions optd-core/src/cascades/tasks/optimize_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@ struct ContinueTask {
return_from_optimize_group: bool,
}

struct ContinueTaskDisplay<'a>(&'a Option<ContinueTask>);

impl std::fmt::Display for ContinueTaskDisplay<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
Some(x) => {
if x.return_from_optimize_group {
write!(f, "return,next_group_idx={}", x.next_group_idx)
} else {
write!(f, "enter,next_group_idx={}", x.next_group_idx)
}
}
None => write!(f, "none"),
}
}
}

pub struct OptimizeInputsTask {
expr_id: ExprId,
continue_from: Option<ContinueTask>,
Expand Down Expand Up @@ -124,7 +141,7 @@ impl<T: RelNodeTyp> Task<T> for OptimizeInputsTask {
let children_group_ids = &expr.children;
let cost = optimizer.cost();

trace!(event = "task_begin", task = "optimize_inputs", expr_id = %self.expr_id, continue_from = ?self.continue_from, total_children = %children_group_ids.len());
trace!(event = "task_begin", task = "optimize_inputs", expr_id = %self.expr_id, continue_from = %ContinueTaskDisplay(&self.continue_from), total_children = %children_group_ids.len());

if let Some(ContinueTask {
next_group_idx,
Expand Down Expand Up @@ -170,11 +187,38 @@ impl<T: RelNodeTyp> Task<T> for OptimizeInputsTask {
Some(optimizer),
);
let total_cost = cost.sum(&operation_cost, &input_cost);

if self.pruning {
let group_info = optimizer.get_group_info(group_id);
fn trace_fmt(winner: &Winner) -> String {
match winner {
Winner::Full(winner) => winner.total_weighted_cost.to_string(),
Winner::Impossible => "impossible".to_string(),
Winner::Unknown => "unknown".to_string(),
}
}
trace!(
event = "compute_cost",
task = "optimize_inputs",
expr_id = %self.expr_id,
weighted_cost_so_far = cost.weighted_cost(&total_cost),
winner_weighted_cost = %trace_fmt(&group_info.winner),
current_processing = %next_group_idx,
total_child_groups = %children_group_ids.len());
if let Some(winner) = group_info.winner.as_full_winner() {
let cost_so_far = cost.weighted_cost(&total_cost);
if winner.total_weighted_cost <= cost_so_far {
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "pruned");
return Ok(vec![]);
}
}
}

if next_group_idx < children_group_ids.len() {
let child_group_id = children_group_ids[next_group_idx];
let group_idx = next_group_idx;
let group_info = optimizer.get_group_info(child_group_id);
if !group_info.winner.has_full_winner() {
let child_group_info = optimizer.get_group_info(child_group_id);
if !child_group_info.winner.has_full_winner() {
if !return_from_optimize_group {
trace!(event = "task_yield", task = "optimize_inputs", expr_id = %self.expr_id, group_idx = %group_idx, yield_to = "optimize_group", optimize_group_id = %child_group_id);
return Ok(vec![
Expand All @@ -189,7 +233,7 @@ impl<T: RelNodeTyp> Task<T> for OptimizeInputsTask {
]);
} else {
self.update_winner_impossible(optimizer);
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, "result" = "impossible");
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "impossible");
return Ok(vec![]);
}
}
Expand All @@ -203,7 +247,7 @@ impl<T: RelNodeTyp> Task<T> for OptimizeInputsTask {
)) as Box<dyn Task<T>>])
} else {
self.update_winner(input_statistics_ref, operation_cost, total_cost, optimizer);
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, "result" = "optimized");
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "optimized");
Ok(vec![])
}
} else {
Expand Down
13 changes: 11 additions & 2 deletions optd-sqlplannertest/src/bin/planner_test_apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ use std::path::Path;
use anyhow::Result;

use clap::Parser;
use sqlplannertest::PlannerTestApplyOptions;

#[derive(Parser)]
#[command(version, about, long_about = None)]
struct Cli {
/// Optional list of directories to apply the test; if empty, apply all tests
directories: Vec<String>,
/// Use the advanced cost model
#[clap(long)]
enable_advanced_cost_model: bool,
/// Execute tests in serial
#[clap(long)]
serial: bool,
}

#[tokio::main]
Expand All @@ -20,9 +25,11 @@ async fn main() -> Result<()> {
let cli = Cli::parse();

let enable_advanced_cost_model = cli.enable_advanced_cost_model;
let opts = PlannerTestApplyOptions { serial: cli.serial };

if cli.directories.is_empty() {
println!("Running all tests");
sqlplannertest::planner_test_apply(
sqlplannertest::planner_test_apply_with_options(
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests"),
move || async move {
if enable_advanced_cost_model {
Expand All @@ -31,12 +38,13 @@ async fn main() -> Result<()> {
optd_sqlplannertest::DatafusionDBMS::new().await
}
},
opts,
)
.await?;
} else {
for directory in cli.directories {
println!("Running tests in {}", directory);
sqlplannertest::planner_test_apply(
sqlplannertest::planner_test_apply_with_options(
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join(directory),
Expand All @@ -47,6 +55,7 @@ async fn main() -> Result<()> {
optd_sqlplannertest::DatafusionDBMS::new().await
}
},
opts.clone(),
)
.await?;
}
Expand Down

0 comments on commit 5604cf8

Please sign in to comment.