Skip to content

Commit

Permalink
More specialized compare functions (#488)
Browse files Browse the repository at this point in the history
Doesn't quite get us where we want performance-wise, but does seem much
better.
Follow up of #481.

---------

Co-authored-by: Robert Kruszewski <[email protected]>
  • Loading branch information
AdamGS and robert3005 authored Jul 22, 2024
1 parent 87221d3 commit 0e19847
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 39 additions & 2 deletions vortex-array/src/array/constant/compute.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use std::cmp::Ordering;
use std::sync::Arc;

use arrow_array::Datum;
use vortex_dtype::Nullability;
use vortex_error::{vortex_bail, VortexResult};
use vortex_expr::Operator;
use vortex_scalar::Scalar;

use crate::array::constant::ConstantArray;
use crate::arrow::FromArrowArray;
use crate::compute::unary::{scalar_at, ScalarAtFn};
use crate::compute::{
AndFn, ArrayCompute, OrFn, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn,
scalar_cmp, AndFn, ArrayCompute, CompareFn, OrFn, SearchResult, SearchSortedFn,
SearchSortedSide, SliceFn, TakeFn,
};
use crate::stats::{ArrayStatistics, Stat};
use crate::{Array, ArrayDType, AsArray, IntoArray};
use crate::{Array, ArrayDType, ArrayData, AsArray, IntoArray, IntoCanonical};

impl ArrayCompute for ConstantArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Expand All @@ -29,6 +34,10 @@ impl ArrayCompute for ConstantArray {
Some(self)
}

fn compare(&self) -> Option<&dyn CompareFn> {
Some(self)
}

fn and(&self) -> Option<&dyn AndFn> {
Some(self)
}
Expand Down Expand Up @@ -69,6 +78,34 @@ impl SearchSortedFn for ConstantArray {
}
}

impl CompareFn for ConstantArray {
fn compare(&self, rhs: &Array, operator: Operator) -> VortexResult<Array> {
if let Some(true) = rhs.statistics().get_as::<bool>(Stat::IsConstant) {
let lhs = self.scalar();
let rhs = scalar_at(rhs, 0)?;

let scalar = scalar_cmp(lhs, &rhs, operator);

Ok(ConstantArray::new(scalar, self.len()).into_array())
} else {
let datum = Arc::<dyn Datum>::from(self.scalar());
let rhs = rhs.clone().into_canonical()?.into_arrow();
let rhs = rhs.as_ref();

let boolean_array = match operator {
Operator::Eq => arrow_ord::cmp::eq(datum.as_ref(), &rhs)?,
Operator::NotEq => arrow_ord::cmp::neq(datum.as_ref(), &rhs)?,
Operator::Gt => arrow_ord::cmp::gt(datum.as_ref(), &rhs)?,
Operator::Gte => arrow_ord::cmp::gt_eq(datum.as_ref(), &rhs)?,
Operator::Lt => arrow_ord::cmp::lt(datum.as_ref(), &rhs)?,
Operator::Lte => arrow_ord::cmp::lt_eq(datum.as_ref(), &rhs)?,
};

Ok(ArrayData::from_arrow(&boolean_array, true).into_array())
}
}
}

