Skip to content

Commit

Permalink
DataFusion expr conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn committed Jun 12, 2024
1 parent d416f69 commit 2fef12f
Show file tree
Hide file tree
Showing 16 changed files with 714 additions and 274 deletions.
711 changes: 506 additions & 205 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ cargo_metadata = "0.18.1"
criterion = { version = "0.5.1", features = ["html_reports"] }
croaring = "1.0.1"
csv = "1.3.0"
datafusion-common = "39.0.0"
datafusion-expr = "39.0.0"
derive_builder = "0.20.0"
divan = "0.1.14"
duckdb = { version = "0.10.1", features = ["bundled"] }
Expand Down
8 changes: 2 additions & 6 deletions vortex-array/benches/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ fn filter_bool_indices(c: &mut Criterion) {

group.bench_function("compare_bool", |b| {
b.iter(|| {
let indices =
vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo)
.unwrap();
let indices = vortex::compute::compare::compare(&arr, &arr2, Operator::Gte).unwrap();
black_box(indices);
Ok::<(), VortexError>(())
});
Expand All @@ -53,9 +51,7 @@ fn filter_indices(c: &mut Criterion) {

group.bench_function("compare_int", |b| {
b.iter(|| {
let indices =
vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo)
.unwrap();
let indices = vortex::compute::compare::compare(&arr, &arr2, Operator::Gte).unwrap();
black_box(indices);
Ok::<(), VortexError>(())
});
Expand Down
24 changes: 12 additions & 12 deletions vortex-array/src/array/bool/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ impl CompareFn for BoolArray {
let lhs = self.boolean_buffer();
let rhs = flattened.boolean_buffer();
let result_buf = match op {
Operator::EqualTo => lhs.bitxor(&rhs).not(),
Operator::NotEqualTo => lhs.bitxor(&rhs),
Operator::Eq => lhs.bitxor(&rhs).not(),
Operator::NotEq => lhs.bitxor(&rhs),

Operator::GreaterThan => lhs.bitand(&rhs.not()),
Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()),
Operator::LessThan => lhs.not().bitand(&rhs),
Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs),
Operator::Gt => lhs.bitand(&rhs.not()),
Operator::Gte => lhs.bitor(&rhs.not()),
Operator::Lt => lhs.not().bitand(&rhs),
Operator::Lte => lhs.not().bitor(&rhs),
};
Ok(BoolArray::from(
self.validity()
Expand Down Expand Up @@ -58,10 +58,10 @@ mod test {
)
.into_array();

let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::Eq)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);

let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::NotEq)?.flatten_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

Expand All @@ -71,16 +71,16 @@ mod test {
)
.into_array();

let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [2u64, 3, 4]);

let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [4u64]);

let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [2u64, 3, 4]);

let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [4u64]);
Ok(())
}
Expand Down
12 changes: 6 additions & 6 deletions vortex-array/src/array/primitive/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ mod test {
])
.into_array();

let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::Eq)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::NotEq)?.flatten_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

Expand All @@ -101,16 +101,16 @@ mod test {
])
.into_array();

let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
Ok(())
}
Expand Down
4 changes: 4 additions & 0 deletions vortex-dtype/src/field_paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ impl FieldPath {
Some(Self::builder().join_all(new_field_names).build())
}
}

pub fn parts(&self) -> &[FieldIdentifier] {
&self.field_names
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down
6 changes: 4 additions & 2 deletions vortex-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ rust-version = { workspace = true }
workspace = true

[dependencies]
datafusion-common = { workspace = true, optional = true }
datafusion-expr = { workspace = true, optional = true }
vortex-dtype = { path = "../vortex-dtype" }
vortex-error = { path = "../vortex-error" }
vortex-scalar = { path = "../vortex-scalar" }
serde = { workspace = true, optional = true, features = ["derive"] }


[dev-dependencies]


[features]
default = []
datafusion = ["dep:datafusion-common", "dep:datafusion-expr", "vortex-scalar/datafusion"]
serde = ["dep:serde", "vortex-dtype/serde", "vortex-scalar/serde"]
63 changes: 63 additions & 0 deletions vortex-expr/src/datafusion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#![cfg(feature = "datafusion")]
use datafusion_common::Column;
use datafusion_expr::{BinaryExpr, Expr};
use vortex_dtype::field_paths::{FieldIdentifier, FieldPath};
use vortex_scalar::Scalar;

use crate::expressions::{Predicate, Value};
use crate::operators::Operator;

impl From<Predicate> for Expr {
fn from(value: Predicate) -> Self {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(FieldPathWrapper(value.left).into()),
value.op.into(),
Box::new(value.right.into()),
))
}
}

