Skip to content

Commit

Permalink
fix: required stats are relations not maps (#1432)
Browse files Browse the repository at this point in the history
Extend overwrites keys when the right-hand side contains a key that is
already present in th left-hand side.
  • Loading branch information
danking authored Nov 21, 2024
1 parent 30e8a21 commit 682e281
Showing 1 changed file with 94 additions and 29 deletions.
123 changes: 94 additions & 29 deletions vortex-file/src/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![allow(dead_code)]

use std::fmt::Display;
use std::hash::Hash;

use itertools::Itertools;
use vortex_array::aliases::hash_map::HashMap;
Expand All @@ -12,21 +13,57 @@ use vortex_dtype::Nullability;
use vortex_expr::{BinaryExpr, Column, ExprRef, Literal, Not, Operator};
use vortex_scalar::Scalar;

#[derive(Debug, Clone)]
pub struct Relation<K, V> {
map: HashMap<K, HashSet<V>>,
}

impl<K: Display, V: Display> Display for Relation<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
self.map.iter().format_with(",", |(k, v), fmt| {
fmt(&format_args!("{k}: {{{}}}", v.iter().format(",")))
})
)
}
}

impl<K: Hash + Eq, V: Hash + Eq> Relation<K, V> {
pub fn new() -> Self {
Relation {
map: HashMap::new(),
}
}

pub fn extend(&mut self, other: Relation<K, V>) {
for (l, rs) in other.map.into_iter() {
self.map.entry(l).or_default().extend(rs.into_iter())
}
}

pub fn insert(&mut self, k: K, v: V) {
self.map.entry(k).or_default().insert(v);
}

pub fn into_map(self) -> HashMap<K, HashSet<V>> {
self.map
}
}

#[derive(Debug, Clone)]
pub struct PruningPredicate {
expr: ExprRef,
required_stats: HashMap<Field, HashSet<Stat>>,
required_stats: Relation<Field, Stat>,
}

impl Display for PruningPredicate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PruningPredicate({}, {{{}}})",
self.expr,
self.required_stats.iter().format_with(",", |(k, v), fmt| {
fmt(&format_args!("{k}: {{{}}}", v.iter().format(",")))
})
self.expr, self.required_stats
)
}
}
Expand Down Expand Up @@ -65,7 +102,7 @@ impl PruningPredicate {
}

pub fn required_stats(&self) -> &HashMap<Field, HashSet<Stat>> {
&self.required_stats
&self.required_stats.map
}
}

Expand Down Expand Up @@ -99,7 +136,7 @@ fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats {
.unwrap_or_else(|| {
(
Literal::new_expr(Scalar::bool(false, Nullability::NonNullable)),
HashMap::new(),
Relation::new(),
)
});
};
Expand All @@ -114,20 +151,20 @@ fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats {
.unwrap_or_else(|| {
(
Literal::new_expr(Scalar::bool(false, Nullability::NonNullable)),
HashMap::new(),
Relation::new(),
)
});
};
}

(
Literal::new_expr(Scalar::bool(false, Nullability::NonNullable)),
HashMap::new(),
Relation::new(),
)
}