impl AndFn for ConstantArray {
fn and(&self, array: &Array) -> VortexResult<Array> {
constant_array_bool_impl(
Expand Down
1 change: 0 additions & 1 deletion vortex-array/src/array/varbin/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ impl<O: NativePType> VarBinBuilder<O> {
pub fn finish(mut self, dtype: DType) -> VarBinArray {
let offsets = PrimitiveArray::from(self.offsets);
let data = PrimitiveArray::from_bytes(self.data.freeze(), Validity::NonNullable);

let nulls = self.validity.finish();

let validity = if dtype.is_nullable() {
Expand Down
2 changes: 2 additions & 0 deletions vortex-array/src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub fn and(lhs: &Array, rhs: &Array) -> VortexResult<Array> {
return selection;
}

// If neither side implements `AndFn`, we try to expand the left-hand side into a `BoolArray`, which we know does implement it, and call into that implementation.
let lhs = lhs.clone().into_bool()?;

lhs.and(rhs)
Expand All @@ -49,6 +50,7 @@ pub fn or(lhs: &Array, rhs: &Array) -> VortexResult<Array> {
return selection;
}

// If neither side implements `OrFn`, we try to expand the left-hand side into a `BoolArray`, which we know does implement it, and call into that implementation.
let lhs = lhs.clone().into_bool()?;

lhs.or(rhs)
Expand Down
39 changes: 37 additions & 2 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
use arrow_ord::cmp;
use vortex_error::VortexResult;
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, VortexResult};
use vortex_expr::Operator;
use vortex_scalar::Scalar;

use crate::arrow::FromArrowArray;
use crate::{Array, ArrayData, IntoArray, IntoCanonical};
use crate::{Array, ArrayDType, ArrayData, IntoArray, IntoCanonical};

pub trait CompareFn {
fn compare(&self, array: &Array, operator: Operator) -> VortexResult<Array>;
}

pub fn compare(left: &Array, right: &Array, operator: Operator) -> VortexResult<Array> {
if left.len() != right.len() {
vortex_bail!("Compare operations only support arrays of the same length");
}

// TODO(adamg): This is a placeholder until we figure out type coercion and casting
if !left.dtype().eq_ignore_nullability(right.dtype()) {
vortex_bail!("Compare operations only support arrays of the same type");
}

if let Some(selection) =
left.with_dyn(|lhs| lhs.compare().map(|lhs| lhs.compare(right, operator)))
{
return selection;
}

if let Some(selection) = right.with_dyn(|rhs| {
rhs.compare()
.map(|rhs| rhs.compare(left, operator.inverse()))
}) {
return selection;
}

// Fallback to arrow on canonical types
let lhs = left.clone().into_canonical()?.into_arrow();
let rhs = right.clone().into_canonical()?.into_arrow();
Expand All @@ -31,3 +49,20 @@ pub fn compare(left: &Array, right: &Array, operator: Operator) -> VortexResult<

Ok(ArrayData::from_arrow(&array, true).into_array())
}

pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
if lhs.is_null() | rhs.is_null() {
Scalar::null(DType::Bool(Nullability::Nullable))
} else {
let b = match operator {
Operator::Eq => lhs == rhs,
Operator::NotEq => lhs != rhs,
Operator::Gt => lhs > rhs,
Operator::Gte => lhs >= rhs,
Operator::Lt => lhs < rhs,
Operator::Lte => lhs <= rhs,
};

Scalar::bool(b, Nullability::Nullable)
}
}
2 changes: 1 addition & 1 deletion vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! from Arrow.
pub use boolean::{and, or, AndFn, OrFn};
pub use compare::{compare, CompareFn};
pub use compare::{compare, scalar_cmp, CompareFn};
pub use filter::{filter, FilterFn};
pub use filter_indices::{filter_indices, FilterIndicesFn};
pub use search_sorted::*;
Expand Down
15 changes: 4 additions & 11 deletions vortex-scalar/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ edition = { workspace = true }
rust-version = { workspace = true }

[dependencies]
arrow-array = { workspace = true }
datafusion-common = { workspace = true, optional = true }
flatbuffers = { workspace = true, optional = true }
flexbuffers = { workspace = true, optional = true }
Expand Down Expand Up @@ -42,15 +43,7 @@ flatbuffers = [
"dep:serde",
"vortex-buffer/flexbuffers",
"vortex-error/flexbuffers",
"vortex-dtype/flatbuffers"
]
proto = [
"dep:prost",
"dep:prost-types",
"vortex-dtype/proto",
]
serde = [
"dep:serde",
"serde/derive",
"vortex-dtype/serde"
"vortex-dtype/flatbuffers",
]
proto = ["dep:prost", "dep:prost-types", "vortex-dtype/proto"]
serde = ["dep:serde", "serde/derive", "vortex-dtype/serde"]
78 changes: 78 additions & 0 deletions vortex-scalar/src/arrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::sync::Arc;

use arrow_array::*;
use vortex_dtype::{DType, PType};

use crate::{PValue, Scalar};

impl From<&Scalar> for Arc<dyn Datum> {
fn from(value: &Scalar) -> Arc<dyn Datum> {
match value.dtype {
DType::Null => Arc::new(NullArray::new(1)),
DType::Bool(_) => match value.value.as_bool().expect("should be bool") {
Some(b) => Arc::new(BooleanArray::new_scalar(b)),
None => Arc::new(BooleanArray::new_null(1)),
},
DType::Primitive(ptype, _) => {
let pvalue = value.value.as_pvalue().expect("should be pvalue");
match pvalue {
None => match ptype {
PType::U8 => Arc::new(UInt8Array::new_null(1)),
PType::U16 => Arc::new(UInt16Array::new_null(1)),
PType::U32 => Arc::new(UInt32Array::new_null(1)),
PType::U64 => Arc::new(UInt64Array::new_null(1)),
PType::I8 => Arc::new(Int8Array::new_null(1)),
PType::I16 => Arc::new(Int16Array::new_null(1)),
PType::I32 => Arc::new(Int32Array::new_null(1)),
PType::I64 => Arc::new(Int64Array::new_null(1)),
PType::F16 => Arc::new(Float16Array::new_null(1)),
PType::F32 => Arc::new(Float32Array::new_null(1)),
PType::F64 => Arc::new(Float64Array::new_null(1)),
},
Some(pvalue) => match pvalue {
PValue::U8(v) => Arc::new(UInt8Array::new_scalar(v)),
PValue::U16(v) => Arc::new(UInt16Array::new_scalar(v)),
PValue::U32(v) => Arc::new(UInt32Array::new_scalar(v)),
PValue::U64(v) => Arc::new(UInt64Array::new_scalar(v)),
PValue::I8(v) => Arc::new(Int8Array::new_scalar(v)),
PValue::I16(v) => Arc::new(Int16Array::new_scalar(v)),
PValue::I32(v) => Arc::new(Int32Array::new_scalar(v)),
PValue::I64(v) => Arc::new(Int64Array::new_scalar(v)),
PValue::F16(v) => Arc::new(Float16Array::new_scalar(v)),
PValue::F32(v) => Arc::new(Float32Array::new_scalar(v)),
PValue::F64(v) => Arc::new(Float64Array::new_scalar(v)),
},
}
}
DType::Utf8(_) => {
match value
.value
.as_buffer_string()
.expect("should be buffer string")
{
Some(s) => Arc::new(StringArray::new_scalar(s.as_str())),
None => Arc::new(StringArray::new_null(1)),
}
}
DType::Binary(_) => {
match value
.value
.as_buffer_string()
.expect("should be buffer string")
{
Some(s) => Arc::new(BinaryArray::new_scalar(s.as_bytes())),
None => Arc::new(BinaryArray::new_null(1)),
}
}
DType::Struct(..) => {
todo!("struct scalar conversion")
}
DType::List(..) => {
todo!("list scalar conversion")
}
DType::Extension(..) => {
todo!("extension scalar conversion")
}
}
}
}
1 change: 1 addition & 0 deletions vortex-scalar/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::cmp::Ordering;

use vortex_dtype::DType;

mod arrow;
mod binary;
mod bool;
mod datafusion;
Expand Down

0 comments on commit 0e19847

Please sign in to comment.