From 7a1a434072e19ef25ee47d2f8873ba810bd07750 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Mon, 24 Jun 2024 18:43:14 +0300 Subject: [PATCH] Fallback to BoolArray stats and add some other tests --- encodings/byte_bool/src/compute/mod.rs | 59 +++++++++++++- encodings/byte_bool/src/stats.rs | 108 +++++++++++++++++++------ vortex-array/src/compute/compare.rs | 6 +- 3 files changed, 145 insertions(+), 28 deletions(-) diff --git a/encodings/byte_bool/src/compute/mod.rs b/encodings/byte_bool/src/compute/mod.rs index 58f75ec88..89bd14330 100644 --- a/encodings/byte_bool/src/compute/mod.rs +++ b/encodings/byte_bool/src/compute/mod.rs @@ -188,7 +188,7 @@ impl FillForwardFn for ByteBoolArray { #[cfg(test)] mod tests { use vortex::{ - compute::{scalar_at::scalar_at, slice::slice}, + compute::{compare::compare, scalar_at::scalar_at, slice::slice}, AsArray as _, }; @@ -213,4 +213,61 @@ mod tests { let s = scalar_at(sliced_arr.as_array_ref(), 2).unwrap(); assert_eq!(s.into_value().as_bool().unwrap(), Some(false)); } + + #[test] + fn test_compare_all_equal() { + let lhs = ByteBoolArray::from(vec![true; 5]); + let rhs = ByteBoolArray::from(vec![true; 5]); + + let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq) + .unwrap() + .flatten_bool() + .unwrap(); + + for i in 0..arr.len() { + assert!(arr.is_valid(i)); + let s = scalar_at(arr.as_array_ref(), i).unwrap(); + assert_eq!(s.value(), &ScalarValue::Bool(true)); + } + } + + #[test] + fn test_compare_all_different() { + let lhs = ByteBoolArray::from(vec![false; 5]); + let rhs = ByteBoolArray::from(vec![true; 5]); + + let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq) + .unwrap() + .flatten_bool() + .unwrap(); + + for i in 0..arr.len() { + assert!(arr.is_valid(i)); + let s = scalar_at(arr.as_array_ref(), i).unwrap(); + assert_eq!(s.value(), &ScalarValue::Bool(false)); + } + } + + #[test] + fn test_compare_with_nulls() { + let lhs = ByteBoolArray::from(vec![true; 5]); + let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]); + + let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq) + .unwrap() + .flatten_bool() + .unwrap(); + + for i in 0..3 { + assert!(arr.is_valid(i)); + let s = scalar_at(arr.as_array_ref(), i).unwrap(); + assert_eq!(s.value(), &ScalarValue::Bool(true)); + } + + assert!(arr.is_valid(3)); + let s = scalar_at(arr.as_array_ref(), 3).unwrap(); + assert_eq!(s.value(), &ScalarValue::Bool(false)); + + assert!(!arr.is_valid(4)); + } } diff --git a/encodings/byte_bool/src/stats.rs b/encodings/byte_bool/src/stats.rs index bb02ef968..15ec5fa68 100644 --- a/encodings/byte_bool/src/stats.rs +++ b/encodings/byte_bool/src/stats.rs @@ -1,9 +1,6 @@ -use std::collections::HashMap; - use vortex::{ stats::{ArrayStatisticsCompute, Stat, StatsSet}, - validity::{ArrayValidity, LogicalValidity}, - ArrayDType, ArrayTrait, AsArray, + ArrayTrait, AsArray, }; use vortex_error::VortexResult; @@ -15,27 +12,90 @@ impl ArrayStatisticsCompute for ByteBoolArray { return Ok(StatsSet::new()); } - match self.logical_validity() { - LogicalValidity::AllValid(len) => Ok(all_true_bool_stats(len)), - LogicalValidity::AllInvalid(len) => Ok(StatsSet::nulls(len, self.dtype())), - LogicalValidity::Array(_) => { - let bools = self.as_array_ref().clone().flatten_bool()?; - bools.compute_statistics(stat) - } - } + let bools = self.as_array_ref().clone().flatten_bool()?; + bools.compute_statistics(stat) } } -fn all_true_bool_stats(len: usize) -> StatsSet { - let stats = HashMap::from([ - (Stat::Min, true.into()), - (Stat::Min, true.into()), - (Stat::IsConstant, true.into()), - (Stat::IsSorted, true.into()), - (Stat::IsStrictSorted, (len < 2).into()), - (Stat::RunCount, 1.into()), - (Stat::NullCount, 0.into()), - ]); - - StatsSet::from(stats) +#[cfg(test)] +mod tests { + use vortex::stats::ArrayStatistics; + use vortex_dtype::{DType, Nullability}; + use vortex_scalar::Scalar; + + use super::*; + + #[test] + fn bool_stats() { + let bool_arr = + ByteBoolArray::from(vec![false, false, true, true, false, true, true, false]); + assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); + assert!(!bool_arr.statistics().compute_is_sorted().unwrap()); + assert!(!bool_arr.statistics().compute_is_constant().unwrap()); + assert!(!bool_arr.statistics().compute_min::().unwrap()); + assert!(bool_arr.statistics().compute_max::().unwrap()); + assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 5); + assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 4); + } + + #[test] + fn strict_sorted() { + let bool_arr_1 = ByteBoolArray::from(vec![false, true]); + assert!(bool_arr_1.statistics().compute_is_strict_sorted().unwrap()); + assert!(bool_arr_1.statistics().compute_is_sorted().unwrap()); + + let bool_arr_2 = ByteBoolArray::from(vec![true]); + assert!(bool_arr_2.statistics().compute_is_strict_sorted().unwrap()); + assert!(bool_arr_2.statistics().compute_is_sorted().unwrap()); + + let bool_arr_3 = ByteBoolArray::from(vec![false]); + assert!(bool_arr_3.statistics().compute_is_strict_sorted().unwrap()); + assert!(bool_arr_3.statistics().compute_is_sorted().unwrap()); + + let bool_arr_4 = ByteBoolArray::from(vec![true, false]); + assert!(!bool_arr_4.statistics().compute_is_strict_sorted().unwrap()); + assert!(!bool_arr_4.statistics().compute_is_sorted().unwrap()); + + let bool_arr_5 = ByteBoolArray::from(vec![false, true, true]); + assert!(!bool_arr_5.statistics().compute_is_strict_sorted().unwrap()); + assert!(bool_arr_5.statistics().compute_is_sorted().unwrap()); + } + + #[test] + fn nullable_stats() { + let bool_arr = ByteBoolArray::from(vec![ + Some(false), + Some(true), + None, + Some(true), + Some(false), + None, + None, + ]); + assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); + assert!(!bool_arr.statistics().compute_is_sorted().unwrap()); + assert!(!bool_arr.statistics().compute_is_constant().unwrap()); + assert!(!bool_arr.statistics().compute_min::().unwrap()); + assert!(bool_arr.statistics().compute_max::().unwrap()); + assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 3); + assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 2); + } + + #[test] + fn all_nullable_stats() { + let bool_arr = ByteBoolArray::from(vec![None, None, None, None, None]); + assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); + assert!(bool_arr.statistics().compute_is_sorted().unwrap()); + assert!(bool_arr.statistics().compute_is_constant().unwrap()); + assert_eq!( + bool_arr.statistics().compute(Stat::Min).unwrap(), + Scalar::null(DType::Bool(Nullability::Nullable)) + ); + assert_eq!( + bool_arr.statistics().compute(Stat::Max).unwrap(), + Scalar::null(DType::Bool(Nullability::Nullable)) + ); + assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); + assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); + } } diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index ceeede9fb..9529d7ede 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -8,9 +8,9 @@ pub trait CompareFn { fn compare(&self, array: &Array, predicate: Operator) -> VortexResult; } -pub fn compare(array: &Array, other: &Array, predicate: Operator) -> VortexResult { +pub fn compare(array: &Array, other: &Array, operator: Operator) -> VortexResult { if let Some(matching_indices) = - array.with_dyn(|c| c.compare().map(|t| t.compare(other, predicate))) + array.with_dyn(|c| c.compare().map(|t| t.compare(other, operator))) { return matching_indices; } @@ -19,7 +19,7 @@ pub fn compare(array: &Array, other: &Array, predicate: Operator) -> VortexResul match array.dtype() { DType::Primitive(..) => { let flat = array.clone().flatten_primitive()?; - flat.compare(other, predicate) + flat.compare(other, operator) } _ => Err(vortex_err!( NotImplemented: "compare",