fn convert_column_reference(expr: &ExprRef, invert: bool) -> PruningPredicateStats {
let mut refs = HashMap::new();
let mut refs = Relation::new();
let min_expr = replace_column_with_stat(expr, Stat::Min, &mut refs);
let max_expr = replace_column_with_stat(expr, Stat::Max, &mut refs);
(
Expand All @@ -149,10 +186,10 @@ struct PruningPredicateRewriter<'a> {
column: Field,
operator: Operator,
other_exp: &'a ExprRef,
stats_to_fetch: HashMap<Field, HashSet<Stat>>,
stats_to_fetch: Relation<Field, Stat>,
}

type PruningPredicateStats = (ExprRef, HashMap<Field, HashSet<Stat>>);
type PruningPredicateStats = (ExprRef, Relation<Field, Stat>);

impl<'a> PruningPredicateRewriter<'a> {
pub fn try_new(column: Field, operator: Operator, other_exp: &'a ExprRef) -> Option<Self> {
Expand All @@ -166,16 +203,13 @@ impl<'a> PruningPredicateRewriter<'a> {
column,
operator,
other_exp,
stats_to_fetch: HashMap::new(),
stats_to_fetch: Relation::new(),
})
}

fn add_stat_reference(&mut self, stat: Stat) -> Field {
let new_field = stat_column_name(&self.column, stat);
self.stats_to_fetch
.entry(self.column.clone())
.or_default()
.insert(stat);
self.stats_to_fetch.insert(self.column.clone(), stat);
new_field
}

Expand Down Expand Up @@ -243,14 +277,11 @@ impl<'a> PruningPredicateRewriter<'a> {
fn replace_column_with_stat(
expr: &ExprRef,
stat: Stat,
stats_to_fetch: &mut HashMap<Field, HashSet<Stat>>,
stats_to_fetch: &mut Relation<Field, Stat>,
) -> Option<ExprRef> {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
let new_field = stat_column_name(col.field(), stat);
stats_to_fetch
.entry(col.field().clone())
.or_default()
.insert(stat);
stats_to_fetch.insert(col.field().clone(), stat);
return Some(Column::new_expr(new_field));
}

Expand Down Expand Up @@ -303,7 +334,7 @@ mod tests {
);
let (converted, refs) = convert_to_pruning_expression(&eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Min, Stat::Max]))])
);
let expected_expr = BinaryExpr::new_expr(
Expand Down Expand Up @@ -334,7 +365,7 @@ mod tests {

let (converted, refs) = convert_to_pruning_expression(&eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([
(column.clone(), HashSet::from_iter([Stat::Min, Stat::Max])),
(
Expand Down Expand Up @@ -371,7 +402,7 @@ mod tests {

let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([
(column.clone(), HashSet::from_iter([Stat::Min, Stat::Max])),
(
Expand Down Expand Up @@ -418,7 +449,7 @@ mod tests {

let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([
(column.clone(), HashSet::from_iter([Stat::Max])),
(other_col.clone(), HashSet::from_iter([Stat::Min]))
Expand All @@ -444,7 +475,7 @@ mod tests {

let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Max])),])
);
let expected_expr = BinaryExpr::new_expr(
Expand All @@ -468,7 +499,7 @@ mod tests {

let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([
(column.clone(), HashSet::from_iter([Stat::Min])),
(other_col.clone(), HashSet::from_iter([Stat::Max]))
Expand All @@ -494,7 +525,7 @@ mod tests {

let (converted, refs) = convert_to_pruning_expression(&not_eq_expr);
assert_eq!(
refs,
refs.into_map(),
HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Min]))])
);
let expected_expr = BinaryExpr::new_expr(
Expand Down Expand Up @@ -526,4 +557,38 @@ mod tests {
"PruningPredicate(($a_min >= 42_i32), {$a: {min}})"
);
}

#[test]
fn or_required_stats_from_both_arms() {
let column = Column::new_expr(Field::from("a"));
let expr = BinaryExpr::new_expr(
BinaryExpr::new_expr(column.clone(), Operator::Lt, Literal::new_expr(10.into())),
Operator::Or,
BinaryExpr::new_expr(column, Operator::Gt, Literal::new_expr(50.into())),
);

let expected = HashMap::from([(Field::from("a"), HashSet::from([Stat::Min, Stat::Max]))]);

assert_eq!(
PruningPredicate::try_new(&expr).unwrap().required_stats(),
&expected
);
}

#[test]
fn and_required_stats_from_both_arms() {
let column = Column::new_expr(Field::from("a"));
let expr = BinaryExpr::new_expr(
BinaryExpr::new_expr(column.clone(), Operator::Gt, Literal::new_expr(50.into())),
Operator::And,
BinaryExpr::new_expr(column, Operator::Lt, Literal::new_expr(10.into())),
);

let expected = HashMap::from([(Field::from("a"), HashSet::from([Stat::Min, Stat::Max]))]);

assert_eq!(
PruningPredicate::try_new(&expr).unwrap().required_stats(),
&expected
);
}
}

0 comments on commit 682e281

Please sign in to comment.