diff --git a/Cargo.lock b/Cargo.lock index 59eefc6eec..da31a1e724 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5503,8 +5503,10 @@ dependencies = [ name = "vortex-expr" version = "0.1.0" dependencies = [ + "build-vortex", "datafusion-common 39.0.0", "datafusion-expr 39.0.0", + "prost", "serde", "vortex-dtype", "vortex-error", diff --git a/vortex-array/benches/compare.rs b/vortex-array/benches/compare.rs index fe2135a2e7..6555ca32bf 100644 --- a/vortex-array/benches/compare.rs +++ b/vortex-array/benches/compare.rs @@ -5,7 +5,7 @@ use rand::{thread_rng, Rng}; use vortex::array::bool::BoolArray; use vortex::IntoArray; use vortex_error::VortexError; -use vortex_expr::operators::Operator; +use vortex_expr::Operator; fn filter_bool_indices(c: &mut Criterion) { let mut group = c.benchmark_group("compare"); diff --git a/vortex-array/benches/filter_indices.rs b/vortex-array/benches/filter_indices.rs index 7f5e7127ff..57f8c68279 100644 --- a/vortex-array/benches/filter_indices.rs +++ b/vortex-array/benches/filter_indices.rs @@ -5,8 +5,8 @@ use rand::{thread_rng, Rng}; use vortex::IntoArray; use vortex_dtype::field_paths::FieldPath; use vortex_error::VortexError; -use vortex_expr::expressions::{lit, Conjunction, Disjunction}; -use vortex_expr::field_paths::FieldPathOperations; +use vortex_expr::FieldPathOperations; +use vortex_expr::{lit, Conjunction, Disjunction}; fn filter_indices(c: &mut Criterion) { let mut group = c.benchmark_group("filter_indices"); diff --git a/vortex-array/src/array/bool/compute/compare.rs b/vortex-array/src/array/bool/compute/compare.rs index fc26d3d8c7..d359cfa40d 100644 --- a/vortex-array/src/array/bool/compute/compare.rs +++ b/vortex-array/src/array/bool/compute/compare.rs @@ -1,7 +1,7 @@ use std::ops::{BitAnd, BitOr, BitXor, Not}; use vortex_error::VortexResult; -use vortex_expr::operators::Operator; +use vortex_expr::Operator; use crate::array::bool::BoolArray; use crate::compute::compare::CompareFn; diff --git a/vortex-array/src/array/primitive/compute/compare.rs b/vortex-array/src/array/primitive/compute/compare.rs index 53b0a70740..0d74e39bc7 100644 --- a/vortex-array/src/array/primitive/compute/compare.rs +++ b/vortex-array/src/array/primitive/compute/compare.rs @@ -3,7 +3,7 @@ use std::ops::BitAnd; use arrow_buffer::BooleanBuffer; use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::VortexResult; -use vortex_expr::operators::Operator; +use vortex_expr::Operator; use crate::array::bool::BoolArray; use crate::array::primitive::PrimitiveArray; diff --git a/vortex-array/src/array/primitive/compute/filter_indices.rs b/vortex-array/src/array/primitive/compute/filter_indices.rs index a49834d5d3..06676bdb5a 100644 --- a/vortex-array/src/array/primitive/compute/filter_indices.rs +++ b/vortex-array/src/array/primitive/compute/filter_indices.rs @@ -3,7 +3,7 @@ use std::ops::{BitAnd, BitOr}; use arrow_buffer::BooleanBuffer; use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::{vortex_bail, VortexResult}; -use vortex_expr::expressions::{Disjunction, Predicate, Value}; +use vortex_expr::{Disjunction, Predicate, Value}; use crate::array::bool::BoolArray; use crate::array::primitive::PrimitiveArray; @@ -71,8 +71,8 @@ fn apply_predicate bool>( mod test { use itertools::Itertools; use vortex_dtype::field_paths::FieldPathBuilder; - use vortex_expr::expressions::{lit, Conjunction}; - use vortex_expr::field_paths::FieldPathOperations; + use vortex_expr::FieldPathOperations; + use vortex_expr::{lit, Conjunction}; use super::*; use crate::validity::Validity; diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index c1b00a056c..ceeede9fb7 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -1,6 +1,6 @@ use vortex_dtype::DType; use vortex_error::{vortex_err, VortexResult}; -use vortex_expr::operators::Operator; +use vortex_expr::Operator; use crate::{Array, ArrayDType}; diff --git a/vortex-array/src/compute/filter_indices.rs b/vortex-array/src/compute/filter_indices.rs index 0f399786eb..cf40e30c0e 100644 --- a/vortex-array/src/compute/filter_indices.rs +++ b/vortex-array/src/compute/filter_indices.rs @@ -1,6 +1,6 @@ use vortex_dtype::DType; use vortex_error::{vortex_err, VortexResult}; -use vortex_expr::expressions::Disjunction; +use vortex_expr::Disjunction; use crate::{Array, ArrayDType}; diff --git a/vortex-dtype/proto/vortex/dtype/dtype.proto b/vortex-dtype/proto/vortex/dtype/dtype.proto index d3f82c9bd9..db24114223 100644 --- a/vortex-dtype/proto/vortex/dtype/dtype.proto +++ b/vortex-dtype/proto/vortex/dtype/dtype.proto @@ -59,7 +59,7 @@ message Extension { } message DType { - oneof type { + oneof dtype_type { Null null = 1; Bool bool = 2; Primitive primitive = 3; diff --git a/vortex-dtype/proto/vortex/dtype/field_path.proto b/vortex-dtype/proto/vortex/dtype/field_path.proto new file mode 100644 index 0000000000..010acb88c1 --- /dev/null +++ b/vortex-dtype/proto/vortex/dtype/field_path.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package vortex.dtype; + +message FieldPath { + repeated Part parts = 1; + + message Part { + oneof part_type { + string name = 1; + int32 index = 2; + } + } +} \ No newline at end of file diff --git a/vortex-dtype/src/field_paths.rs b/vortex-dtype/src/field_paths.rs index 07368a0952..6eefddd8c2 100644 --- a/vortex-dtype/src/field_paths.rs +++ b/vortex-dtype/src/field_paths.rs @@ -4,7 +4,7 @@ use std::fmt::{Display, Formatter}; #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct FieldPath { - field_names: Vec, + parts: Vec, } impl FieldPath { @@ -13,20 +13,20 @@ impl FieldPath { } pub fn head(&self) -> Option<&FieldIdentifier> { - self.field_names.first() + self.parts.first() } pub fn tail(&self) -> Option { if self.head().is_none() { None } else { - let new_field_names = self.field_names[1..self.field_names.len()].to_vec(); - Some(Self::builder().join_all(new_field_names).build()) + let new_parts = self.parts[1..self.parts.len()].to_vec(); + Some(Self::builder().join_all(new_parts).build()) } } pub fn parts(&self) -> &[FieldIdentifier] { - &self.field_names + &self.parts } } @@ -38,31 +38,30 @@ pub enum FieldIdentifier { } pub struct FieldPathBuilder { - field_names: Vec, + parts: Vec, } impl FieldPathBuilder { pub fn new() -> Self { - Self { - field_names: Vec::new(), - } + Self { parts: Vec::new() } + } + + pub fn push>(&mut self, identifier: T) { + self.parts.push(identifier.into()); } pub fn join>(mut self, identifier: T) -> Self { - self.field_names.push(identifier.into()); + self.push(identifier); self } pub fn join_all(mut self, identifiers: Vec>) -> Self { - self.field_names - .extend(identifiers.into_iter().map(|v| v.into())); + self.parts.extend(identifiers.into_iter().map(|v| v.into())); self } pub fn build(self) -> FieldPath { - FieldPath { - field_names: self.field_names, - } + FieldPath { parts: self.parts } } } @@ -78,9 +77,7 @@ pub fn field(x: impl Into) -> FieldPath { impl From for FieldPath { fn from(value: FieldIdentifier) -> Self { - FieldPath { - field_names: vec![value], - } + FieldPath { parts: vec![value] } } } @@ -108,7 +105,7 @@ impl Display for FieldIdentifier { impl Display for FieldPath { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let formatted = self - .field_names + .parts .iter() .map(|fid| format!("{fid}")) .collect::>() diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs index ab2f42ea9b..aaf587b927 100644 --- a/vortex-dtype/src/serde/proto.rs +++ b/vortex-dtype/src/serde/proto.rs @@ -4,7 +4,9 @@ use std::sync::Arc; use vortex_error::{vortex_err, VortexError, VortexResult}; -use crate::proto::dtype::d_type::Type; +use crate::field_paths::{FieldPath, FieldPathBuilder}; +use crate::proto::dtype::d_type::DtypeType; +use crate::proto::dtype::field_path::part::PartType; use crate::{proto::dtype as pb, DType, ExtDType, ExtID, ExtMetadata, PType, StructDType}; impl TryFrom<&pb::DType> for DType { @@ -12,17 +14,17 @@ impl TryFrom<&pb::DType> for DType { fn try_from(value: &pb::DType) -> Result { match value - .r#type + .dtype_type .as_ref() .ok_or_else(|| vortex_err!(InvalidSerde: "Unrecognized DType"))? { - Type::Null(_) => Ok(Self::Null), - Type::Bool(b) => Ok(Self::Bool(b.nullable.into())), - Type::Primitive(p) => Ok(Self::Primitive(p.r#type().into(), p.nullable.into())), - Type::Decimal(_) => todo!("Not Implemented"), - Type::Utf8(u) => Ok(Self::Utf8(u.nullable.into())), - Type::Binary(b) => Ok(Self::Binary(b.nullable.into())), - Type::Struct(s) => Ok(Self::Struct( + DtypeType::Null(_) => Ok(Self::Null), + DtypeType::Bool(b) => Ok(Self::Bool(b.nullable.into())), + DtypeType::Primitive(p) => Ok(Self::Primitive(p.r#type().into(), p.nullable.into())), + DtypeType::Decimal(_) => todo!("Not Implemented"), + DtypeType::Utf8(u) => Ok(Self::Utf8(u.nullable.into())), + DtypeType::Binary(b) => Ok(Self::Binary(b.nullable.into())), + DtypeType::Struct(s) => Ok(Self::Struct( StructDType::new( s.names.iter().map(|s| s.as_str().into()).collect(), s.dtypes @@ -32,7 +34,7 @@ impl TryFrom<&pb::DType> for DType { ), s.nullable.into(), )), - Type::List(l) => { + DtypeType::List(l) => { let nullable = l.nullable.into(); Ok(Self::List( l.element_type @@ -44,7 +46,7 @@ impl TryFrom<&pb::DType> for DType { nullable, )) } - Type::Extension(e) => Ok(Self::Extension( + DtypeType::Extension(e) => Ok(Self::Extension( ExtDType::new( ExtID::from(e.id.as_str()), e.metadata.as_ref().map(|m| ExtMetadata::from(m.as_ref())), @@ -58,31 +60,31 @@ impl TryFrom<&pb::DType> for DType { impl From<&DType> for pb::DType { fn from(value: &DType) -> Self { Self { - r#type: Some(match value { - DType::Null => Type::Null(pb::Null {}), - DType::Bool(n) => Type::Bool(pb::Bool { + dtype_type: Some(match value { + DType::Null => DtypeType::Null(pb::Null {}), + DType::Bool(n) => DtypeType::Bool(pb::Bool { nullable: (*n).into(), }), - DType::Primitive(ptype, n) => Type::Primitive(pb::Primitive { + DType::Primitive(ptype, n) => DtypeType::Primitive(pb::Primitive { r#type: pb::PType::from(*ptype).into(), nullable: (*n).into(), }), - DType::Utf8(n) => Type::Utf8(pb::Utf8 { + DType::Utf8(n) => DtypeType::Utf8(pb::Utf8 { nullable: (*n).into(), }), - DType::Binary(n) => Type::Binary(pb::Binary { + DType::Binary(n) => DtypeType::Binary(pb::Binary { nullable: (*n).into(), }), - DType::Struct(s, n) => Type::Struct(pb::Struct { + DType::Struct(s, n) => DtypeType::Struct(pb::Struct { names: s.names().iter().map(|s| s.as_ref().to_string()).collect(), dtypes: s.dtypes().iter().map(Into::into).collect(), nullable: (*n).into(), }), - DType::List(l, n) => Type::List(Box::new(pb::List { + DType::List(l, n) => DtypeType::List(Box::new(pb::List { element_type: Some(Box::new(l.as_ref().into())), nullable: (*n).into(), })), - DType::Extension(e, n) => Type::Extension(pb::Extension { + DType::Extension(e, n) => DtypeType::Extension(pb::Extension { id: e.id().as_ref().into(), metadata: e.metadata().map(|m| m.as_ref().into()), nullable: (*n).into(), @@ -129,3 +131,22 @@ impl From for pb::PType { } } } + +impl TryFrom<&pb::FieldPath> for FieldPath { + type Error = VortexError; + + fn try_from(value: &pb::FieldPath) -> Result { + let mut builder = FieldPathBuilder::new(); + for part in value.parts.iter() { + match part + .part_type + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "FieldPath part missing type"))? + { + PartType::Name(name) => builder.push(name.as_str()), + PartType::Index(idx) => builder.push(*idx as u64), + } + } + Ok(builder.build()) + } +} diff --git a/vortex-expr/Cargo.toml b/vortex-expr/Cargo.toml index b59f87119a..1fda5a040d 100644 --- a/vortex-expr/Cargo.toml +++ b/vortex-expr/Cargo.toml @@ -17,14 +17,17 @@ workspace = true [dependencies] datafusion-common = { workspace = true, optional = true } datafusion-expr = { workspace = true, optional = true } +prost = { workspace = true, optional = true } vortex-dtype = { path = "../vortex-dtype" } vortex-error = { path = "../vortex-error" } vortex-scalar = { path = "../vortex-scalar" } serde = { workspace = true, optional = true, features = ["derive"] } -[dev-dependencies] +[build-dependencies] +build-vortex = { path = "../build-vortex" } [features] -default = [] +default = ["proto"] datafusion = ["dep:datafusion-common", "dep:datafusion-expr", "vortex-scalar/datafusion"] +proto = ["dep:prost", "vortex-dtype/proto", "vortex-scalar/proto"] serde = ["dep:serde", "vortex-dtype/serde", "vortex-scalar/serde"] \ No newline at end of file diff --git a/vortex-expr/build.rs b/vortex-expr/build.rs new file mode 100644 index 0000000000..3ce2fd1cb5 --- /dev/null +++ b/vortex-expr/build.rs @@ -0,0 +1,3 @@ +pub fn main() { + build_vortex::build(); +} diff --git a/vortex-expr/proto/vortex/expr/expr.proto b/vortex-expr/proto/vortex/expr/expr.proto new file mode 100644 index 0000000000..72b8be9c2d --- /dev/null +++ b/vortex-expr/proto/vortex/expr/expr.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package vortex.expr; + +import "vortex/dtype/field_path.proto"; +import "vortex/scalar/scalar.proto"; + +message Disjunction { + repeated Conjunction conjunctions = 1; +} + +message Conjunction { + repeated Predicate predicates = 1; +} + +message Predicate { + vortex.dtype.FieldPath left = 1; + Operator op = 2; + oneof right { + vortex.dtype.FieldPath field = 3; + vortex.scalar.Scalar scalar = 4; + } +} + +enum Operator { + UNKNOWN = 0; + EQ = 1; + NEQ = 2; + LT = 3; + LTE = 4; + GT = 5; + GTE = 6; +} diff --git a/vortex-expr/src/display.rs b/vortex-expr/src/display.rs index a3d61323ea..0dde56ef11 100644 --- a/vortex-expr/src/display.rs +++ b/vortex-expr/src/display.rs @@ -2,7 +2,6 @@ use core::fmt; use std::fmt::{Display, Formatter}; use crate::expressions::{Conjunction, Disjunction, Predicate, Value}; -use crate::operators::Operator; impl Display for Disjunction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { @@ -39,20 +38,6 @@ impl Display for Value { } } -impl Display for Operator { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let display = match &self { - Operator::Eq => "=", - Operator::NotEq => "!=", - Operator::Gt => ">", - Operator::Gte => ">=", - Operator::Lt => "<", - Operator::Lte => "<=", - }; - write!(f, "{display}") - } -} - #[cfg(test)] mod tests { use vortex_dtype::field_paths::{field, FieldPath}; diff --git a/vortex-expr/src/lib.rs b/vortex-expr/src/lib.rs index d512ed175b..e8b2fb2174 100644 --- a/vortex-expr/src/lib.rs +++ b/vortex-expr/src/lib.rs @@ -1,8 +1,22 @@ #![feature(iter_intersperse)] -extern crate core; mod datafusion; mod display; -pub mod expressions; -pub mod field_paths; -pub mod operators; +mod expressions; +mod field_paths; +mod operators; +mod serde_proto; + +pub use expressions::*; +pub use field_paths::*; +pub use operators::*; + +#[cfg(feature = "proto")] +pub mod proto { + pub mod expr { + include!(concat!(env!("OUT_DIR"), "/proto/vortex.expr.rs")); + } + + pub use vortex_dtype::proto::dtype; + pub use vortex_scalar::proto::scalar; +} diff --git a/vortex-expr/src/operators.rs b/vortex-expr/src/operators.rs index 12b9166c4c..76ee36ccfe 100644 --- a/vortex-expr/src/operators.rs +++ b/vortex-expr/src/operators.rs @@ -1,3 +1,5 @@ +use core::fmt; +use std::fmt::{Display, Formatter}; use std::ops; use vortex_dtype::NativePType; @@ -16,6 +18,20 @@ pub enum Operator { Lte, } +impl Display for Operator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let display = match &self { + Operator::Eq => "=", + Operator::NotEq => "!=", + Operator::Gt => ">", + Operator::Gte => ">=", + Operator::Lt => "<", + Operator::Lte => "<=", + }; + write!(f, "{display}") + } +} + impl ops::Not for Predicate { type Output = Self; diff --git a/vortex-expr/src/serde_proto.rs b/vortex-expr/src/serde_proto.rs new file mode 100644 index 0000000000..c493d27e8d --- /dev/null +++ b/vortex-expr/src/serde_proto.rs @@ -0,0 +1,48 @@ +#![cfg(feature = "proto")] + +use vortex_error::{vortex_bail, vortex_err, VortexError}; + +use crate::proto::expr as pb; +use crate::proto::expr::predicate::Right; +use crate::{Operator, Predicate, Value}; + +impl TryFrom<&pb::Predicate> for Predicate { + type Error = VortexError; + + fn try_from(value: &pb::Predicate) -> Result { + Ok(Predicate { + left: value + .left + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Lhs is missing"))? + .try_into()?, + op: value.op().try_into()?, + right: match value + .right + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Rhs is missing"))? + { + Right::Field(f) => Value::Field(f.try_into()?), + Right::Scalar(scalar) => Value::Literal(scalar.try_into()?), + }, + }) + } +} + +impl TryFrom for Operator { + type Error = VortexError; + + fn try_from(value: pb::Operator) -> Result { + match value { + pb::Operator::Unknown => { + vortex_bail!(InvalidSerde: "Unknown operator {}", value.as_str_name()) + } + pb::Operator::Eq => Ok(Self::Eq), + pb::Operator::Neq => Ok(Self::NotEq), + pb::Operator::Lt => Ok(Self::Lt), + pb::Operator::Lte => Ok(Self::Lte), + pb::Operator::Gt => Ok(Self::Gt), + pb::Operator::Gte => Ok(Self::Gte), + } + } +}