From cdfb190763f13a62699d3ed05cc104f08768b007 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 30 Aug 2024 18:16:12 +0100 Subject: [PATCH] Add method for converting VortexExpr into equivalent pruning expression (#701) Pruning expression when evaluated on the statistics of the block of data will tell us if the block MIGHT contain matching values. --- vortex-expr/src/datafusion.rs | 2 +- vortex-expr/src/expr.rs | 127 +++++- vortex-expr/src/operators.rs | 2 +- vortex-serde/src/layouts/mod.rs | 1 + vortex-serde/src/layouts/pruning.rs | 443 +++++++++++++++++++++ vortex-serde/src/layouts/read/filtering.rs | 6 +- vortex-serde/src/layouts/write/writer.rs | 8 +- 7 files changed, 556 insertions(+), 33 deletions(-) create mode 100644 vortex-serde/src/layouts/pruning.rs diff --git a/vortex-expr/src/datafusion.rs b/vortex-expr/src/datafusion.rs index 764a9b6404..22bee491c9 100644 --- a/vortex-expr/src/datafusion.rs +++ b/vortex-expr/src/datafusion.rs @@ -28,7 +28,7 @@ pub fn convert_expr_to_vortex( .as_any() .downcast_ref::() { - let expr = Column::new(col_expr.name().to_owned()); + let expr = Column::from(col_expr.name().to_owned()); return Ok(Arc::new(expr) as _); } diff --git a/vortex-expr/src/expr.rs b/vortex-expr/src/expr.rs index b80b4dd6b9..7e4b47fa1c 100644 --- a/vortex-expr/src/expr.rs +++ b/vortex-expr/src/expr.rs @@ -1,3 +1,4 @@ +use std::any::Any; use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; @@ -12,46 +13,85 @@ use vortex_scalar::Scalar; use crate::Operator; -pub trait VortexExpr: Debug + Send + Sync { +pub trait VortexExpr: Debug + Send + Sync + PartialEq { + fn as_any(&self) -> &dyn Any; + fn evaluate(&self, array: &Array) -> VortexResult; fn references(&self) -> HashSet; } -#[derive(Debug)] +// Taken from apache-datafusion, necessary since you can't require VortexExpr implement PartialEq +fn unbox_any(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>().unwrap().as_any() + } else if any.is::>() { + any.downcast_ref::>().unwrap().as_any() + } else { + any + } +} + +#[derive(Debug, PartialEq, Hash, Clone)] pub struct NoOp; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BinaryExpr { - left: Arc, - right: Arc, + lhs: Arc, operator: Operator, + rhs: Arc, } impl BinaryExpr { - pub fn new(left: Arc, operator: Operator, right: Arc) -> Self { - Self { - left, - right, - operator, - } + pub fn new(lhs: Arc, operator: Operator, rhs: Arc) -> Self { + Self { lhs, rhs, operator } + } + + pub fn lhs(&self) -> &Arc { + &self.lhs + } + + pub fn rhs(&self) -> &Arc { + &self.rhs + } + + pub fn op(&self) -> Operator { + self.operator } } -#[derive(Debug)] +#[derive(Debug, PartialEq, Hash, Clone)] pub struct Column { field: Field, } impl Column { - pub fn new(field: String) -> Self { - Self { - field: Field::from(field), - } + pub fn new(field: Field) -> Self { + Self { field } + } + + pub fn field(&self) -> &Field { + &self.field + } +} + +impl From for Column { + fn from(value: String) -> Self { + Column::new(value.into()) + } +} + +impl From for Column { + fn from(value: usize) -> Self { + Column::new(value.into()) } } impl VortexExpr for Column { + fn as_any(&self) -> &dyn Any { + self + } + fn evaluate(&self, array: &Array) -> VortexResult { let s = StructArray::try_from(array)?; @@ -68,7 +108,16 @@ impl VortexExpr for Column { } } -#[derive(Debug)] +impl PartialEq for Column { + fn eq(&self, other: &dyn Any) -> bool { + unbox_any(other) + .downcast_ref::() + .map(|x| x == self) + .unwrap_or(false) + } +} + +#[derive(Debug, PartialEq)] pub struct Literal { value: Scalar, } @@ -80,6 +129,10 @@ impl Literal { } impl VortexExpr for Literal { + fn as_any(&self) -> &dyn Any { + self + } + fn evaluate(&self, array: &Array) -> VortexResult { Ok(ConstantArray::new(self.value.clone(), array.len()).into_array()) } @@ -89,10 +142,23 @@ impl VortexExpr for Literal { } } +impl PartialEq for Literal { + fn eq(&self, other: &dyn Any) -> bool { + unbox_any(other) + .downcast_ref::() + .map(|x| x == self) + .unwrap_or(false) + } +} + impl VortexExpr for BinaryExpr { + fn as_any(&self) -> &dyn Any { + self + } + fn evaluate(&self, array: &Array) -> VortexResult { - let lhs = self.left.evaluate(array)?; - let rhs = self.right.evaluate(array)?; + let lhs = self.lhs.evaluate(array)?; + let rhs = self.rhs.evaluate(array)?; let array = match self.operator { Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq)?, @@ -109,13 +175,26 @@ impl VortexExpr for BinaryExpr { } fn references(&self) -> HashSet { - let mut res = self.left.references(); - res.extend(self.right.references()); + let mut res = self.lhs.references(); + res.extend(self.rhs.references()); res } } +impl PartialEq for BinaryExpr { + fn eq(&self, other: &dyn Any) -> bool { + unbox_any(other) + .downcast_ref::() + .map(|x| x.operator == self.operator && x.lhs.eq(&self.lhs) && x.rhs.eq(&self.rhs)) + .unwrap_or(false) + } +} + impl VortexExpr for NoOp { + fn as_any(&self) -> &dyn Any { + self + } + fn evaluate(&self, _array: &Array) -> VortexResult { vortex_bail!("NoOp::evaluate() should not be called") } @@ -124,3 +203,9 @@ impl VortexExpr for NoOp { HashSet::new() } } + +impl PartialEq for NoOp { + fn eq(&self, other: &dyn Any) -> bool { + unbox_any(other).downcast_ref::().is_some() + } +} diff --git a/vortex-expr/src/operators.rs b/vortex-expr/src/operators.rs index 7368317ed5..a475a09fb2 100644 --- a/vortex-expr/src/operators.rs +++ b/vortex-expr/src/operators.rs @@ -1,7 +1,7 @@ use core::fmt; use std::fmt::{Display, Formatter}; -#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum Operator { // comparison diff --git a/vortex-serde/src/layouts/mod.rs b/vortex-serde/src/layouts/mod.rs index 03d6b856f7..b7b6459438 100644 --- a/vortex-serde/src/layouts/mod.rs +++ b/vortex-serde/src/layouts/mod.rs @@ -1,6 +1,7 @@ mod read; mod write; +mod pruning; #[cfg(test)] mod tests; diff --git a/vortex-serde/src/layouts/pruning.rs b/vortex-serde/src/layouts/pruning.rs new file mode 100644 index 0000000000..2730fa2ac1 --- /dev/null +++ b/vortex-serde/src/layouts/pruning.rs @@ -0,0 +1,443 @@ +use std::collections::hash_map::Entry; +use std::sync::Arc; + +use ahash::{HashMap, HashMapExt}; +use vortex::stats::Stat; +use vortex_dtype::field::Field; +use vortex_dtype::Nullability; +use vortex_expr::{BinaryExpr, Column, Literal, Operator, VortexExpr}; +use vortex_scalar::Scalar; + +#[allow(dead_code)] +pub struct PruningPredicate { + expr: Arc, + stats_to_fetch: HashMap>, +} + +impl PruningPredicate { + pub fn new(original_expr: &Arc) -> Self { + let (expr, stats_to_fetch) = convert_to_pruning_expression(original_expr); + Self { + expr, + stats_to_fetch, + } + } +} + +fn convert_to_pruning_expression( + expr: &Arc, +) -> (Arc, HashMap>) { + // Anything that can't be translated has to be represented as + // boolean true expression, i.e. the value might be in that chunk + let fallback = Arc::new(Literal::new(Scalar::bool(true, Nullability::NonNullable))); + // TODO(robert): Add support for boolean column expressions, + // i.e. if column is of bool dtype it's valid to filter on it directly as a predicate + if expr.as_any().downcast_ref::().is_some() { + return (fallback, HashMap::new()); + } + + if let Some(bexp) = expr.as_any().downcast_ref::() { + if bexp.op() == Operator::Or || bexp.op() == Operator::And { + let (rewritten_left, mut refs_lhs) = convert_to_pruning_expression(bexp.lhs()); + let (rewritten_right, refs_rhs) = convert_to_pruning_expression(bexp.rhs()); + refs_lhs.extend(refs_rhs); + return ( + Arc::new(BinaryExpr::new(rewritten_left, bexp.op(), rewritten_right)), + refs_lhs, + ); + } + + if let Some(col) = bexp.lhs().as_any().downcast_ref::() { + return PruningPredicateRewriter::try_new(col.field().clone(), bexp.op(), bexp.rhs()) + .and_then(PruningPredicateRewriter::rewrite) + .unwrap_or_else(|| (fallback, HashMap::new())); + }; + + if let Some(col) = bexp.rhs().as_any().downcast_ref::() { + return PruningPredicateRewriter::try_new( + col.field().clone(), + bexp.op().swap(), + bexp.lhs(), + ) + .and_then(PruningPredicateRewriter::rewrite) + .unwrap_or_else(|| (fallback, HashMap::new())); + }; + } + + (fallback, HashMap::new()) +} + +struct PruningPredicateRewriter<'a> { + column: Field, + operator: Operator, + other_exp: &'a Arc, + stats_to_fetch: HashMap>, +} + +type PruningPredicateStats = (Arc, HashMap>); + +impl<'a> PruningPredicateRewriter<'a> { + pub fn try_new( + column: Field, + operator: Operator, + other_exp: &'a Arc, + ) -> Option { + // TODO(robert): Simplify expression to guarantee that each column is not compared to itself + // For majority of cases self column references are likely not prunable + if other_exp.references().contains(&column) { + return None; + } + + Some(Self { + column, + operator, + other_exp, + stats_to_fetch: HashMap::new(), + }) + } + + fn add_stat_reference(&mut self, stat: Stat) -> Field { + let new_field = stat_column_name(&self.column, stat); + match self.stats_to_fetch.entry(self.column.clone()) { + Entry::Occupied(o) => o.into_mut().push(stat), + Entry::Vacant(v) => { + v.insert(vec![stat]); + } + } + new_field + } + + fn rewrite_other_exp(&mut self, stat: Stat) -> Arc { + replace_column_with_stat(self.other_exp, stat, &mut self.stats_to_fetch) + .unwrap_or_else(|| self.other_exp.clone()) + } + + fn rewrite(mut self) -> Option { + let expr: Option> = match self.operator { + Operator::Eq => { + let min_col = Arc::new(Column::new(self.add_stat_reference(Stat::Min))); + let max_col = Arc::new(Column::new(self.add_stat_reference(Stat::Max))); + let replaced_max = self.rewrite_other_exp(Stat::Max); + let replaced_min = self.rewrite_other_exp(Stat::Min); + + Some(Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new(min_col, Operator::Lte, replaced_max)), + Operator::And, + Arc::new(BinaryExpr::new(replaced_min, Operator::Lte, max_col)), + ))) + } + Operator::NotEq => { + let min_col = Arc::new(Column::new(self.add_stat_reference(Stat::Min))); + let max_col = Arc::new(Column::new(self.add_stat_reference(Stat::Max))); + let replaced_max = self.rewrite_other_exp(Stat::Max); + let replaced_min = self.rewrite_other_exp(Stat::Min); + + // In case of other_exp is literal both sides of AND will be the same expression + Some(Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + min_col.clone(), + Operator::NotEq, + replaced_min.clone(), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + replaced_min, + Operator::NotEq, + max_col.clone(), + )), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + min_col, + Operator::NotEq, + replaced_max.clone(), + )), + Operator::Or, + Arc::new(BinaryExpr::new(replaced_max, Operator::NotEq, max_col)), + )), + ))) + } + op @ Operator::Gt | op @ Operator::Gte => { + let max_col = Arc::new(Column::new(self.add_stat_reference(Stat::Max))); + let replaced_min = self.rewrite_other_exp(Stat::Min); + + Some(Arc::new(BinaryExpr::new(max_col, op, replaced_min))) + } + op @ Operator::Lt | op @ Operator::Lte => { + let min_col = Arc::new(Column::new(self.add_stat_reference(Stat::Min))); + let replaced_max = self.rewrite_other_exp(Stat::Max); + + Some(Arc::new(BinaryExpr::new(min_col, op, replaced_max))) + } + _ => None, + }; + expr.map(|e| (e, self.stats_to_fetch)) + } +} + +fn replace_column_with_stat( + expr: &Arc, + stat: Stat, + stats_to_fetch: &mut HashMap>, +) -> Option> { + if let Some(col) = expr.as_any().downcast_ref::() { + let new_field = stat_column_name(col.field(), stat); + match stats_to_fetch.entry(col.field().clone()) { + Entry::Occupied(o) => o.into_mut().push(stat), + Entry::Vacant(v) => { + v.insert(vec![stat]); + } + } + return Some(Arc::new(Column::new(new_field))); + } + + if let Some(bexp) = expr.as_any().downcast_ref::() { + let rewritten_lhs = replace_column_with_stat(bexp.lhs(), stat, stats_to_fetch); + let rewritten_rhs = replace_column_with_stat(bexp.rhs(), stat, stats_to_fetch); + if rewritten_lhs.is_none() && rewritten_rhs.is_none() { + return None; + } + + let lhs = rewritten_lhs.unwrap_or_else(|| bexp.lhs().clone()); + let rhs = rewritten_rhs.unwrap_or_else(|| bexp.rhs().clone()); + + return Some(Arc::new(BinaryExpr::new(lhs, bexp.op(), rhs))); + } + + None +} + +fn stat_column_name(field: &Field, stat: Stat) -> Field { + match field { + Field::Name(n) => Field::Name(format!("{n}_{stat}")), + Field::Index(i) => Field::Name(format!("{i}_{stat}")), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use ahash::HashMap; + use vortex::stats::Stat; + use vortex_dtype::field::Field; + use vortex_expr::{BinaryExpr, Column, Literal, Operator, VortexExpr}; + + use crate::layouts::pruning::{convert_to_pruning_expression, stat_column_name}; + + #[test] + pub fn pruning_equals() { + let column = Field::from("a"); + let literal_eq = Arc::new(Literal::new(42.into())); + let eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::Eq, + literal_eq.clone(), + )) as _; + let (converted, refs) = convert_to_pruning_expression(&eq_expr); + assert_eq!( + refs, + HashMap::from_iter([(column.clone(), vec![Stat::Min, Stat::Max])]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Min))), + Operator::Lte, + literal_eq.clone(), + )), + Operator::And, + Arc::new(BinaryExpr::new( + literal_eq, + Operator::Lte, + Arc::new(Column::new(stat_column_name(&column, Stat::Max))), + )), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } + + #[test] + pub fn pruning_equals_column() { + let column = Field::from("a"); + let other_col = Field::from("b"); + let eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::Eq, + Arc::new(Column::new(other_col.clone())), + )) as _; + + let (converted, refs) = convert_to_pruning_expression(&eq_expr); + assert_eq!( + refs, + HashMap::from_iter([ + (column.clone(), vec![Stat::Min, Stat::Max]), + (other_col.clone(), vec![Stat::Max, Stat::Min]) + ]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Min))), + Operator::Lte, + Arc::new(Column::new(stat_column_name(&other_col, Stat::Max))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&other_col, Stat::Min))), + Operator::Lte, + Arc::new(Column::new(stat_column_name(&column, Stat::Max))), + )), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } + + #[test] + pub fn pruning_not_equals_column() { + let column = Field::from("a"); + let other_col = Field::from("b"); + let not_eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::NotEq, + Arc::new(Column::new(other_col.clone())), + )) as _; + + let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); + assert_eq!( + refs, + HashMap::from_iter([ + (column.clone(), vec![Stat::Min, Stat::Max]), + (other_col.clone(), vec![Stat::Max, Stat::Min]) + ]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Min))), + Operator::NotEq, + Arc::new(Column::new(stat_column_name(&other_col, Stat::Min))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&other_col, Stat::Min))), + Operator::NotEq, + Arc::new(Column::new(stat_column_name(&column, Stat::Max))), + )), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Min))), + Operator::NotEq, + Arc::new(Column::new(stat_column_name(&other_col, Stat::Max))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&other_col, Stat::Max))), + Operator::NotEq, + Arc::new(Column::new(stat_column_name(&column, Stat::Max))), + )), + )), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } + + #[test] + pub fn pruning_gt_column() { + let column = Field::from("a"); + let other_col = Field::from("b"); + let other_expr = Arc::new(Column::new(other_col.clone())); + let not_eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::Gt, + other_expr.clone(), + )) as _; + + let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); + assert_eq!( + refs, + HashMap::from_iter([ + (column.clone(), vec![Stat::Max]), + (other_col.clone(), vec![Stat::Min]) + ]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Max))), + Operator::Gt, + Arc::new(Column::new(stat_column_name(&other_col, Stat::Min))), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } + + #[test] + pub fn pruning_gt_value() { + let column = Field::from("a"); + let other_col = Arc::new(Literal::new(42.into())); + let not_eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::Gt, + other_col.clone(), + )) as _; + + let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); + assert_eq!( + refs, + HashMap::from_iter([(column.clone(), vec![Stat::Max]),]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Max))), + Operator::Gt, + other_col.clone(), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } + + #[test] + pub fn pruning_lt_column() { + let column = Field::from("a"); + let other_col = Field::from("b"); + let other_expr = Arc::new(Column::new(other_col.clone())); + let not_eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::Lt, + other_expr.clone(), + )) as _; + + let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); + assert_eq!( + refs, + HashMap::from_iter([ + (column.clone(), vec![Stat::Min]), + (other_col.clone(), vec![Stat::Max]) + ]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Min))), + Operator::Lt, + Arc::new(Column::new(stat_column_name(&other_col, Stat::Max))), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } + + #[test] + pub fn pruning_lt_value() { + let column = Field::from("a"); + let other_col = Arc::new(Literal::new(42.into())); + let not_eq_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new(column.clone())), + Operator::Lt, + other_col.clone(), + )) as _; + + let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); + assert_eq!( + refs, + HashMap::from_iter([(column.clone(), vec![Stat::Min]),]) + ); + let expected_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new(stat_column_name(&column, Stat::Min))), + Operator::Lt, + other_col.clone(), + )); + assert_eq!(*converted, *expected_expr.as_any()); + } +} diff --git a/vortex-serde/src/layouts/read/filtering.rs b/vortex-serde/src/layouts/read/filtering.rs index b566613399..b4c6e18c87 100644 --- a/vortex-serde/src/layouts/read/filtering.rs +++ b/vortex-serde/src/layouts/read/filtering.rs @@ -10,10 +10,8 @@ pub struct RowFilter { } impl RowFilter { - pub fn new(disjunction: Arc) -> Self { - Self { - filter: disjunction, - } + pub fn new(filter: Arc) -> Self { + Self { filter } } pub fn project(&self, _fields: &[FieldPath]) -> Self { diff --git a/vortex-serde/src/layouts/write/writer.rs b/vortex-serde/src/layouts/write/writer.rs index 70dd49c60f..eef6afabbf 100644 --- a/vortex-serde/src/layouts/write/writer.rs +++ b/vortex-serde/src/layouts/write/writer.rs @@ -126,15 +126,11 @@ impl LayoutWriter { .zip(chunk.byte_offsets.iter().skip(1)) .map(|(begin, end)| Layout::Flat(FlatLayout::new(*begin, *end))) .collect(); - chunk.byte_offsets.truncate(len); chunk.row_offsets.truncate(len); let metadata_array = StructArray::try_new( - ["byte_offset".into(), "row_offset".into()].into(), - vec![ - chunk.byte_offsets.into_array(), - chunk.row_offsets.into_array(), - ], + ["row_offset".into()].into(), + vec![chunk.row_offsets.into_array()], len, Validity::NonNullable, )?;