diff --git a/vortex-array/benches/compare.rs b/vortex-array/benches/compare.rs index 5728cc6984..fb5b3bcc7b 100644 --- a/vortex-array/benches/compare.rs +++ b/vortex-array/benches/compare.rs @@ -12,23 +12,12 @@ fn compare_bool(c: &mut Criterion) { let mut group = c.benchmark_group("compare"); let mut rng = thread_rng(); - let range = Uniform::new(0u8, 1); - let arr = BoolArray::from( - (0..10_000_000) - .map(|_| rng.sample(range) == 0) - .collect_vec(), - ) - .into_array(); - let arr2 = BoolArray::from( - (0..10_000_000) - .map(|_| rng.sample(range) == 0) - .collect_vec(), - ) - .into_array(); + let arr = BoolArray::from((0..10_000_000).map(|_| rng.gen()).collect_vec()).into_array(); + let arr2 = BoolArray::from((0..10_000_000).map(|_| rng.gen()).collect_vec()).into_array(); group.bench_function("compare_bool", |b| { b.iter(|| { - let indices = compare(&arr, &arr2, Operator::GreaterThanOrEqualTo).unwrap(); + let indices = compare(&arr, &arr2, Operator::LessThan).unwrap(); black_box(indices); Ok::<(), VortexError>(()) }); @@ -52,7 +41,7 @@ fn compare_int(c: &mut Criterion) { group.bench_function("compare_int", |b| { b.iter(|| { - let indices = compare(&arr, &arr2, Operator::GreaterThanOrEqualTo).unwrap(); + let indices = compare(&arr, &arr2, Operator::LessThan).unwrap(); black_box(indices); Ok::<(), VortexError>(()) }); diff --git a/vortex-array/benches/compare_scalar.rs b/vortex-array/benches/compare_scalar.rs index e6076f70f6..616e12f552 100644 --- a/vortex-array/benches/compare_scalar.rs +++ b/vortex-array/benches/compare_scalar.rs @@ -12,18 +12,11 @@ fn compare_bool_scalar(c: &mut Criterion) { let mut group = c.benchmark_group("compare_scalar"); let mut rng = thread_rng(); - let range = Uniform::new(0u8, 1); - let arr = BoolArray::from( - (0..10_000_000) - .map(|_| rng.sample(range) == 0) - .collect_vec(), - ) - .into_array(); + let arr = BoolArray::from((0..10_000_000).map(|_| rng.gen()).collect_vec()).into_array(); group.bench_function("compare_bool", |b| { b.iter(|| { - let indices = - compare_scalar(&arr, Operator::GreaterThanOrEqualTo, &false.into()).unwrap(); + let indices = compare_scalar(&arr, Operator::LessThan, &false.into()).unwrap(); black_box(indices); Ok::<(), VortexError>(()) }); @@ -42,8 +35,7 @@ fn compare_int_scalar(c: &mut Criterion) { group.bench_function("compare_int", |b| { b.iter(|| { - let indices = - compare_scalar(&arr, Operator::GreaterThanOrEqualTo, &50_000_000.into()).unwrap(); + let indices = compare_scalar(&arr, Operator::LessThan, &50_000_000.into()).unwrap(); black_box(indices); Ok::<(), VortexError>(()) }); diff --git a/vortex-array/benches/filter_indices.rs b/vortex-array/benches/filter_indices.rs index 7f5e7127ff..efb7b6fd5f 100644 --- a/vortex-array/benches/filter_indices.rs +++ b/vortex-array/benches/filter_indices.rs @@ -1,15 +1,20 @@ +use std::sync::Arc; + use criterion::{black_box, criterion_group, criterion_main, Criterion}; use itertools::Itertools; use rand::distributions::Uniform; use rand::{thread_rng, Rng}; +use vortex::array::r#struct::StructArray; +use vortex::validity::Validity; use vortex::IntoArray; -use vortex_dtype::field_paths::FieldPath; +use vortex_dtype::field_paths::{field, FieldPath}; use vortex_error::VortexError; use vortex_expr::expressions::{lit, Conjunction, Disjunction}; use vortex_expr::field_paths::FieldPathOperations; +use vortex_expr::operators::{field_comparison, Operator}; -fn filter_indices(c: &mut Criterion) { - let mut group = c.benchmark_group("filter_indices"); +fn filter_indices_primitive(c: &mut Criterion) { + let mut group = c.benchmark_group("filter_indices_primitive"); let mut rng = thread_rng(); let range = Uniform::new(0i64, 100_000_000); @@ -34,5 +39,39 @@ fn filter_indices(c: &mut Criterion) { }); } -criterion_group!(benches, filter_indices); +fn filter_indices_struct(c: &mut Criterion) { + let mut group = c.benchmark_group("filter_indices_struct"); + + let mut rng = thread_rng(); + let range = Uniform::new(0i64, 100_000_000); + let arr = (0..10_000_000) + .map(|_| rng.sample(range)) + .collect_vec() + .into_array(); + let arr2 = (0..10_000_000) + .map(|_| rng.sample(range)) + .collect_vec() + .into_array(); + + let structs = StructArray::try_new( + Arc::new([Arc::from("field_a"), Arc::from("field_b")]), + vec![arr, arr2.clone()], + arr2.len(), + Validity::AllValid, + ) + .unwrap() + .into_array(); + let predicate = field_comparison(Operator::LessThan, field("field_a"), field("field_b")); + + group.bench_function("vortex", |b| { + b.iter(|| { + let indices = + vortex::compute::filter_indices::filter_indices(&structs, &predicate).unwrap(); + black_box(indices); + Ok::<(), VortexError>(()) + }); + }); +} + +criterion_group!(benches, filter_indices_primitive, filter_indices_struct); criterion_main!(benches); diff --git a/vortex-array/src/array/bool/compute/compare.rs b/vortex-array/src/array/bool/compute/compare.rs index 8417edf7cc..9b66e77ad8 100644 --- a/vortex-array/src/array/bool/compute/compare.rs +++ b/vortex-array/src/array/bool/compute/compare.rs @@ -1,9 +1,9 @@ -use std::ops::{BitAnd, BitOr, BitXor, Not}; +use std::ops::BitAnd; use vortex_error::{vortex_err, VortexResult}; use vortex_expr::operators::Operator; -use crate::array::bool::BoolArray; +use crate::array::bool::{apply_comparison_op, BoolArray}; use crate::compute::compare::CompareFn; use crate::{Array, ArrayTrait, IntoArray}; @@ -15,21 +15,14 @@ impl CompareFn for BoolArray { .map_err(|_| vortex_err!("Cannot compare boolean array with non-boolean array"))?; let lhs = self.boolean_buffer(); let rhs = flattened.boolean_buffer(); - let result_buf = match op { - Operator::EqualTo => lhs.bitxor(&rhs).not(), - Operator::NotEqualTo => lhs.bitxor(&rhs), + let comparison_result = apply_comparison_op(lhs, rhs, op); - Operator::GreaterThan => lhs.bitand(&rhs.not()), - Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()), - Operator::LessThan => lhs.not().bitand(&rhs), - Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs), - }; Ok(BoolArray::from( self.validity() .to_logical(self.len()) .to_null_buffer()? - .map(|nulls| result_buf.bitand(&nulls.into_inner())) - .unwrap_or(result_buf), + .map(|nulls| comparison_result.bitand(&nulls.into_inner())) + .unwrap_or(comparison_result), ) .into_array()) } diff --git a/vortex-array/src/array/bool/compute/compare_scalar.rs b/vortex-array/src/array/bool/compute/compare_scalar.rs index 75ad1cc3c8..b7d0cdbe39 100644 --- a/vortex-array/src/array/bool/compute/compare_scalar.rs +++ b/vortex-array/src/array/bool/compute/compare_scalar.rs @@ -1,4 +1,4 @@ -use std::ops::{BitAnd, BitOr, BitXor, Not}; +use std::ops::BitAnd; use arrow_buffer::BooleanBufferBuilder; use vortex_dtype::DType; @@ -6,20 +6,18 @@ use vortex_error::{vortex_bail, vortex_err, VortexResult}; use vortex_expr::operators::Operator; use vortex_scalar::Scalar; -use crate::array::bool::BoolArray; +use crate::array::bool::{apply_comparison_op, BoolArray}; use crate::compute::compare_scalar::CompareScalarFn; use crate::{Array, ArrayTrait, IntoArray}; impl CompareScalarFn for BoolArray { fn compare_scalar(&self, op: Operator, scalar: &Scalar) -> VortexResult { - match scalar.dtype() { - DType::Bool(_) => {} - _ => { - vortex_bail!("Invalid dtype for boolean scalar comparison") - } + if let DType::Bool(_) = scalar.dtype() { + } else { + vortex_bail!("Invalid dtype for boolean scalar comparison") } - let lhs = self.boolean_buffer(); + let lhs = self.boolean_buffer(); let scalar_val = scalar .value() .as_bool()? @@ -28,22 +26,14 @@ impl CompareScalarFn for BoolArray { let mut rhs = BooleanBufferBuilder::new(self.len()); rhs.append_n(self.len(), scalar_val); let rhs = rhs.finish(); - let result_buf = match op { - Operator::EqualTo => lhs.bitxor(&rhs).not(), - Operator::NotEqualTo => lhs.bitxor(&rhs), - Operator::GreaterThan => lhs.bitand(&rhs.not()), - Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()), - Operator::LessThan => lhs.not().bitand(&rhs), - Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs), - }; - - let present = self - .validity() - .to_logical(self.len()) - .to_present_null_buffer()? - .into_inner(); - - Ok(BoolArray::from(result_buf.bitand(&present)).into_array()) + let comparison_result = apply_comparison_op(lhs, rhs, op); + + let present = self.validity().to_logical(self.len()).to_null_buffer()?; + let with_validity_applied = present + .map(|p| comparison_result.bitand(&p.into_inner())) + .unwrap_or(comparison_result); + + Ok(BoolArray::from(with_validity_applied).into_array()) } } @@ -78,6 +68,21 @@ mod test { let matches = compare_scalar(&arr, Operator::NotEqualTo, &false.into())?.flatten_bool()?; assert_eq!(to_int_indices(matches), [1u64, 3]); + + let matches = compare_scalar(&arr, Operator::GreaterThan, &false.into())?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [1u64, 3]); + + let matches = + compare_scalar(&arr, Operator::GreaterThanOrEqualTo, &false.into())?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [1u64, 2, 3]); + + let matches = compare_scalar(&arr, Operator::LessThan, &false.into())?.flatten_bool()?; + let empty: [u64; 0] = []; + assert_eq!(to_int_indices(matches), empty); + + let matches = + compare_scalar(&arr, Operator::LessThanOrEqualTo, &false.into())?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [2u64]); Ok(()) } } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index e9b93c1f56..52081625d7 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -1,7 +1,10 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + use arrow_buffer::BooleanBuffer; use itertools::Itertools; use serde::{Deserialize, Serialize}; use vortex_buffer::Buffer; +use vortex_expr::operators::Operator; use crate::validity::{ArrayValidity, ValidityMetadata}; use crate::validity::{LogicalValidity, Validity}; @@ -57,6 +60,17 @@ impl BoolArray { } } +pub fn apply_comparison_op(lhs: BooleanBuffer, rhs: BooleanBuffer, op: Operator) -> BooleanBuffer { + match op { + Operator::EqualTo => lhs.bitxor(&rhs).not(), + Operator::NotEqualTo => lhs.bitxor(&rhs), + Operator::GreaterThan => lhs.bitand(&rhs.not()), + Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()), + Operator::LessThan => lhs.not().bitand(&rhs), + Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs), + } +} + impl From for BoolArray { fn from(value: BooleanBuffer) -> Self { Self::try_new(value, Validity::NonNullable).unwrap() diff --git a/vortex-array/src/array/primitive/compute/compare.rs b/vortex-array/src/array/primitive/compute/compare.rs index 67eac4b9c5..bf6e2276b1 100644 --- a/vortex-array/src/array/primitive/compute/compare.rs +++ b/vortex-array/src/array/primitive/compute/compare.rs @@ -11,6 +11,8 @@ use crate::compute::compare::CompareFn; use crate::{Array, ArrayTrait, IntoArray}; impl CompareFn for PrimitiveArray { + // @TODO(@jcasale) take stats into account here, which may allow us to elide some comparison + // work based on sortedness/min/max/etc. fn compare(&self, other: &Array, predicate: Operator) -> VortexResult { let flattened = other .clone() @@ -22,18 +24,21 @@ impl CompareFn for PrimitiveArray { apply_predicate(self.typed_data::<$T>(), flattened.typed_data::<$T>(), predicate_fn) }); - let present = self - .validity() - .to_logical(self.len()) - .to_present_null_buffer()? - .into_inner(); + let present = self.validity().to_logical(self.len()).to_null_buffer()?; + let with_validity_applied = present + .map(|p| matching_idxs.bitand(&p.into_inner())) + .unwrap_or(matching_idxs); + let present_other = flattened .validity() .to_logical(self.len()) - .to_present_null_buffer()? - .into_inner(); + .to_null_buffer()?; + + let with_other_validity_applied = present_other + .map(|p| with_validity_applied.bitand(&p.into_inner())) + .unwrap_or(with_validity_applied); - Ok(BoolArray::from(matching_idxs.bitand(&present).bitand(&present_other)).into_array()) + Ok(BoolArray::from(with_other_validity_applied).into_array()) } } diff --git a/vortex-array/src/array/primitive/compute/compare_scalar.rs b/vortex-array/src/array/primitive/compute/compare_scalar.rs index 4d39477ec9..98ba504595 100644 --- a/vortex-array/src/array/primitive/compute/compare_scalar.rs +++ b/vortex-array/src/array/primitive/compute/compare_scalar.rs @@ -9,16 +9,17 @@ use vortex_scalar::Scalar; use crate::array::bool::BoolArray; use crate::array::primitive::PrimitiveArray; use crate::compute::compare_scalar::CompareScalarFn; -use crate::{Array, ArrayDType, ArrayTrait, IntoArray}; +use crate::{Array, ArrayTrait, IntoArray}; impl CompareScalarFn for PrimitiveArray { + // @TODO(@jcasale) take stats into account here, which may allow us to elide some comparison + // work based on sortedness/min/max/etc. fn compare_scalar(&self, op: Operator, scalar: &Scalar) -> VortexResult { - match self.dtype() { - DType::Primitive(..) => {} - _ => { - vortex_bail!("Invalid scalar dtype for primitive comparison") - } + if let DType::Primitive(..) = scalar.dtype() { + } else { + vortex_bail!("Invalid scalar dtype for boolean scalar comparison") } + let p_val = scalar .value() .as_pvalue()? @@ -29,13 +30,12 @@ impl CompareScalarFn for PrimitiveArray { apply_predicate(self.typed_data::<$T>(), &rhs, predicate_fn) }); - let present = self - .validity() - .to_logical(self.len()) - .to_present_null_buffer()? - .into_inner(); + let present = self.validity().to_logical(self.len()).to_null_buffer()?; + let with_validity_applied = present + .map(|p| matching_idxs.bitand(&p.into_inner())) + .unwrap_or(matching_idxs); - Ok(BoolArray::from(matching_idxs.bitand(&present)).into_array()) + Ok(BoolArray::from(with_validity_applied).into_array()) } } @@ -86,6 +86,9 @@ mod test { let matches = compare_scalar(&arr, Operator::EqualTo, &5.into())?.flatten_bool()?; assert_eq!(to_int_indices(matches), [5u64]); + let matches = compare_scalar(&arr, Operator::NotEqualTo, &5.into())?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 6, 7, 8, 10]); + let matches = compare_scalar(&arr, Operator::EqualTo, &11.into())?.flatten_bool()?; let empty: [u64; 0] = []; assert_eq!(to_int_indices(matches), empty); diff --git a/vortex-array/src/array/primitive/compute/filter_indices.rs b/vortex-array/src/array/primitive/compute/filter_indices.rs index a49834d5d3..ec219c8e0d 100644 --- a/vortex-array/src/array/primitive/compute/filter_indices.rs +++ b/vortex-array/src/array/primitive/compute/filter_indices.rs @@ -1,14 +1,15 @@ use std::ops::{BitAnd, BitOr}; use arrow_buffer::BooleanBuffer; -use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::{vortex_bail, VortexResult}; use vortex_expr::expressions::{Disjunction, Predicate, Value}; use crate::array::bool::BoolArray; use crate::array::primitive::PrimitiveArray; +use crate::compute::compare::CompareFn; +use crate::compute::compare_scalar::CompareScalarFn; use crate::compute::filter_indices::FilterIndicesFn; -use crate::{Array, ArrayTrait, IntoArray}; +use crate::{Array, ArrayDType, ArrayTrait, IntoArray}; impl FilterIndicesFn for PrimitiveArray { fn filter_indices(&self, predicate: &Disjunction) -> VortexResult { @@ -42,39 +43,32 @@ fn indices_matching_predicate( vortex_bail!("Invalid path for primitive array") } - let rhs = match &predicate.right { - Value::Field(_) => { - vortex_bail!("Cannot apply field reference to primitive array") + match &predicate.right { + Value::Field(path) => { + let rhs = arr.clone().into_array().resolve_field(arr.dtype(), path)?; + arr.compare(&rhs, predicate.op)? + .flatten_bool() + .map(|arr| arr.boolean_buffer()) } - Value::Literal(scalar) => scalar, - }; - - let matching_idxs = match_each_native_ptype!(arr.ptype(), |$T| { - let rhs_typed: $T = rhs.try_into().unwrap(); - let predicate_fn = &predicate.op.to_predicate::<$T>(); - apply_predicate(arr.typed_data::<$T>(), &rhs_typed, predicate_fn) - }); - - Ok(matching_idxs) -} - -fn apply_predicate bool>( - lhs: &[T], - rhs: &T, - f: F, -) -> BooleanBuffer { - let matches = lhs.iter().map(|lhs| f(lhs, rhs)); - BooleanBuffer::from_iter(matches) + Value::Literal(scalar) => arr + .compare_scalar(predicate.op, scalar)? + .flatten_bool() + .map(|arr| arr.boolean_buffer()), + } } #[cfg(test)] mod test { + use std::sync::Arc; + use itertools::Itertools; - use vortex_dtype::field_paths::FieldPathBuilder; + use vortex_dtype::field_paths::{field, FieldPathBuilder}; use vortex_expr::expressions::{lit, Conjunction}; use vortex_expr::field_paths::FieldPathOperations; + use vortex_expr::operators::{field_comparison, Operator}; use super::*; + use crate::array::r#struct::StructArray; use crate::validity::Validity; fn apply_conjunctive_filter(arr: &PrimitiveArray, conj: Conjunction) -> VortexResult { @@ -242,4 +236,47 @@ mod test { ) .expect_err("Cannot apply field reference to primitive array"); } + + #[test] + fn test_basic_field_comparisons() -> VortexResult<()> { + let ints = + PrimitiveArray::from_nullable_vec(vec![Some(0u64), Some(1), None, Some(3), Some(4)]); + let other = + PrimitiveArray::from_nullable_vec(vec![Some(0u64), Some(2), None, Some(5), Some(1)]); + + let structs = StructArray::try_new( + Arc::new([Arc::from("field_a"), Arc::from("field_b")]), + vec![ints.into_array(), other.clone().into_array()], + 5, + Validity::AllValid, + )?; + + fn comparison(op: Operator) -> Disjunction { + field_comparison(op, field("field_a"), field("field_b")) + } + + let matches = FilterIndicesFn::filter_indices(&structs, &comparison(Operator::EqualTo))? + .flatten_bool()?; + assert_eq!(to_int_indices(matches), [0]); + + let matches = FilterIndicesFn::filter_indices(&structs, &comparison(Operator::LessThan))? + .flatten_bool()?; + assert_eq!(to_int_indices(matches), [1, 3]); + + let matches = + FilterIndicesFn::filter_indices(&structs, &comparison(Operator::LessThanOrEqualTo))? + .flatten_bool()?; + assert_eq!(to_int_indices(matches), [0, 1, 3]); + + let matches = + FilterIndicesFn::filter_indices(&structs, &comparison(Operator::GreaterThan))? + .flatten_bool()?; + assert_eq!(to_int_indices(matches), [4]); + + let matches = + FilterIndicesFn::filter_indices(&structs, &comparison(Operator::GreaterThanOrEqualTo))? + .flatten_bool()?; + assert_eq!(to_int_indices(matches), [0, 4]); + Ok(()) + } } diff --git a/vortex-array/src/array/struct/compute.rs b/vortex-array/src/array/struct/compute.rs index 392ebe88a9..e7f1b751c6 100644 --- a/vortex-array/src/array/struct/compute.rs +++ b/vortex-array/src/array/struct/compute.rs @@ -7,9 +7,7 @@ use arrow_array::{ use arrow_buffer::BooleanBuffer; use arrow_schema::{Field, Fields}; use itertools::Itertools; -use vortex_dtype::field_paths::{FieldIdentifier, FieldPath}; -use vortex_dtype::DType; -use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use vortex_error::VortexResult; use vortex_expr::expressions::{Conjunction, Disjunction, Predicate, Value}; use vortex_scalar::Scalar; @@ -179,57 +177,41 @@ impl FilterIndicesFn for StructArray { } fn indices_matching_predicate(arr: &StructArray, pred: &Predicate) -> VortexResult { - let inner = resolve_field(arr.clone().into_array(), arr.dtype(), &pred.left)?; + let inner = arr + .clone() + .into_array() + .resolve_field(arr.dtype(), &pred.left)?; match &pred.right { Value::Field(rh_field) => { - let rhs = resolve_field(arr.clone().into_array(), arr.dtype(), rh_field)?; + let rhs = arr + .clone() + .into_array() + .resolve_field(arr.dtype(), rh_field)?; Ok(compare(&inner, &rhs, pred.op)? .flatten_bool()? .boolean_buffer()) } Value::Literal(_) => { - let conjunction = Conjunction { + let conj = Conjunction { predicates: vec![pred.clone()], }; - let d = Disjunction { - conjunctions: vec![conjunction], + let disj = Disjunction { + conjunctions: vec![conj], }; - Ok(filter_indices(&inner, &d)?.flatten_bool()?.boolean_buffer()) - } - } -} - -fn resolve_field(array: Array, dtype: &DType, path: &FieldPath) -> VortexResult { - match dtype { - DType::Struct(struct_dtype, _) => { - let current = path.head().ok_or_else(|| vortex_err!(""))?; - if let FieldIdentifier::Name(field_name) = current { - let idx = struct_dtype - .find_name(field_name.as_str()) - .ok_or_else(|| vortex_err!("Query not compatible with dtype"))?; - let inner_dtype = struct_dtype.dtypes().get(idx).unwrap(); - let inner_name = path.tail().ok_or_else(|| vortex_err!(""))?; - resolve_field( - array.child(idx, inner_dtype).unwrap(), - inner_dtype, - &inner_name, - ) - } else { - vortex_bail!("Query not compatible with dtype") - } + Ok(filter_indices(&inner, &disj)? + .flatten_bool()? + .boolean_buffer()) } - _ => Ok(array), } } #[cfg(test)] mod test { use itertools::Itertools; - use vortex_dtype::field_paths::field; - use vortex_dtype::{Nullability, PType, StructDType}; - use vortex_expr::expressions::Value::Field; - use vortex_expr::operators::Operator; + use vortex_dtype::field_paths::{field, FieldPath}; + use vortex_dtype::{DType, Nullability, PType, StructDType}; + use vortex_expr::operators::{field_comparison, Operator}; use super::*; use crate::array::primitive::PrimitiveArray; @@ -248,16 +230,8 @@ mod test { filtered } - fn comparison(op: Operator, left: FieldPath, right: FieldPath) -> Disjunction { - Disjunction { - conjunctions: vec![Conjunction { - predicates: vec![Predicate { - left, - op, - right: Field(right), - }], - }], - } + fn comparison(op: Operator) -> Disjunction { + field_comparison(op, field("field_a"), field("field_b")) } #[test] @@ -273,9 +247,9 @@ mod test { Nullability::NonNullable, ); - let ints = + let ints_a = PrimitiveArray::from_nullable_vec(vec![Some(0u64), Some(1), None, Some(3), Some(4)]); - let other = + let ints_b = PrimitiveArray::from_nullable_vec(vec![Some(0u64), Some(2), None, Some(5), Some(1)]); let structs = StructArray::try_from_parts( @@ -285,24 +259,12 @@ mod test { validity: ValidityMetadata::AllValid, }, Arc::new([ - ints.clone().into_array_data(), - other.clone().into_array_data(), + ints_a.clone().into_array_data(), + ints_b.clone().into_array_data(), ]), StatsSet::new(), )?; - fn comparison(op: Operator) -> Disjunction { - Disjunction { - conjunctions: vec![Conjunction { - predicates: vec![Predicate { - left: field("field_a"), - op, - right: Field(field("field_b")), - }], - }], - } - } - let matches = FilterIndicesFn::filter_indices(&structs, &comparison(Operator::EqualTo))? .flatten_bool()?; assert_eq!(to_int_indices(matches), [0]); @@ -352,9 +314,9 @@ mod test { Nullability::NonNullable, ); - let ints = + let ints_a = PrimitiveArray::from_nullable_vec(vec![Some(0u64), Some(1), None, Some(3), Some(4)]); - let other = + let other_b = PrimitiveArray::from_nullable_vec(vec![Some(0u64), Some(2), None, Some(5), Some(1)]); let structs = StructArray::try_from_parts( @@ -364,8 +326,8 @@ mod test { validity: ValidityMetadata::AllValid, }, Arc::new([ - ints.clone().into_array_data(), - other.clone().into_array_data(), + ints_a.clone().into_array_data(), + other_b.clone().into_array_data(), ]), StatsSet::new(), )?; @@ -378,7 +340,7 @@ mod test { }, Arc::new([ structs.clone().into_array_data(), - other.clone().into_array_data(), + other_b.clone().into_array_data(), ]), StatsSet::new(), )?; @@ -393,7 +355,7 @@ mod test { let mixed_level_cmp = |op: Operator| -> VortexResult { FilterIndicesFn::filter_indices( top_level_structs, - &comparison( + &field_comparison( op, FieldPath::builder().join("struct").join("field_a").build(), field("flat"), @@ -422,7 +384,7 @@ mod test { let nested_cmp = |op: Operator| -> VortexResult { FilterIndicesFn::filter_indices( top_level_structs, - &comparison( + &field_comparison( op, FieldPath::builder().join("struct").join("field_a").build(), FieldPath::builder().join("struct").join("field_b").build(), diff --git a/vortex-array/src/compute/compare_scalar.rs b/vortex-array/src/compute/compare_scalar.rs index 517118322f..11244ddaea 100644 --- a/vortex-array/src/compute/compare_scalar.rs +++ b/vortex-array/src/compute/compare_scalar.rs @@ -16,7 +16,7 @@ pub fn compare_scalar(array: &Array, comparator: Operator, scalar: &Scalar) -> V }) { return matching_indices; } - // if compare is not implemented for the given array type, but the array has a numeric + // if compare_scalar is not implemented for the given array type, but the array has a numeric // DType, we can flatten the array and apply filter to the flattened primitive array match array.dtype() { DType::Primitive(..) => { @@ -24,7 +24,7 @@ pub fn compare_scalar(array: &Array, comparator: Operator, scalar: &Scalar) -> V flat.compare_scalar(comparator, scalar) } _ => Err(vortex_err!( - NotImplemented: "compare", + NotImplemented: "compare_scalar", array.encoding().id() )), } diff --git a/vortex-array/src/lib.rs b/vortex-array/src/lib.rs index 134d91781f..fca45ec1b9 100644 --- a/vortex-array/src/lib.rs +++ b/vortex-array/src/lib.rs @@ -32,8 +32,9 @@ pub use metadata::*; pub use typed::*; pub use view::*; use vortex_buffer::Buffer; +use vortex_dtype::field_paths::{FieldIdentifier, FieldPath}; use vortex_dtype::DType; -use vortex_error::VortexResult; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; use crate::compute::ArrayCompute; use crate::encoding::{ArrayEncodingRef, EncodingRef}; @@ -59,6 +60,7 @@ pub mod flatbuffers { #[allow(unused_imports)] pub use vortex_dtype::flatbuffers as dtype; } + pub mod scalar { #[allow(unused_imports)] pub use vortex_scalar::flatbuffers as scalar; @@ -123,6 +125,44 @@ impl Array { futures_util::stream::once(ready(Ok(self))), ) } + + pub fn resolve_field(self, dtype: &DType, path: &FieldPath) -> VortexResult { + match dtype { + DType::Struct(struct_dtype, _) => { + let current = path + .head() + .ok_or_else(|| vortex_err!("Invalid path for struct array"))?; + if let FieldIdentifier::Name(field_name) = current { + let idx = struct_dtype + .find_name(field_name.as_str()) + .ok_or_else(|| vortex_err!("Query not compatible with dtype"))?; + let inner_dtype = struct_dtype + .dtypes() + .get(idx) + .expect("Looking up known index should never fail"); + let inner_name = path + .tail() + .ok_or_else(|| vortex_err!("Invalid path for struct array"))?; + self.child(idx, inner_dtype) + .ok_or_else(|| vortex_err!("Invalid dtype for array"))? + .resolve_field(inner_dtype, &inner_name) + } else { + vortex_bail!("Query not compatible with dtype") + } + } + DType::List(..) => { + // TODO(@jcasale): resolve list fields in a follow-on + vortex_bail!(NotImplemented: "Resolving list fields not yet implemented", self.dtype()) + } + _ => { + if path.head().is_none() { + Ok(self) + } else { + vortex_bail!("Invalid path for non-nested array") + } + } + } + } } pub trait ToArray { @@ -177,6 +217,7 @@ pub trait ArrayDType { } struct NBytesVisitor(usize); + impl ArrayVisitor for NBytesVisitor { fn visit_child(&mut self, _name: &str, array: &Array) -> VortexResult<()> { self.0 += array.with_dyn(|a| a.nbytes()); diff --git a/vortex-expr/src/operators.rs b/vortex-expr/src/operators.rs index b6c6c289af..932260cdb4 100644 --- a/vortex-expr/src/operators.rs +++ b/vortex-expr/src/operators.rs @@ -1,8 +1,9 @@ use std::ops; +use vortex_dtype::field_paths::FieldPath; use vortex_dtype::NativePType; -use crate::expressions::Predicate; +use crate::expressions::{Conjunction, Disjunction, Predicate, Value}; #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -59,3 +60,15 @@ impl Operator { } } } + +pub fn field_comparison(op: Operator, left: FieldPath, right: FieldPath) -> Disjunction { + Disjunction { + conjunctions: vec![Conjunction { + predicates: vec![Predicate { + left, + op, + right: Value::Field(right), + }], + }], + } +}