Skip to content

Commit

Permalink
feat: keep both input and output correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
Gun9niR committed Apr 25, 2024
1 parent 1a29aaa commit ff9ce90
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 89 deletions.
2 changes: 1 addition & 1 deletion optd-datafusion-repr/src/cost/base_cost/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<
let group_col_refs = optimizer
.get_property_by_group::<ColumnRefPropertyBuilder>(context.group_id, 1);
group_col_refs
.column_refs
.column_refs()
.iter()
.take(group_by.len())
.map(|col_ref| match col_ref {
Expand Down
44 changes: 22 additions & 22 deletions optd-datafusion-repr/src/cost/base_cost/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,14 @@ mod tests {
assert_approx_eq::assert_approx_eq!(
cost_model.get_filter_selectivity(
cnst(Value::Bool(true)),
&GroupColumnRefs::new(vec![], None)
&GroupColumnRefs::new_test(vec![], None)
),
1.0
);
assert_approx_eq::assert_approx_eq!(
cost_model.get_filter_selectivity(
cnst(Value::Bool(false)),
&GroupColumnRefs::new(vec![], None)
&GroupColumnRefs::new_test(vec![], None)
),
0.0
);
Expand All @@ -482,7 +482,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), cnst(Value::Int32(1)));
let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(1)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -509,7 +509,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), cnst(Value::Int32(2)));
let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -536,7 +536,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), cnst(Value::Int32(2)));
let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -564,7 +564,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Neq, col_ref(0), cnst(Value::Int32(1)));
let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -591,7 +591,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -618,7 +618,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -654,7 +654,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -686,7 +686,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -713,7 +713,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -740,7 +740,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -776,7 +776,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -812,7 +812,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -841,7 +841,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Gt, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -868,7 +868,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Gt, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -896,7 +896,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Geq, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -923,7 +923,7 @@ mod tests {
));
let expr_tree = bin_op(BinOpType::Geq, col_ref(0), cnst(Value::Int32(15)));
let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -963,7 +963,7 @@ mod tests {
let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]);
let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]);
let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -1006,7 +1006,7 @@ mod tests {
let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]);
let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]);
let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -1039,7 +1039,7 @@ mod tests {
UnOpType::Not,
bin_op(BinOpType::Eq, col_ref(0), cnst(Value::Int32(1))),
);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand All @@ -1064,7 +1064,7 @@ mod tests {
UnOpType::Not,
bin_op(BinOpType::Eq, col_ref(0), cnst(Value::Int32(1))),
);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down
2 changes: 1 addition & 1 deletion optd-datafusion-repr/src/cost/base_cost/filter/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ mod tests {
0.0,
Some(TestDistribution::empty()),
));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down
4 changes: 2 additions & 2 deletions optd-datafusion-repr/src/cost/base_cost/filter/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ mod tests {
0.0,
Some(TestDistribution::empty()),
));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down Expand Up @@ -167,7 +167,7 @@ mod tests {
null_frac,
Some(TestDistribution::empty()),
));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![ColumnRef::base_table_column_ref(
String::from(TABLE1_NAME),
0,
Expand Down
24 changes: 12 additions & 12 deletions optd-datafusion-repr/src/cost/base_cost/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl<
let right_keys_group_id = context.children_group_ids[3];
let left_col_cnt = optimizer
.get_property_by_group::<ColumnRefPropertyBuilder>(context.children_group_ids[0], 1)
.column_refs
.column_refs()
.len();
let left_keys_list = optimizer.get_all_group_bindings(left_keys_group_id, false);
let right_keys_list = optimizer.get_all_group_bindings(right_keys_group_id, false);
Expand Down Expand Up @@ -380,7 +380,7 @@ mod tests {
cost_model.get_join_selectivity_from_expr_tree(
JoinType::Inner,
cnst(Value::Bool(true)),
&GroupColumnRefs::new(vec![], None),
&GroupColumnRefs::new_test(vec![], None),
f64::NAN,
f64::NAN
),
Expand All @@ -390,7 +390,7 @@ mod tests {
cost_model.get_join_selectivity_from_expr_tree(
JoinType::Inner,
cnst(Value::Bool(false)),
&GroupColumnRefs::new(vec![], None),
&GroupColumnRefs::new_test(vec![], None),
f64::NAN,
f64::NAN
),
Expand All @@ -416,7 +416,7 @@ mod tests {
);
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1));
let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -459,7 +459,7 @@ mod tests {
let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0));
let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and0.clone()]);
let expr_tree_rev = log_op(LogOpType::And, vec![eq1and0.clone(), eq0and1.clone()]);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -502,7 +502,7 @@ mod tests {
let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100)));
let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq100.clone()]);
let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), eq0and1.clone()]);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -545,7 +545,7 @@ mod tests {
let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100)));
let expr_tree = log_op(LogOpType::And, vec![neq12.clone(), eq100.clone()]);
let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), neq12.clone()]);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -585,7 +585,7 @@ mod tests {
),
);
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -717,7 +717,7 @@ mod tests {
// the left/right of the join refers to the tables, not the order of columns in the predicate
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1));
let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -780,7 +780,7 @@ mod tests {
// the left/right of the join refers to the tables, not the order of columns in the predicate
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1));
let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -844,7 +844,7 @@ mod tests {
// the left/right of the join refers to the tables, not the order of columns in the predicate
let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1));
let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0));
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down Expand Up @@ -912,7 +912,7 @@ mod tests {
let expr_tree = log_op(LogOpType::And, vec![eq0and1, filter.clone()]);
// inner rev means its the inner expr (the eq op) whose children are being reversed, as opposed to the and op
let expr_tree_inner_rev = log_op(LogOpType::And, vec![eq1and0, filter.clone()]);
let column_refs = GroupColumnRefs::new(
let column_refs = GroupColumnRefs::new_test(
vec![
ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0),
ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0),
Expand Down
Loading

0 comments on commit ff9ce90

Please sign in to comment.