Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multi column stats #160

Merged
merged 13 commits into from
Apr 16, 2024
33 changes: 22 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions optd-datafusion-repr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ async-trait = "0.1"
datafusion = "32.0.0"
assert_approx_eq = "1.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_with = {version = "3.7.0", features = ["json"]}
bincode = "1.3.3"
9 changes: 7 additions & 2 deletions optd-datafusion-repr/src/bin/test_optimize.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use optd_core::{
cascades::CascadesOptimizer,
Expand Down Expand Up @@ -45,7 +45,12 @@ pub fn main() {
Box::new(OptCostModel::new(
[("t1", 1000), ("t2", 100), ("t3", 10000)]
.into_iter()
.map(|(x, y)| (x.to_string(), DataFusionPerTableStats::new(y, vec![])))
.map(|(x, y)| {
(
x.to_string(),
DataFusionPerTableStats::new(y, HashMap::new()),
)
})
.collect(),
)),
vec![],
Expand Down
18 changes: 15 additions & 3 deletions optd-datafusion-repr/src/cost/adaptive_cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use optd_core::{
cost::{Cost, CostModel},
rel_node::{RelNode, Value},
};
use serde::{de::DeserializeOwned, Serialize};

use super::base_cost::stats::{
BaseTableStats, DataFusionDistribution, DataFusionMostCommonValues, Distribution,
Expand All @@ -27,13 +28,20 @@ pub struct RuntimeAdaptionStorageInner {

pub const DEFAULT_DECAY: usize = 50;

pub struct AdaptiveCostModel<M: MostCommonValues, D: Distribution> {
pub struct AdaptiveCostModel<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> {
runtime_row_cnt: RuntimeAdaptionStorage,
base_model: OptCostModel<M, D>,
decay: usize,
}

impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for AdaptiveCostModel<M, D> {
impl<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> CostModel<OptRelNodeTyp> for AdaptiveCostModel<M, D>
{
fn explain(&self, cost: &Cost) -> String {
self.base_model.explain(cost)
}
Expand Down Expand Up @@ -87,7 +95,11 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for Adaptive
}
}

impl<M: MostCommonValues, D: Distribution> AdaptiveCostModel<M, D> {
impl<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> AdaptiveCostModel<M, D>
{
pub fn new(decay: usize, stats: BaseTableStats<M, D>) -> Self {
Self {
runtime_row_cnt: RuntimeAdaptionStorage::default(),
Expand Down
72 changes: 53 additions & 19 deletions optd-datafusion-repr/src/cost/base_cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ use optd_core::{
cost::{Cost, CostModel},
rel_node::{RelNode, RelNodeTyp, Value},
};
use serde::{de::DeserializeOwned, Serialize};

use super::base_cost::stats::{BaseTableStats, Distribution, MostCommonValues, PerColumnStats};
use super::base_cost::stats::{
BaseTableStats, ColumnCombValueStats, Distribution, MostCommonValues,
};

fn compute_plan_node_cost<T: RelNodeTyp, C: CostModel<T>>(
model: &C,
Expand All @@ -29,7 +32,10 @@ fn compute_plan_node_cost<T: RelNodeTyp, C: CostModel<T>>(
cost
}

pub struct OptCostModel<M: MostCommonValues, D: Distribution> {
pub struct OptCostModel<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> {
per_table_stats_map: BaseTableStats<M, D>,
}

Expand All @@ -52,7 +58,11 @@ pub const ROW_COUNT: usize = 1;
pub const COMPUTE_COST: usize = 2;
pub const IO_COST: usize = 3;

impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
impl<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> OptCostModel<M, D>
{
pub fn row_cnt(Cost(cost): &Cost) -> f64 {
cost[ROW_COUNT]
}
Expand Down Expand Up @@ -84,7 +94,11 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
}
}

impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostModel<M, D> {
impl<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> CostModel<OptRelNodeTyp> for OptCostModel<M, D>
{
fn explain(&self, cost: &Cost) -> String {
format!(
"weighted={},row_cnt={},compute={},io={}",
Expand Down Expand Up @@ -180,28 +194,36 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
}
}

impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
impl<
M: MostCommonValues + Serialize + DeserializeOwned,
D: Distribution + Serialize + DeserializeOwned,
> OptCostModel<M, D>
{
pub fn new(per_table_stats_map: BaseTableStats<M, D>) -> Self {
Self {
per_table_stats_map,
}
}

fn get_per_column_stats_from_col_ref(
fn get_single_column_stats_from_col_ref(
&self,
col_ref: &ColumnRef,
) -> Option<&PerColumnStats<M, D>> {
) -> Option<&ColumnCombValueStats<M, D>> {
if let ColumnRef::BaseTableColumnRef { table, col_idx } = col_ref {
self.get_per_column_stats(table, *col_idx)
self.get_column_comb_stats(table, &[*col_idx])
} else {
None
}
}

fn get_per_column_stats(&self, table: &str, col_idx: usize) -> Option<&PerColumnStats<M, D>> {
fn get_column_comb_stats(
&self,
table: &str,
col_comb: &[usize],
) -> Option<&ColumnCombValueStats<M, D>> {
self.per_table_stats_map
.get(table)
.and_then(|per_table_stats| per_table_stats.per_column_stats_vec[col_idx].as_ref())
.and_then(|per_table_stats| per_table_stats.column_comb_stats.get(col_comb))
}
}

Expand All @@ -212,6 +234,7 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
mod tests {
use itertools::Itertools;
use optd_core::rel_node::Value;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use crate::{
Expand All @@ -223,21 +246,26 @@ mod tests {
};

use super::*;
pub type TestPerColumnStats = PerColumnStats<TestMostCommonValues, TestDistribution>;
pub type TestPerColumnStats = ColumnCombValueStats<TestMostCommonValues, TestDistribution>;
pub type TestOptCostModel = OptCostModel<TestMostCommonValues, TestDistribution>;

#[derive(Serialize, Deserialize)]
pub struct TestMostCommonValues {
pub mcvs: HashMap<Value, f64>,
pub mcvs: HashMap<Vec<Option<Value>>, f64>,
}

#[derive(Serialize, Deserialize)]
pub struct TestDistribution {
cdfs: HashMap<Value, f64>,
}

impl TestMostCommonValues {
pub fn new(mcvs_vec: Vec<(Value, f64)>) -> Self {
Self {
mcvs: mcvs_vec.into_iter().collect(),
mcvs: mcvs_vec
.into_iter()
.map(|(v, freq)| (vec![Some(v)], freq))
.collect(),
}
}

Expand All @@ -247,15 +275,15 @@ mod tests {
}

impl MostCommonValues for TestMostCommonValues {
fn freq(&self, value: &Value) -> Option<f64> {
fn freq(&self, value: &ColumnCombValue) -> Option<f64> {
self.mcvs.get(value).copied()
}

fn total_freq(&self) -> f64 {
self.mcvs.values().sum()
}

fn freq_over_pred(&self, pred: Box<dyn Fn(&Value) -> bool>) -> f64 {
fn freq_over_pred(&self, pred: Box<dyn Fn(&ColumnCombValue) -> bool>) -> f64 {
self.mcvs
.iter()
.filter(|(val, _)| pred(val))
Expand Down Expand Up @@ -294,7 +322,7 @@ mod tests {
OptCostModel::new(
vec![(
String::from(TABLE1_NAME),
PerTableStats::new(100, vec![Some(per_column_stats)]),
TableStats::new(100, vec![(vec![0], per_column_stats)].into_iter().collect()),
)]
.into_iter()
.collect(),
Expand Down Expand Up @@ -325,11 +353,17 @@ mod tests {
vec![
(
String::from(TABLE1_NAME),
PerTableStats::new(tbl1_row_cnt, vec![Some(tbl1_per_column_stats)]),
TableStats::new(
tbl1_row_cnt,
vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(),
),
),
(
String::from(TABLE2_NAME),
PerTableStats::new(tbl2_row_cnt, vec![Some(tbl2_per_column_stats)]),
TableStats::new(
tbl2_row_cnt,
vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(),
),
),
]
.into_iter()
Expand Down Expand Up @@ -399,7 +433,7 @@ mod tests {
TestMostCommonValues::empty(),
0,
0.0,
TestDistribution::empty(),
Some(TestDistribution::empty()),
)
}
}
Loading
Loading