diff --git a/Cargo.lock b/Cargo.lock index f145e032c8..7540bd7e06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5114,6 +5114,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "vortex-expr" +version = "0.1.0" +dependencies = [ + "vortex-dtype", + "vortex-error", + "vortex-scalar", +] + [[package]] name = "vortex-fastlanes" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 38b88ffa01..b12b7dfd22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "vortex-datetime-parts", "vortex-dict", "vortex-error", + "vortex-expr", "vortex-fastlanes", "vortex-flatbuffers", "vortex-ipc", diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index f767b26562..69b212ecf9 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -8,7 +8,8 @@ use DType::*; use crate::nullability::Nullability; use crate::{ExtDType, PType}; -pub type FieldNames = Arc<[Arc]>; +pub type FieldName = Arc; +pub type FieldNames = Arc<[FieldName]>; pub type Metadata = Vec; diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index edb3235f7b..3c1bf3df87 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -90,6 +90,26 @@ macro_rules! match_each_native_ptype { }) } +#[macro_export] +macro_rules! match_each_signed_ptype { + ($self:expr, | $_:tt $enc:ident | $($body:tt)*) => ({ + macro_rules! __with__ {( $_ $enc:ident ) => ( $($body)* )} + use $crate::PType; + use $crate::half::f16; + match $self { + PType::I8 => __with__! { i8 }, + PType::I16 => __with__! { i16 }, + PType::I32 => __with__! { i32 }, + PType::I64 => __with__! { i64 }, + PType::F16 => __with__! { f16 }, + PType::F32 => __with__! { f32 }, + PType::F64 => __with__! { f64 }, + _ => panic!("Unsupported ptype {}", $self), + + } + }) +} + #[macro_export] macro_rules! match_each_integer_ptype { ($self:expr, | $_:tt $enc:ident | $($body:tt)*) => ({ diff --git a/vortex-expr/Cargo.toml b/vortex-expr/Cargo.toml new file mode 100644 index 0000000000..5340700a4d --- /dev/null +++ b/vortex-expr/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "vortex-expr" +version = { workspace = true } +description = "Vortex Expressions" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +vortex-dtype = { path = "../vortex-dtype" } +vortex-error = { path = "../vortex-error" } +vortex-scalar = { path = "../vortex-scalar" } + + +[dev-dependencies] diff --git a/vortex-expr/README.md b/vortex-expr/README.md new file mode 100644 index 0000000000..0a059346f1 --- /dev/null +++ b/vortex-expr/README.md @@ -0,0 +1,8 @@ +# Vortex Expressions + +Expressions for querying vortex arrays. The query algebra is designed to express a minimal +superset of linear operations that can be pushed down to vortex metadata. Conversely, not all +operations that can be expressed in this algebra can be pushed down to metadata. + +Takes inspiration from postgres https://www.postgresql.org/docs/current/sql-expressions.html +and datafusion https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr diff --git a/vortex-expr/src/display.rs b/vortex-expr/src/display.rs new file mode 100644 index 0000000000..89cc596339 --- /dev/null +++ b/vortex-expr/src/display.rs @@ -0,0 +1,133 @@ +use core::fmt; +use std::fmt::{Display, Formatter}; + +use vortex_dtype::{match_each_native_ptype, DType}; +use vortex_scalar::{BoolScalar, PrimitiveScalar}; + +use crate::expressions::{BinaryExpr, Expr}; +use crate::operators::Operator; +use crate::scalar::ScalarDisplayWrapper; + +impl Display for Expr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + Expr::BinaryExpr(expr) => write!(f, "{expr}"), + Expr::Field(d) => write!(f, "{d}"), + Expr::Literal(v) => { + let wrapped = ScalarDisplayWrapper(v); + write!(f, "{wrapped}") + } + Expr::Not(expr) => write!(f, "NOT {expr}"), + Expr::Minus(expr) => write!(f, "(- {expr})"), + Expr::IsNull(expr) => write!(f, "{expr} IS NULL"), + } + } +} + +impl Display for BinaryExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn write_inner( + f: &mut Formatter<'_>, + outer: &Expr, + outer_op_precedence: u8, + ) -> fmt::Result { + if let Expr::BinaryExpr(inner) = outer { + let inner_op_precedence = inner.op.precedence(); + // if the child operator has lower precedence than the outer expression, + // wrap it in parentheses to prevent inversion of priority + if inner_op_precedence == 0 || inner_op_precedence < outer_op_precedence { + write!(f, "({inner})")?; + } else { + write!(f, "{inner}")?; + } + } else if let Expr::Literal(scalar) = outer { + // use alternative formatting for scalars + let wrapped = ScalarDisplayWrapper(scalar); + write!(f, "{wrapped}")?; + } else { + write!(f, "{outer}")?; + } + Ok(()) + } + + write_inner(f, self.left.as_ref(), self.op.precedence())?; + write!(f, " {} ", self.op)?; + write_inner(f, self.right.as_ref(), self.op.precedence()) + } +} + +/// Alternative display for scalars +impl Display for ScalarDisplayWrapper<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.0.dtype() { + DType::Null => write!(f, "null"), + DType::Bool(_) => match BoolScalar::try_from(self.0).expect("bool").value() { + None => write!(f, "null"), + Some(b) => write!(f, "{}", b), + }, + DType::Primitive(ptype, _) => match_each_native_ptype!(ptype, |$T| { + match PrimitiveScalar::try_from(self.0).expect("primitive").typed_value::<$T>() { + None => write!(f, "null"), + Some(v) => write!(f, "{}{}", v, std::any::type_name::<$T>()), + } + }), + DType::Utf8(_) => todo!(), + DType::Binary(_) => todo!(), + DType::Struct(..) => todo!(), + DType::List(..) => todo!(), + DType::Extension(..) => todo!(), + } + } +} + +impl Display for Operator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let display = match &self { + Operator::And => "AND", + Operator::Or => "OR", + Operator::EqualTo => "=", + Operator::NotEqualTo => "!=", + Operator::GreaterThan => ">", + Operator::GreaterThanOrEqualTo => ">=", + Operator::LessThan => "<", + Operator::LessThanOrEqualTo => "<=", + Operator::Plus => "+", + Operator::Minus | Operator::UnaryMinus => "-", + Operator::Multiplication => "*", + Operator::Division => "/", + Operator::Modulo => "%", + }; + write!(f, "{display}") + } +} + +#[cfg(test)] +mod tests { + use crate::literal::lit; + + #[test] + fn test_formatting() { + // Addition + assert_eq!(format!("{}", lit(1u32) + lit(2u32)), "1u32 + 2u32"); + // Subtraction + assert_eq!(format!("{}", lit(1u32) - lit(2u32)), "1u32 - 2u32"); + // Multiplication + assert_eq!(format!("{}", lit(1u32) * lit(2u32)), "1u32 * 2u32"); + // Division + assert_eq!(format!("{}", lit(1u32) / lit(2u32)), "1u32 / 2u32"); + // Modulus + assert_eq!(format!("{}", lit(1u32) % lit(2u32)), "1u32 % 2u32"); + // Negate + assert_eq!(format!("{}", -lit(1u32)), "(- 1u32)"); + + // And + let string = format!("{}", lit(true).and(lit(false))); + assert_eq!(string, "true AND false"); + // Or + let string = format!("{}", lit(true).or(lit(false))); + assert_eq!(string, "true OR false"); + // Not + let string = format!("{}", !lit(1u32)); + assert_eq!(string, "NOT 1u32"); + } +} diff --git a/vortex-expr/src/expression_fns.rs b/vortex-expr/src/expression_fns.rs new file mode 100644 index 0000000000..1eb92af8f9 --- /dev/null +++ b/vortex-expr/src/expression_fns.rs @@ -0,0 +1,15 @@ +use vortex_dtype::FieldName; + +use crate::expressions::BinaryExpr; +use crate::expressions::Expr; +use crate::operators::Operator; + +pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) +} + +#[allow(dead_code)] +/// Create a field expression based on a qualified field name. +pub fn field(field: impl Into) -> Expr { + Expr::Field(field.into()) +} diff --git a/vortex-expr/src/expressions.rs b/vortex-expr/src/expressions.rs new file mode 100644 index 0000000000..6388a1b6e5 --- /dev/null +++ b/vortex-expr/src/expressions.rs @@ -0,0 +1,115 @@ +use vortex_dtype::FieldName; +use vortex_error::{vortex_bail, VortexResult}; +use vortex_scalar::Scalar; + +use crate::expression_fns::binary_expr; +use crate::operators::Operator; + +#[derive(Clone, Debug, PartialEq)] +pub enum Expr { + /// A binary expression such as "duration_seconds == 100" + BinaryExpr(BinaryExpr), + + /// A named reference to a qualified field in a dtype. + Field(FieldName), + + /// True if argument is NULL, false otherwise. This expression itself is never NULL. + IsNull(Box), + + /// A constant scalar value. + Literal(Scalar), + + /// Additive inverse of an expression. The expression's type must be numeric. + Minus(Box), + + /// Negation of an expression. The expression's type must be a boolean. + Not(Box), +} + +impl Expr { + // binary logic + + pub fn and(self, other: Expr) -> Expr { + binary_expr(self, Operator::And, other) + } + + pub fn or(self, other: Expr) -> Expr { + binary_expr(self, Operator::Or, other) + } + + // comparisons + + pub fn eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::EqualTo, other) + } + + pub fn not_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotEqualTo, other) + } + + pub fn gt(self, other: Expr) -> Expr { + binary_expr(self, Operator::GreaterThan, other) + } + + pub fn gt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::GreaterThanOrEqualTo, other) + } + + pub fn lt(self, other: Expr) -> Expr { + binary_expr(self, Operator::LessThan, other) + } + + pub fn lt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::LessThanOrEqualTo, other) + } + + // misc + pub fn is_null(self) -> Expr { + Expr::IsNull(Box::new(self)) + } + + pub fn minus(&self) -> VortexResult { + Ok(match self { + Expr::Literal(scalar) => Expr::Literal(scalar.invert()?), + _ => { + vortex_bail!("Can only negate numeric literals") + } + }) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct BinaryExpr { + pub left: Box, + pub op: Operator, + pub right: Box, +} + +impl BinaryExpr { + pub fn new(left: Box, op: Operator, right: Box) -> Self { + Self { left, op, right } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::expression_fns::field; + use crate::literal::lit; + + #[test] + fn test_lit() { + let exp = field("id").eq(lit(1)); + let exp = exp.eq(lit(2).minus().unwrap()); + let s = format!("{}", exp); + assert_eq!(s, "id = Scalar { dtype: Primitive(I32, NonNullable), \ + value: Primitive(I32(1)) } = Scalar { dtype: Primitive(I32, NonNullable), value: Primitive(I32(-2)) }") + } + + #[test] + fn test_negative() { + let scalar: Scalar = 1.into(); + let rhs: Expr = lit(scalar); + field("id").eq(rhs); + } +} diff --git a/vortex-expr/src/fields.rs b/vortex-expr/src/fields.rs new file mode 100644 index 0000000000..230e509bcd --- /dev/null +++ b/vortex-expr/src/fields.rs @@ -0,0 +1,9 @@ +#[cfg(test)] +mod test { + use crate::expression_fns::field; + + #[test] + fn test_cant_negate_field() { + field("id").minus().expect_err("cannot negate field"); + } +} diff --git a/vortex-expr/src/lib.rs b/vortex-expr/src/lib.rs new file mode 100644 index 0000000000..4ba5bfaa8d --- /dev/null +++ b/vortex-expr/src/lib.rs @@ -0,0 +1,9 @@ +extern crate core; + +mod display; +mod expression_fns; +pub mod expressions; +mod fields; +mod literal; +pub mod operators; +pub mod scalar; diff --git a/vortex-expr/src/literal.rs b/vortex-expr/src/literal.rs new file mode 100644 index 0000000000..2f6bc37946 --- /dev/null +++ b/vortex-expr/src/literal.rs @@ -0,0 +1,43 @@ +use vortex_scalar::Scalar; + +use crate::expressions::Expr; + +pub trait Literal { + fn lit(&self) -> Expr; +} + +#[allow(dead_code)] +pub fn lit>(n: T) -> Expr { + n.into().lit() +} + +#[cfg(test)] +mod test { + use super::*; + use crate::expression_fns::field; + + #[test] + fn test_lit() { + let scalar: Scalar = 1.into(); + let rhs: Expr = lit(scalar); + let expr = field("id").eq(rhs); + assert_eq!( + format!("{}", expr), + "id = Scalar { dtype: Primitive(I32, NonNullable), value: Primitive(I32(1)) }" + ); + } + + #[test] + fn test_invert_lit() { + let scalar: Scalar = 1.into(); + let lit: Expr = lit(scalar); + assert_eq!( + format!("{}", -lit.clone()), + "(- Scalar { dtype: Primitive(I32, NonNullable), value: Primitive(I32(1)) })" + ); + assert_eq!( + format!("{}", lit.minus().unwrap()), + "Scalar { dtype: Primitive(I32, NonNullable), value: Primitive(I32(-1)) }" + ); + } +} diff --git a/vortex-expr/src/operators.rs b/vortex-expr/src/operators.rs new file mode 100644 index 0000000000..ff8709df2e --- /dev/null +++ b/vortex-expr/src/operators.rs @@ -0,0 +1,127 @@ +use std::ops; + +use crate::expression_fns::binary_expr; +use crate::expressions::Expr; + +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)] +pub enum Operator { + // arithmetic + Plus, + Minus, + UnaryMinus, + Multiplication, + Division, + Modulo, + // binary logic + And, + Or, + // comparison + EqualTo, + NotEqualTo, + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, +} + +pub enum Associativity { + Left, + Right, + Neither, +} + +/// Magic numbers from postgres docs: +/// +impl Operator { + pub fn precedence(&self) -> u8 { + match self { + Operator::Or => 1, + Operator::And => 2, + Operator::NotEqualTo + | Operator::EqualTo + | Operator::LessThan + | Operator::LessThanOrEqualTo + | Operator::GreaterThan + | Operator::GreaterThanOrEqualTo => 4, + Operator::Plus | Operator::Minus => 13, + Operator::Multiplication | Operator::Division | Operator::Modulo => 14, + Operator::UnaryMinus => 17, + } + } + + pub fn associativity(&self) -> Associativity { + match self { + Operator::Or + | Operator::And + | Operator::Plus + | Operator::Minus + | Operator::Multiplication + | Operator::Division + | Operator::Modulo => Associativity::Left, + Operator::NotEqualTo + | Operator::EqualTo + | Operator::LessThan + | Operator::LessThanOrEqualTo + | Operator::GreaterThan + | Operator::GreaterThanOrEqualTo => Associativity::Neither, + Operator::UnaryMinus => Associativity::Right, + } + } +} + +/// Various operator support +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiplication, rhs) + } +} + +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Division, rhs) + } +} + +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} + +impl ops::Neg for Expr { + type Output = Self; + + fn neg(self) -> Self::Output { + Expr::Minus(Box::new(self)) + } +} + +impl ops::Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + Expr::Not(Box::new(self)) + } +} diff --git a/vortex-expr/src/scalar.rs b/vortex-expr/src/scalar.rs new file mode 100644 index 0000000000..e285d53095 --- /dev/null +++ b/vortex-expr/src/scalar.rs @@ -0,0 +1,12 @@ +use vortex_scalar::Scalar; + +use crate::expressions::Expr; +use crate::literal::Literal; + +pub struct ScalarDisplayWrapper<'a>(pub &'a Scalar); + +impl Literal for Scalar { + fn lit(&self) -> Expr { + Expr::Literal(self.clone()) + } +} diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index faeee748cd..7c49d9d6a9 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering; +use std::ops::Neg; -use vortex_dtype::DType; +use vortex_dtype::{match_each_signed_ptype, DType, PType}; mod binary; mod bool; @@ -114,6 +115,53 @@ impl Scalar { DType::Extension(..) => ExtScalar::try_from(self).and_then(|s| s.cast(dtype)), } } + + pub fn invert(&self) -> VortexResult { + if !can_invert(self.dtype()) { + vortex_bail!("Cannot invert type") + } + + let scalar_value: Scalar = match &self.value { + ScalarValue::Bool(bool) => (!bool).into(), + ScalarValue::List(list) => { + let inner_dtype = match self.dtype.clone() { + DType::List(dtype, _) => dtype, + _ => vortex_bail!("Invalid dtype"), + }; + let scalar: VortexResult> = list + .iter() + .map(|v| { + Ok(Scalar::new((*inner_dtype).clone(), v.clone()) + .invert()? + .value) + }) + .collect(); + Scalar::list((*inner_dtype).clone(), scalar?) + } + ScalarValue::Primitive(ps) => { + match_each_signed_ptype!(ps.ptype(), |$T| { + let p: $T = $T::try_from(ps.clone())?; + let sc: Scalar = p.neg().into(); + sc + }) + } + _ => { + unreachable!("Should have bailed already due to uninvertible type") + } + }; + Ok(scalar_value) + } +} + +fn can_invert(dtype: &DType) -> bool { + match dtype { + DType::Bool(_) => true, + DType::Primitive(ptype, _) => { + !matches!(ptype, PType::U8 | PType::U16 | PType::U32 | PType::U64) + } + DType::List(dtype, _) => can_invert(dtype.as_ref()), + _ => false, + } } impl PartialEq for Scalar {