Skip to content

Commit

Permalink
Fallback to BoolArray stats and add some other tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamGS committed Jun 24, 2024
1 parent 0c5b203 commit 7a1a434
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 28 deletions.
59 changes: 58 additions & 1 deletion encodings/byte_bool/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl FillForwardFn for ByteBoolArray {
#[cfg(test)]
mod tests {
use vortex::{
compute::{scalar_at::scalar_at, slice::slice},
compute::{compare::compare, scalar_at::scalar_at, slice::slice},
AsArray as _,
};

Expand All @@ -213,4 +213,61 @@ mod tests {
let s = scalar_at(sliced_arr.as_array_ref(), 2).unwrap();
assert_eq!(s.into_value().as_bool().unwrap(), Some(false));
}

#[test]
fn test_compare_all_equal() {
let lhs = ByteBoolArray::from(vec![true; 5]);
let rhs = ByteBoolArray::from(vec![true; 5]);

let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq)
.unwrap()
.flatten_bool()
.unwrap();

for i in 0..arr.len() {
assert!(arr.is_valid(i));
let s = scalar_at(arr.as_array_ref(), i).unwrap();
assert_eq!(s.value(), &ScalarValue::Bool(true));
}
}

#[test]
fn test_compare_all_different() {
let lhs = ByteBoolArray::from(vec![false; 5]);
let rhs = ByteBoolArray::from(vec![true; 5]);

let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq)
.unwrap()
.flatten_bool()
.unwrap();

for i in 0..arr.len() {
assert!(arr.is_valid(i));
let s = scalar_at(arr.as_array_ref(), i).unwrap();
assert_eq!(s.value(), &ScalarValue::Bool(false));
}
}

#[test]
fn test_compare_with_nulls() {
let lhs = ByteBoolArray::from(vec![true; 5]);
let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);

let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq)
.unwrap()
.flatten_bool()
.unwrap();

for i in 0..3 {
assert!(arr.is_valid(i));
let s = scalar_at(arr.as_array_ref(), i).unwrap();
assert_eq!(s.value(), &ScalarValue::Bool(true));
}

assert!(arr.is_valid(3));
let s = scalar_at(arr.as_array_ref(), 3).unwrap();
assert_eq!(s.value(), &ScalarValue::Bool(false));

assert!(!arr.is_valid(4));
}
}
108 changes: 84 additions & 24 deletions encodings/byte_bool/src/stats.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::collections::HashMap;

use vortex::{
stats::{ArrayStatisticsCompute, Stat, StatsSet},
validity::{ArrayValidity, LogicalValidity},
ArrayDType, ArrayTrait, AsArray,
ArrayTrait, AsArray,
};
use vortex_error::VortexResult;

Expand All @@ -15,27 +12,90 @@ impl ArrayStatisticsCompute for ByteBoolArray {
return Ok(StatsSet::new());
}

match self.logical_validity() {
LogicalValidity::AllValid(len) => Ok(all_true_bool_stats(len)),
LogicalValidity::AllInvalid(len) => Ok(StatsSet::nulls(len, self.dtype())),
LogicalValidity::Array(_) => {
let bools = self.as_array_ref().clone().flatten_bool()?;
bools.compute_statistics(stat)
}
}
let bools = self.as_array_ref().clone().flatten_bool()?;
bools.compute_statistics(stat)
}
}

fn all_true_bool_stats(len: usize) -> StatsSet {
let stats = HashMap::from([
(Stat::Min, true.into()),
(Stat::Min, true.into()),
(Stat::IsConstant, true.into()),
(Stat::IsSorted, true.into()),
(Stat::IsStrictSorted, (len < 2).into()),
(Stat::RunCount, 1.into()),
(Stat::NullCount, 0.into()),
]);

StatsSet::from(stats)
#[cfg(test)]
mod tests {
use vortex::stats::ArrayStatistics;
use vortex_dtype::{DType, Nullability};
use vortex_scalar::Scalar;

use super::*;

#[test]
fn bool_stats() {
let bool_arr =
ByteBoolArray::from(vec![false, false, true, true, false, true, true, false]);
assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap());
assert!(!bool_arr.statistics().compute_is_sorted().unwrap());
assert!(!bool_arr.statistics().compute_is_constant().unwrap());
assert!(!bool_arr.statistics().compute_min::<bool>().unwrap());
assert!(bool_arr.statistics().compute_max::<bool>().unwrap());
assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 5);
assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 4);
}