impl From<Operator> for datafusion_expr::Operator {
fn from(value: Operator) -> Self {
match value {
Operator::Eq => datafusion_expr::Operator::Eq,
Operator::NotEq => datafusion_expr::Operator::NotEq,
Operator::Gt => datafusion_expr::Operator::Gt,
Operator::Gte => datafusion_expr::Operator::GtEq,
Operator::Lt => datafusion_expr::Operator::Lt,
Operator::Lte => datafusion_expr::Operator::LtEq,
}
}
}

impl From<Value> for Expr {
fn from(value: Value) -> Self {
match value {
Value::Field(field_path) => FieldPathWrapper(field_path).into(),
Value::Literal(literal) => ScalarWrapper(literal).into(),
}
}
}

struct FieldPathWrapper(FieldPath);
impl From<FieldPathWrapper> for Expr {
fn from(value: FieldPathWrapper) -> Self {
let mut field = String::new();
for part in value.0.parts() {
match part {
// TODO(ngates): escape quotes?
FieldIdentifier::Name(identifier) => field.push_str(&format!("\"{}\"", identifier)),
FieldIdentifier::ListIndex(idx) => field.push_str(&format!("[{}]", idx)),
}
}

Expr::Column(Column::from(field))
}
}

struct ScalarWrapper(Scalar);
impl From<ScalarWrapper> for Expr {
fn from(value: ScalarWrapper) -> Self {
Expr::Literal(value.0.into())
}
}
12 changes: 6 additions & 6 deletions vortex-expr/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ impl Display for Value {
impl Display for Operator {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let display = match &self {
Operator::EqualTo => "=",
Operator::NotEqualTo => "!=",
Operator::GreaterThan => ">",
Operator::GreaterThanOrEqualTo => ">=",
Operator::LessThan => "<",
Operator::LessThanOrEqualTo => "<=",
Operator::Eq => "=",
Operator::NotEq => "!=",
Operator::Gt => ">",
Operator::Gte => ">=",
Operator::Lt => "<",
Operator::Lte => "<=",
};
write!(f, "{display}")
}
Expand Down
14 changes: 7 additions & 7 deletions vortex-expr/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,47 +50,47 @@ impl Value {
pub fn eq(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::EqualTo,
op: Operator::Eq,
right: self,
}
}

pub fn not_eq(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::NotEqualTo.inverse(),
op: Operator::NotEq.inverse(),
right: self,
}
}

pub fn gt(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::GreaterThan.inverse(),
op: Operator::Gt.inverse(),
right: self,
}
}

pub fn gte(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::GreaterThanOrEqualTo.inverse(),
op: Operator::Gte.inverse(),
right: self,
}
}

pub fn lt(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::LessThan.inverse(),
op: Operator::Lt.inverse(),
right: self,
}
}

pub fn lte(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::LessThanOrEqualTo.inverse(),
op: Operator::Lte.inverse(),
right: self,
}
}
Expand All @@ -109,7 +109,7 @@ mod test {
let field = field("id");
let expr = Predicate {
left: field,
op: Operator::EqualTo,
op: Operator::Eq,
right: value,
};
assert_eq!(format!("{}", expr), "($id = 1)");
Expand Down
12 changes: 6 additions & 6 deletions vortex-expr/src/field_paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,47 @@ impl FieldPathOperations for FieldPath {
fn eq(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::EqualTo,
op: Operator::Eq,
right: other,
}
}

fn not_eq(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::NotEqualTo,
op: Operator::NotEq,
right: other,
}
}

fn gt(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::GreaterThan,
op: Operator::Gt,
right: other,
}
}

fn gte(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::GreaterThanOrEqualTo,
op: Operator::Gte,
right: other,
}
}

fn lt(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::LessThan,
op: Operator::Lt,
right: other,
}
}

fn lte(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::LessThanOrEqualTo,
op: Operator::Lte,
right: other,
}
}
Expand Down
1 change: 1 addition & 0 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(iter_intersperse)]
extern crate core;

mod datafusion;
mod display;
pub mod expressions;
pub mod field_paths;
Expand Down
Loading

0 comments on commit 2fef12f

Please sign in to comment.