From fde193fa994bed8589e9818810b0c0090ee34057 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Wed, 17 Jul 2024 16:06:39 +0100 Subject: [PATCH] Basic predicate pushdown support for Datafusion (#472) Enables basic support for predicate pushdown over in-memory vortex arrays for `eq` operations under fairly limited conditions. --- Cargo.lock | 1 + Cargo.toml | 1 + bench-vortex/src/bin/tpch_benchmark.rs | 30 ++--- encodings/dict/src/compress.rs | 4 +- vortex-array/Cargo.toml | 1 + vortex-array/src/array/constant/canonical.rs | 17 ++- vortex-array/src/array/constant/mod.rs | 4 +- vortex-array/src/array/varbin/compute/mod.rs | 4 +- vortex-array/src/array/varbin/compute/take.rs | 2 +- vortex-array/src/array/varbin/mod.rs | 2 +- vortex-array/src/array/varbinview/compute.rs | 3 +- vortex-array/src/compute/compare.rs | 40 +++--- vortex-datafusion/Cargo.toml | 2 +- vortex-datafusion/src/eval.rs | 53 ++++++++ vortex-datafusion/src/lib.rs | 116 ++++++++---------- vortex-datafusion/src/plans.rs | 55 ++------- vortex-scalar/src/datafusion.rs | 29 +++++ 17 files changed, 203 insertions(+), 161 deletions(-) create mode 100644 vortex-datafusion/src/eval.rs diff --git a/Cargo.lock b/Cargo.lock index 4ad105a453..1c01d79f40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3998,6 +3998,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-cast", + "arrow-ord", "arrow-schema", "arrow-select", "build-vortex", diff --git a/Cargo.toml b/Cargo.toml index f9679c5221..b19de1fcd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ arrow-cast = "52.0.0" arrow-csv = "52.0.0" arrow-data = "52.0.0" arrow-ipc = "52.0.0" +arrow-ord = "52.0.0" arrow-schema = "52.0.0" arrow-select = "52.0.0" async-trait = "0.1" diff --git a/bench-vortex/src/bin/tpch_benchmark.rs b/bench-vortex/src/bin/tpch_benchmark.rs index e21e255e9d..90f392caad 100644 --- a/bench-vortex/src/bin/tpch_benchmark.rs +++ b/bench-vortex/src/bin/tpch_benchmark.rs @@ -3,9 +3,8 @@ use std::time::SystemTime; use bench_vortex::tpch::dbgen::{DBGen, DBGenOptions}; use bench_vortex::tpch::{load_datasets, tpch_queries, Format}; -use futures::future::join_all; +use futures::future::try_join_all; use indicatif::ProgressBar; -use itertools::Itertools; use prettytable::{Cell, Row, Table}; #[tokio::main(flavor = "multi_thread", worker_threads = 8)] @@ -23,21 +22,12 @@ async fn main() { Format::Vortex { disable_pushdown: false, }, - Format::Vortex { - disable_pushdown: true, - }, ]; // Load datasets - let ctxs = join_all( - formats - .iter() - .map(|format| load_datasets(&data_dir, *format)), - ) - .await - .into_iter() - .map(|r| r.unwrap()) - .collect_vec(); + let ctxs = try_join_all(formats.map(|format| load_datasets(&data_dir, format))) + .await + .unwrap(); // Set up a results table let mut table = Table::new(); @@ -53,9 +43,9 @@ async fn main() { // Send back a channel with the results of Row. let (rows_tx, rows_rx) = sync::mpsc::channel(); for (q, query) in tpch_queries() { - let _ctxs = ctxs.clone(); - let _tx = rows_tx.clone(); - let _progress = progress.clone(); + let ctxs = ctxs.clone(); + let tx = rows_tx.clone(); + let progress = progress.clone(); rayon::spawn_fifo(move || { let mut cells = Vec::with_capacity(formats.len()); cells.push(Cell::new(&format!("Q{}", q))); @@ -65,7 +55,7 @@ async fn main() { .enable_all() .build() .unwrap(); - for (ctx, format) in _ctxs.iter().zip(formats.iter()) { + for (ctx, format) in ctxs.iter().zip(formats.iter()) { for _ in 0..3 { // warmup rt.block_on(async { @@ -98,7 +88,7 @@ async fn main() { let fastest = measure.iter().cloned().min().unwrap(); elapsed_us.push(fastest); - _progress.inc(1); + progress.inc(1); } let baseline = elapsed_us.first().unwrap(); @@ -125,7 +115,7 @@ async fn main() { ); } - _tx.send((q, Row::new(cells))).unwrap(); + tx.send((q, Row::new(cells))).unwrap(); }); } diff --git a/encodings/dict/src/compress.rs b/encodings/dict/src/compress.rs index 7d60f2417f..d868f25f29 100644 --- a/encodings/dict/src/compress.rs +++ b/encodings/dict/src/compress.rs @@ -111,7 +111,7 @@ where let mut lookup_dict: HashMap = HashMap::with_hasher(()); let mut codes: Vec = Vec::with_capacity(lower); let mut bytes: Vec = Vec::new(); - let mut offsets: Vec = Vec::new(); + let mut offsets: Vec = Vec::new(); offsets.push(0); if dtype.is_nullable() { @@ -133,7 +133,7 @@ where RawEntryMut::Vacant(vac) => { let next_code = offsets.len() as u64 - 1; bytes.extend_from_slice(byte_ref); - offsets.push(bytes.len() as u64); + offsets.push(bytes.len() as u32); vac.insert_with_hasher(value_hash, next_code, (), |idx| { hasher.hash_one(lookup_bytes( offsets.as_slice(), diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index adc760d068..3ae8c31444 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -24,6 +24,7 @@ arrow-buffer = { workspace = true } arrow-cast = { workspace = true } arrow-select = { workspace = true } arrow-schema = { workspace = true } +arrow-ord = { workspace = true } enum-iterator = { workspace = true } flatbuffers = { workspace = true } flexbuffers = { workspace = true } diff --git a/vortex-array/src/array/constant/canonical.rs b/vortex-array/src/array/constant/canonical.rs index 18fd8234c1..867341871c 100644 --- a/vortex-array/src/array/constant/canonical.rs +++ b/vortex-array/src/array/constant/canonical.rs @@ -1,10 +1,13 @@ -use vortex_dtype::{match_each_native_ptype, Nullability, PType}; +use std::iter; + +use vortex_dtype::{match_each_native_ptype, DType, Nullability, PType}; use vortex_error::{vortex_bail, VortexResult}; -use vortex_scalar::BoolScalar; +use vortex_scalar::{BoolScalar, Utf8Scalar}; use crate::array::bool::BoolArray; use crate::array::constant::ConstantArray; use crate::array::primitive::PrimitiveArray; +use crate::array::varbin::VarBinArray; use crate::validity::Validity; use crate::ArrayDType; use crate::{Canonical, IntoCanonical}; @@ -26,6 +29,16 @@ impl IntoCanonical for ConstantArray { ))); } + if let Ok(s) = Utf8Scalar::try_from(self.scalar()) { + let const_value = s.value().unwrap(); + let bytes = const_value.as_bytes(); + + return Ok(Canonical::VarBin(VarBinArray::from_iter( + iter::repeat(Some(bytes)).take(self.len()), + DType::Utf8(validity.nullability()), + ))); + } + if let Ok(ptype) = PType::try_from(self.scalar().dtype()) { return match_each_native_ptype!(ptype, |$P| { Ok(Canonical::Primitive(PrimitiveArray::from_vec::<$P>( diff --git a/vortex-array/src/array/constant/mod.rs b/vortex-array/src/array/constant/mod.rs index 005567a98c..a30ebf54ea 100644 --- a/vortex-array/src/array/constant/mod.rs +++ b/vortex-array/src/array/constant/mod.rs @@ -24,9 +24,9 @@ pub struct ConstantMetadata { impl ConstantArray { pub fn new(scalar: S, length: usize) -> Self where - Scalar: From, + S: Into, { - let scalar: Scalar = scalar.into(); + let scalar = scalar.into(); // TODO(aduffy): add stats for bools, ideally there should be a // StatsSet::constant(Scalar) constructor that does this for us, like StatsSet::nulls. let stats = StatsSet::from(HashMap::from([ diff --git a/vortex-array/src/array/varbin/compute/mod.rs b/vortex-array/src/array/varbin/compute/mod.rs index edeca4867f..3b0f4d09af 100644 --- a/vortex-array/src/array/varbin/compute/mod.rs +++ b/vortex-array/src/array/varbin/compute/mod.rs @@ -3,9 +3,7 @@ use vortex_scalar::Scalar; use crate::array::varbin::{varbin_scalar, VarBinArray}; use crate::compute::unary::scalar_at::ScalarAtFn; -use crate::compute::ArrayCompute; -use crate::compute::SliceFn; -use crate::compute::TakeFn; +use crate::compute::{ArrayCompute, SliceFn, TakeFn}; use crate::validity::ArrayValidity; use crate::ArrayDType; diff --git a/vortex-array/src/array/varbin/compute/take.rs b/vortex-array/src/array/varbin/compute/take.rs index 02b1eb466c..d57ea9241d 100644 --- a/vortex-array/src/array/varbin/compute/take.rs +++ b/vortex-array/src/array/varbin/compute/take.rs @@ -66,7 +66,7 @@ fn take_nullable( indices: &[I], null_buffer: NullBuffer, ) -> VarBinArray { - let mut builder = VarBinBuilder::::with_capacity(indices.len()); + let mut builder = VarBinBuilder::::with_capacity(indices.len()); for &idx in indices { let idx = idx.to_usize().unwrap(); if null_buffer.is_valid(idx) { diff --git a/vortex-array/src/array/varbin/mod.rs b/vortex-array/src/array/varbin/mod.rs index efd80e355d..16d9279395 100644 --- a/vortex-array/src/array/varbin/mod.rs +++ b/vortex-array/src/array/varbin/mod.rs @@ -132,7 +132,7 @@ impl VarBinArray { dtype: DType, ) -> Self { let iter = iter.into_iter(); - let mut builder = VarBinBuilder::::with_capacity(iter.size_hint().0); + let mut builder = VarBinBuilder::::with_capacity(iter.size_hint().0); for v in iter { builder.push(v.as_ref().map(|o| o.as_ref())); } diff --git a/vortex-array/src/array/varbinview/compute.rs b/vortex-array/src/array/varbinview/compute.rs index d6c380789b..8c410b1ba5 100644 --- a/vortex-array/src/array/varbinview/compute.rs +++ b/vortex-array/src/array/varbinview/compute.rs @@ -4,8 +4,7 @@ use vortex_scalar::Scalar; use crate::array::varbin::varbin_scalar; use crate::array::varbinview::{VarBinViewArray, VIEW_SIZE}; use crate::compute::unary::scalar_at::ScalarAtFn; -use crate::compute::ArrayCompute; -use crate::compute::{slice, SliceFn}; +use crate::compute::{slice, ArrayCompute, SliceFn}; use crate::validity::ArrayValidity; use crate::{Array, ArrayDType, IntoArray, IntoArrayData}; diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 80459f887d..b566bb8885 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -1,30 +1,32 @@ -use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexResult}; +use arrow_ord::cmp; +use vortex_error::VortexResult; use vortex_expr::Operator; -use crate::{Array, ArrayDType, IntoArrayVariant}; +use crate::{arrow::FromArrowArray, Array, ArrayData, IntoArray, IntoCanonical}; pub trait CompareFn { - fn compare(&self, array: &Array, predicate: Operator) -> VortexResult; + fn compare(&self, array: &Array, operator: Operator) -> VortexResult; } pub fn compare(left: &Array, right: &Array, operator: Operator) -> VortexResult { - if let Some(matching_indices) = - left.with_dyn(|lhs| lhs.compare().map(|rhs| rhs.compare(right, operator))) + if let Some(selection) = + left.with_dyn(|lhs| lhs.compare().map(|lhs| lhs.compare(right, operator))) { - return matching_indices; + return selection; } - // if compare is not implemented for the given array type, but the array has a numeric - // DType, we can flatten the array and apply filter to the flattened primitive array - match left.dtype() { - DType::Primitive(..) => { - let flat = left.clone().into_primitive()?; - flat.compare(right, operator) - } - _ => Err(vortex_err!( - NotImplemented: "compare", - left.encoding().id() - )), - } + // Fallback to arrow on canonical types + let lhs = left.clone().into_canonical()?.into_arrow(); + let rhs = right.clone().into_canonical()?.into_arrow(); + + let array = match operator { + Operator::Eq => cmp::eq(&lhs.as_ref(), &rhs.as_ref())?, + Operator::NotEq => cmp::neq(&lhs.as_ref(), &rhs.as_ref())?, + Operator::Gt => cmp::gt(&lhs.as_ref(), &rhs.as_ref())?, + Operator::Gte => cmp::gt_eq(&lhs.as_ref(), &rhs.as_ref())?, + Operator::Lt => cmp::lt(&lhs.as_ref(), &rhs.as_ref())?, + Operator::Lte => cmp::lt_eq(&lhs.as_ref(), &rhs.as_ref())?, + }; + + Ok(ArrayData::from_arrow(&array, true).into_array()) } diff --git a/vortex-datafusion/Cargo.toml b/vortex-datafusion/Cargo.toml index 615171e3f4..abe626d007 100644 --- a/vortex-datafusion/Cargo.toml +++ b/vortex-datafusion/Cargo.toml @@ -15,7 +15,7 @@ vortex-array = { path = "../vortex-array" } vortex-dtype = { path = "../vortex-dtype" } vortex-expr = { path = "../vortex-expr" } vortex-error = { path = "../vortex-error" } -vortex-scalar = { path = "../vortex-scalar" } +vortex-scalar = { path = "../vortex-scalar", features = ["datafusion"] } arrow-array = { workspace = true } arrow-schema = { workspace = true } diff --git a/vortex-datafusion/src/eval.rs b/vortex-datafusion/src/eval.rs new file mode 100644 index 0000000000..33a45b38d1 --- /dev/null +++ b/vortex-datafusion/src/eval.rs @@ -0,0 +1,53 @@ +use datafusion_expr::{Expr, Operator as DFOperator}; +use vortex::{ + array::{bool::BoolArray, constant::ConstantArray}, + compute::compare, + Array, IntoArray, IntoArrayVariant, +}; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use vortex_expr::Operator; + +pub struct ExpressionEvaluator; + +impl ExpressionEvaluator { + pub fn eval(array: Array, expr: &Expr) -> VortexResult { + match expr { + Expr::BinaryExpr(expr) => { + let lhs = expr.left.as_ref(); + let rhs = expr.right.as_ref(); + + // TODO(adamg): turn and/or into more general compute functions + match expr.op { + DFOperator::And => { + let lhs = ExpressionEvaluator::eval(array.clone(), lhs)?.into_bool()?; + let rhs = ExpressionEvaluator::eval(array, rhs)?.into_bool()?; + let buffer = &lhs.boolean_buffer() & &rhs.boolean_buffer(); + Ok(BoolArray::from(buffer).into_array()) + } + DFOperator::Or => { + let lhs = ExpressionEvaluator::eval(array.clone(), lhs)?.into_bool()?; + let rhs = ExpressionEvaluator::eval(array.clone(), rhs)?.into_bool()?; + let buffer = &lhs.boolean_buffer() | &rhs.boolean_buffer(); + Ok(BoolArray::from(buffer).into_array()) + } + DFOperator::Eq => { + let lhs = ExpressionEvaluator::eval(array.clone(), lhs)?; + let rhs = ExpressionEvaluator::eval(array.clone(), rhs)?; + compare(&lhs, &rhs, Operator::Eq) + } + _ => vortex_bail!("{} is an unsupported operator", expr.op), + } + } + Expr::Column(col) => { + // TODO(adamg): Use variant trait once its merged + let array = array.clone().into_struct()?; + let name = col.name(); + array + .field_by_name(name) + .ok_or(vortex_err!("Missing field {name} in struct")) + } + Expr::Literal(lit) => Ok(ConstantArray::new(lit.clone(), array.len()).into_array()), + _ => unreachable!(), + } + } +} diff --git a/vortex-datafusion/src/lib.rs b/vortex-datafusion/src/lib.rs index 49ea4e404a..3447916426 100644 --- a/vortex-datafusion/src/lib.rs +++ b/vortex-datafusion/src/lib.rs @@ -1,5 +1,7 @@ //! Connectors to enable DataFusion to read Vortex data. +#![allow(clippy::nonminimal_bool)] + use std::any::Any; use std::collections::HashSet; use std::fmt::{Debug, Formatter}; @@ -8,14 +10,14 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow_array::{RecordBatch, StructArray as ArrowStructArray}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use datafusion::dataframe::DataFrame; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion::prelude::SessionContext; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{exec_datafusion_err, DataFusionError, Result as DFResult}; use datafusion_expr::{Expr, Operator, TableProviderFilterPushDown, TableType}; use datafusion_physical_expr::EquivalenceProperties; @@ -32,6 +34,7 @@ use crate::datatype::infer_schema; use crate::plans::{RowSelectorExec, TakeRowsExec}; mod datatype; +mod eval; mod expr; mod plans; @@ -246,7 +249,7 @@ impl TableProvider for VortexMemTable { filters .iter() .map(|expr| { - if can_be_pushed_down(expr)? { + if can_be_pushed_down(expr) { Ok(TableProviderFilterPushDown::Exact) } else { Ok(TableProviderFilterPushDown::Unsupported) @@ -286,72 +289,32 @@ fn make_filter_then_take_plan( } /// Check if the given expression tree can be pushed down into the scan. -fn can_be_pushed_down(expr: &Expr) -> DFResult { - // If the filter references a column not known to our schema, we reject the filter for pushdown. - fn is_supported(expr: &Expr) -> bool { - match expr { - Expr::BinaryExpr(binary_expr) => { - // Both the left and right sides must be column expressions, scalars, or casts. - - match binary_expr.op { - // Initially, we will only support pushdown for basic boolean operators - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => true, - - // TODO(aduffy): add support for LIKE - // TODO(aduffy): add support for basic mathematical ops +-*/ - // TODO(aduffy): add support for conjunctions, assuming all of the - // left and right are valid expressions. - _ => false, +fn can_be_pushed_down(expr: &Expr) -> bool { + match expr { + Expr::BinaryExpr(expr) if expr.op == Operator::Eq => { + let lhs = expr.left.as_ref(); + let rhs = expr.right.as_ref(); + + match (lhs, rhs) { + (Expr::Column(_), Expr::Column(_)) => true, + (Expr::Column(_), Expr::Literal(lit)) | (Expr::Literal(lit), Expr::Column(_)) => { + let dt = lit.data_type(); + dt.is_integer() + || dt.is_floating() + || dt.is_signed_integer() + || dt.is_null() + || dt == DataType::Binary + || dt == DataType::Utf8 + || dt == DataType::Binary + || dt == DataType::BinaryView + || dt == DataType::Utf8View } + _ => false, } - Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::Column(_) - | Expr::Literal(_) - // TODO(aduffy): ensure that cast can be pushed down. - | Expr::Cast(_) => true, - _ => false, } - } - - // Visitor that traverses the expression tree and tracks if any unsupported expressions were - // encountered. - struct IsSupportedVisitor { - supported_expressions_only: bool, - } - - impl TreeNodeVisitor<'_> for IsSupportedVisitor { - type Node = Expr; - fn f_down(&mut self, node: &Self::Node) -> DFResult { - if !is_supported(node) { - self.supported_expressions_only = false; - return Ok(TreeNodeRecursion::Stop); - } - - Ok(TreeNodeRecursion::Continue) - } + _ => false, } - - let mut visitor = IsSupportedVisitor { - supported_expressions_only: true, - }; - - // Traverse the tree. - // At the end of the traversal, the internal state of `visitor` will indicate if there were - // unsupported expressions encountered. - expr.visit(&mut visitor)?; - - Ok(visitor.supported_expressions_only) } /// Extract out the columns from our table referenced by the expression. @@ -505,7 +468,8 @@ mod test { use datafusion::arrow::array::AsArray; use datafusion::functions_aggregate::count::count_distinct; use datafusion::prelude::SessionContext; - use datafusion_expr::{col, lit}; + use datafusion_common::{Column, TableReference}; + use datafusion_expr::{col, lit, BinaryExpr, Expr, Operator}; use vortex::array::primitive::PrimitiveArray; use vortex::array::struct_::StructArray; use vortex::array::varbin::VarBinArray; @@ -513,7 +477,7 @@ mod test { use vortex::{Array, IntoArray}; use vortex_dtype::{DType, Nullability}; - use crate::{SessionContextExt, VortexMemTableOptions}; + use crate::{can_be_pushed_down, SessionContextExt, VortexMemTableOptions}; fn presidents_array() -> Array { let names = VarBinArray::from_vec( @@ -603,4 +567,24 @@ mod test { 4i64 ); } + + #[test] + fn test_can_be_pushed_down() { + let e = BinaryExpr { + left: Box::new( + Column { + relation: Some(TableReference::Bare { + table: "orders".into(), + }), + name: "o_orderstatus".to_string(), + } + .into(), + ), + op: Operator::Eq, + right: Box::new(lit("F")), + }; + let e = Expr::BinaryExpr(e); + + assert!(can_be_pushed_down(&e)); + } } diff --git a/vortex-datafusion/src/plans.rs b/vortex-datafusion/src/plans.rs index ca00e782b6..04890c516d 100644 --- a/vortex-datafusion/src/plans.rs +++ b/vortex-datafusion/src/plans.rs @@ -10,10 +10,10 @@ use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::{DFSchema, Result as DFResult}; +use datafusion_common::Result as DFResult; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Expr; -use datafusion_physical_expr::{create_physical_expr, EquivalenceProperties, Partitioning}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties, }; @@ -26,6 +26,7 @@ use vortex::compute::take; use vortex::{ArrayDType, ArrayData, IntoArray, IntoArrayVariant, IntoCanonical}; use crate::datatype::infer_schema; +use crate::eval::ExpressionEvaluator; use crate::expr::{make_conjunction, simplify_expr}; /// Physical plan operator that applies a set of [filters][Expr] against the input, producing a @@ -133,17 +134,14 @@ impl ExecutionPlan for RowSelectorExec { .unwrap(), ); - let conjunction_expr = simplify_expr( - &make_conjunction(&self.filter_exprs)?, - filter_schema.clone(), - )?; + let conjunction_expr = + simplify_expr(&make_conjunction(&self.filter_exprs)?, filter_schema)?; Ok(Box::pin(RowIndicesStream { chunked_array: self.chunked_array.clone(), chunk_idx: 0, filter_projection: self.filter_projection.clone(), conjunction_expr, - filter_schema, })) } } @@ -154,7 +152,6 @@ pub(crate) struct RowIndicesStream { chunk_idx: usize, conjunction_expr: Expr, filter_projection: Vec, - filter_schema: SchemaRef, } impl Stream for RowIndicesStream { @@ -182,37 +179,19 @@ impl Stream for RowIndicesStream { .project(this.filter_projection.as_slice()) .expect("projection should succeed"); - // Immediately convert to Arrow RecordBatch for processing. - // TODO(aduffy): attempt to pushdown the filter to Vortex without decoding. - let record_batch = RecordBatch::from( - vortex_struct - .into_canonical() - .unwrap() - .into_arrow() - .as_struct(), - ); - - // Generate a physical plan to execute the conjunction query against the filter columns. - // - // The result of a conjunction expression is a BooleanArray containing `true` for rows - // where the conjunction was satisfied, and `false` otherwise. - let df_schema = DFSchema::try_from(this.filter_schema.clone())?; - let physical_expr = - create_physical_expr(&this.conjunction_expr, &df_schema, &Default::default())?; - let selection = physical_expr - .evaluate(&record_batch)? - .into_array(record_batch.num_rows())?; + // TODO(adamg): Filter on vortex arrays + let array = + ExpressionEvaluator::eval(vortex_struct.into_array(), &this.conjunction_expr).unwrap(); + let selection = array.into_canonical().unwrap().into_arrow(); // Convert the `selection` BooleanArray into a UInt64Array of indices. - let selection_indices: Vec = selection + let selection_indices = selection .as_boolean() - .clone() .values() .set_indices() - .map(|idx| idx as u64) - .collect(); + .map(|idx| idx as u64); - let indices: ArrayRef = Arc::new(UInt64Array::from(selection_indices)); + let indices = Arc::new(UInt64Array::from_iter_values(selection_indices)) as ArrayRef; let indices_batch = RecordBatch::try_new(ROW_SELECTOR_SCHEMA_REF.clone(), vec![indices])?; Poll::Ready(Some(Ok(indices_batch))) @@ -422,7 +401,6 @@ mod test { use std::sync::Arc; use arrow_array::{RecordBatch, UInt64Array}; - use arrow_schema::{DataType, Field, Schema}; use datafusion_expr::{and, col, lit}; use itertools::Itertools; use vortex::array::bool::BoolArray; @@ -437,11 +415,6 @@ mod test { #[tokio::test] async fn test_filtering_stream() { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt64, false), - Field::new("b", DataType::Boolean, false), - ])); - let chunk = StructArray::try_new( Arc::new([FieldName::from("a"), FieldName::from("b")]), vec![ @@ -458,13 +431,11 @@ mod test { let chunked_array = ChunkedArray::try_new(vec![chunk.clone(), chunk.clone()], dtype).unwrap(); - let _schema = schema.clone(); let filtering_stream = RowIndicesStream { chunked_array: chunked_array.clone(), chunk_idx: 0, - conjunction_expr: and((col("a") % lit(2u64)).eq(lit(0u64)), col("b").is_true()), + conjunction_expr: and((col("a")).eq(lit(2u64)), col("b").eq(lit(true))), filter_projection: vec![0, 1], - filter_schema: _schema, }; let rows: Vec = futures::executor::block_on_stream(filtering_stream) diff --git a/vortex-scalar/src/datafusion.rs b/vortex-scalar/src/datafusion.rs index 594a6c3c3d..21d4012243 100644 --- a/vortex-scalar/src/datafusion.rs +++ b/vortex-scalar/src/datafusion.rs @@ -66,3 +66,32 @@ impl From for ScalarValue { } } } + +impl From for Scalar { + fn from(value: ScalarValue) -> Scalar { + match value { + ScalarValue::Null => Some(Scalar::null(DType::Null)), + ScalarValue::Boolean(b) => b.map(Scalar::from), + ScalarValue::Float16(f) => f.map(Scalar::from), + ScalarValue::Float32(f) => f.map(Scalar::from), + ScalarValue::Float64(f) => f.map(Scalar::from), + ScalarValue::Int8(i) => i.map(Scalar::from), + ScalarValue::Int16(i) => i.map(Scalar::from), + ScalarValue::Int32(i) => i.map(Scalar::from), + ScalarValue::Int64(i) => i.map(Scalar::from), + ScalarValue::UInt8(i) => i.map(Scalar::from), + ScalarValue::UInt16(i) => i.map(Scalar::from), + ScalarValue::UInt32(i) => i.map(Scalar::from), + ScalarValue::UInt64(i) => i.map(Scalar::from), + ScalarValue::Utf8(s) => s.as_ref().map(|s| Scalar::from(s.as_str())), + ScalarValue::Utf8View(s) => s.as_ref().map(|s| Scalar::from(s.as_str())), + ScalarValue::LargeUtf8(s) => s.as_ref().map(|s| Scalar::from(s.as_str())), + ScalarValue::Binary(b) => b.as_ref().map(|b| Scalar::from(b.clone())), + ScalarValue::BinaryView(b) => b.as_ref().map(|b| Scalar::from(b.clone())), + ScalarValue::LargeBinary(b) => b.as_ref().map(|b| Scalar::from(b.clone())), + ScalarValue::FixedSizeBinary(_, b) => b.map(|b| Scalar::from(b.clone())), + _ => unimplemented!(), + } + .unwrap_or(Scalar::null(DType::Null)) + } +}