#[test]
fn strict_sorted() {
let bool_arr_1 = ByteBoolArray::from(vec![false, true]);
assert!(bool_arr_1.statistics().compute_is_strict_sorted().unwrap());
assert!(bool_arr_1.statistics().compute_is_sorted().unwrap());

let bool_arr_2 = ByteBoolArray::from(vec![true]);
assert!(bool_arr_2.statistics().compute_is_strict_sorted().unwrap());
assert!(bool_arr_2.statistics().compute_is_sorted().unwrap());

let bool_arr_3 = ByteBoolArray::from(vec![false]);
assert!(bool_arr_3.statistics().compute_is_strict_sorted().unwrap());
assert!(bool_arr_3.statistics().compute_is_sorted().unwrap());

let bool_arr_4 = ByteBoolArray::from(vec![true, false]);
assert!(!bool_arr_4.statistics().compute_is_strict_sorted().unwrap());
assert!(!bool_arr_4.statistics().compute_is_sorted().unwrap());

let bool_arr_5 = ByteBoolArray::from(vec![false, true, true]);
assert!(!bool_arr_5.statistics().compute_is_strict_sorted().unwrap());
assert!(bool_arr_5.statistics().compute_is_sorted().unwrap());
}

#[test]
fn nullable_stats() {
let bool_arr = ByteBoolArray::from(vec![
Some(false),
Some(true),
None,
Some(true),
Some(false),
None,
None,
]);
assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap());
assert!(!bool_arr.statistics().compute_is_sorted().unwrap());
assert!(!bool_arr.statistics().compute_is_constant().unwrap());
assert!(!bool_arr.statistics().compute_min::<bool>().unwrap());
assert!(bool_arr.statistics().compute_max::<bool>().unwrap());
assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 3);
assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 2);
}

#[test]
fn all_nullable_stats() {
let bool_arr = ByteBoolArray::from(vec![None, None, None, None, None]);
assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap());
assert!(bool_arr.statistics().compute_is_sorted().unwrap());
assert!(bool_arr.statistics().compute_is_constant().unwrap());
assert_eq!(
bool_arr.statistics().compute(Stat::Min).unwrap(),
Scalar::null(DType::Bool(Nullability::Nullable))
);
assert_eq!(
bool_arr.statistics().compute(Stat::Max).unwrap(),
Scalar::null(DType::Bool(Nullability::Nullable))
);
assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1);
assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0);
}
}
6 changes: 3 additions & 3 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ pub trait CompareFn {
fn compare(&self, array: &Array, predicate: Operator) -> VortexResult<Array>;
}

pub fn compare(array: &Array, other: &Array, predicate: Operator) -> VortexResult<Array> {
pub fn compare(array: &Array, other: &Array, operator: Operator) -> VortexResult<Array> {
if let Some(matching_indices) =
array.with_dyn(|c| c.compare().map(|t| t.compare(other, predicate)))
array.with_dyn(|c| c.compare().map(|t| t.compare(other, operator)))
{
return matching_indices;
}
Expand All @@ -19,7 +19,7 @@ pub fn compare(array: &Array, other: &Array, predicate: Operator) -> VortexResul
match array.dtype() {
DType::Primitive(..) => {
let flat = array.clone().flatten_primitive()?;
flat.compare(other, predicate)
flat.compare(other, operator)
}
_ => Err(vortex_err!(
NotImplemented: "compare",
Expand Down

0 comments on commit 7a1a434

Please sign in to comment.