diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index 77fdb37009..fded160c73 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -2,6 +2,7 @@ #![allow(dead_code)] use std::fmt::Display; +use std::hash::Hash; use itertools::Itertools; use vortex_array::aliases::hash_map::HashMap; @@ -12,10 +13,49 @@ use vortex_dtype::Nullability; use vortex_expr::{BinaryExpr, Column, ExprRef, Literal, Not, Operator}; use vortex_scalar::Scalar; +#[derive(Debug, Clone)] +pub struct Relation { + map: HashMap>, +} + +impl Display for Relation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.map.iter().format_with(",", |(k, v), fmt| { + fmt(&format_args!("{k}: {{{}}}", v.iter().format(","))) + }) + ) + } +} + +impl Relation { + pub fn new() -> Self { + Relation { + map: HashMap::new(), + } + } + + pub fn extend(&mut self, other: Relation) { + for (l, rs) in other.map.into_iter() { + self.map.entry(l).or_default().extend(rs.into_iter()) + } + } + + pub fn insert(&mut self, k: K, v: V) { + self.map.entry(k).or_default().insert(v); + } + + pub fn into_map(self) -> HashMap> { + self.map + } +} + #[derive(Debug, Clone)] pub struct PruningPredicate { expr: ExprRef, - required_stats: HashMap>, + required_stats: Relation, } impl Display for PruningPredicate { @@ -23,10 +63,7 @@ impl Display for PruningPredicate { write!( f, "PruningPredicate({}, {{{}}})", - self.expr, - self.required_stats.iter().format_with(",", |(k, v), fmt| { - fmt(&format_args!("{k}: {{{}}}", v.iter().format(","))) - }) + self.expr, self.required_stats ) } } @@ -65,7 +102,7 @@ impl PruningPredicate { } pub fn required_stats(&self) -> &HashMap> { - &self.required_stats + &self.required_stats.map } } @@ -99,7 +136,7 @@ fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats { .unwrap_or_else(|| { ( Literal::new_expr(Scalar::bool(false, Nullability::NonNullable)), - HashMap::new(), + Relation::new(), ) }); }; @@ -114,7 +151,7 @@ fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats { .unwrap_or_else(|| { ( Literal::new_expr(Scalar::bool(false, Nullability::NonNullable)), - HashMap::new(), + Relation::new(), ) }); }; @@ -122,12 +159,12 @@ fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats { ( Literal::new_expr(Scalar::bool(false, Nullability::NonNullable)), - HashMap::new(), + Relation::new(), ) } fn convert_column_reference(expr: &ExprRef, invert: bool) -> PruningPredicateStats { - let mut refs = HashMap::new(); + let mut refs = Relation::new(); let min_expr = replace_column_with_stat(expr, Stat::Min, &mut refs); let max_expr = replace_column_with_stat(expr, Stat::Max, &mut refs); ( @@ -149,10 +186,10 @@ struct PruningPredicateRewriter<'a> { column: Field, operator: Operator, other_exp: &'a ExprRef, - stats_to_fetch: HashMap>, + stats_to_fetch: Relation, } -type PruningPredicateStats = (ExprRef, HashMap>); +type PruningPredicateStats = (ExprRef, Relation); impl<'a> PruningPredicateRewriter<'a> { pub fn try_new(column: Field, operator: Operator, other_exp: &'a ExprRef) -> Option { @@ -166,16 +203,13 @@ impl<'a> PruningPredicateRewriter<'a> { column, operator, other_exp, - stats_to_fetch: HashMap::new(), + stats_to_fetch: Relation::new(), }) } fn add_stat_reference(&mut self, stat: Stat) -> Field { let new_field = stat_column_name(&self.column, stat); - self.stats_to_fetch - .entry(self.column.clone()) - .or_default() - .insert(stat); + self.stats_to_fetch.insert(self.column.clone(), stat); new_field } @@ -243,14 +277,11 @@ impl<'a> PruningPredicateRewriter<'a> { fn replace_column_with_stat( expr: &ExprRef, stat: Stat, - stats_to_fetch: &mut HashMap>, + stats_to_fetch: &mut Relation, ) -> Option { if let Some(col) = expr.as_any().downcast_ref::() { let new_field = stat_column_name(col.field(), stat); - stats_to_fetch - .entry(col.field().clone()) - .or_default() - .insert(stat); + stats_to_fetch.insert(col.field().clone(), stat); return Some(Column::new_expr(new_field)); } @@ -303,7 +334,7 @@ mod tests { ); let (converted, refs) = convert_to_pruning_expression(&eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Min, Stat::Max]))]) ); let expected_expr = BinaryExpr::new_expr( @@ -334,7 +365,7 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(&eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([ (column.clone(), HashSet::from_iter([Stat::Min, Stat::Max])), ( @@ -371,7 +402,7 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([ (column.clone(), HashSet::from_iter([Stat::Min, Stat::Max])), ( @@ -418,7 +449,7 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([ (column.clone(), HashSet::from_iter([Stat::Max])), (other_col.clone(), HashSet::from_iter([Stat::Min])) @@ -444,7 +475,7 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Max])),]) ); let expected_expr = BinaryExpr::new_expr( @@ -468,7 +499,7 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([ (column.clone(), HashSet::from_iter([Stat::Min])), (other_col.clone(), HashSet::from_iter([Stat::Max])) @@ -494,7 +525,7 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( - refs, + refs.into_map(), HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Min]))]) ); let expected_expr = BinaryExpr::new_expr( @@ -526,4 +557,38 @@ mod tests { "PruningPredicate(($a_min >= 42_i32), {$a: {min}})" ); } + + #[test] + fn or_required_stats_from_both_arms() { + let column = Column::new_expr(Field::from("a")); + let expr = BinaryExpr::new_expr( + BinaryExpr::new_expr(column.clone(), Operator::Lt, Literal::new_expr(10.into())), + Operator::Or, + BinaryExpr::new_expr(column, Operator::Gt, Literal::new_expr(50.into())), + ); + + let expected = HashMap::from([(Field::from("a"), HashSet::from([Stat::Min, Stat::Max]))]); + + assert_eq!( + PruningPredicate::try_new(&expr).unwrap().required_stats(), + &expected + ); + } + + #[test] + fn and_required_stats_from_both_arms() { + let column = Column::new_expr(Field::from("a")); + let expr = BinaryExpr::new_expr( + BinaryExpr::new_expr(column.clone(), Operator::Gt, Literal::new_expr(50.into())), + Operator::And, + BinaryExpr::new_expr(column, Operator::Lt, Literal::new_expr(10.into())), + ); + + let expected = HashMap::from([(Field::from("a"), HashSet::from([Stat::Min, Stat::Max]))]); + + assert_eq!( + PruningPredicate::try_new(&expr).unwrap().required_stats(), + &expected + ); + } }