From 5fc024cf2480a7ef018c4682748a7cce4b075996 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 29 Nov 2024 03:22:09 -0500 Subject: [PATCH 1/6] feat: support filter pushdown for datafusion --- crates/core/src/exprs/filter.rs | 164 +++++++++ crates/core/src/exprs/mod.rs | 97 +++++ crates/core/src/lib.rs | 1 + crates/core/src/table/fs_view.rs | 22 +- crates/core/src/table/mod.rs | 92 +++-- crates/core/src/table/partition.rs | 235 +++---------- crates/datafusion/Cargo.toml | 2 + crates/datafusion/src/lib.rs | 20 +- .../datafusion/src/utils/exprs_to_filter.rs | 331 ++++++++++++++++++ crates/datafusion/src/utils/mod.rs | 20 ++ python/src/internal.rs | 52 ++- 11 files changed, 800 insertions(+), 236 deletions(-) create mode 100644 crates/core/src/exprs/filter.rs create mode 100644 crates/core/src/exprs/mod.rs create mode 100644 crates/datafusion/src/utils/exprs_to_filter.rs create mode 100644 crates/datafusion/src/utils/mod.rs diff --git a/crates/core/src/exprs/filter.rs b/crates/core/src/exprs/filter.rs new file mode 100644 index 0000000..d3b4328 --- /dev/null +++ b/crates/core/src/exprs/filter.rs @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::exprs::HudiOperator; + +use anyhow::{Context, Result}; +use arrow_array::{ArrayRef, Scalar, StringArray}; +use arrow_cast::{cast_with_options, CastOptions}; +use arrow_schema::{DataType, Field, Schema}; +use std::str::FromStr; + +/// A partition filter that represents a filter expression for partition pruning. +#[derive(Debug, Clone)] +pub struct PartitionFilter { + pub field: Field, + pub operator: HudiOperator, + pub value: Scalar, +} + +impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { + type Error = anyhow::Error; + + fn try_from((filter, partition_schema): ((&str, &str, &str), &Schema)) -> Result { + let (field_name, operator_str, value_str) = filter; + + let field: &Field = partition_schema + .field_with_name(field_name) + .with_context(|| format!("Field '{}' not found in partition schema", field_name))?; + + let operator = HudiOperator::from_str(operator_str) + .with_context(|| format!("Unsupported operator: {}", operator_str))?; + + let value = &[value_str]; + let value = Self::cast_value(value, field.data_type()) + .with_context(|| format!("Unable to cast {:?} as {:?}", value, field.data_type()))?; + + let field = field.clone(); + Ok(PartitionFilter { + field, + operator, + value, + }) + } +} + +impl PartitionFilter { + pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { + let cast_options = CastOptions { + safe: false, + format_options: Default::default(), + }; + + let value = StringArray::from(Vec::from(value)); + + Ok(Scalar::new(cast_with_options( + &value, + data_type, + &cast_options, + )?)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::exprs::HudiOperator; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::Datum; + use std::str::FromStr; + + fn create_test_schema() -> Schema { + Schema::new(vec![ + Field::new("date", DataType::Date32, false), + Field::new("category", DataType::Utf8, false), + Field::new("count", DataType::Int32, false), + ]) + } + + #[test] + fn test_partition_filter_try_from_valid() { + let schema = create_test_schema(); + let filter_tuple = ("date", "=", "2023-01-01"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_ok()); + let filter = filter.unwrap(); + assert_eq!(filter.field.name(), "date"); + assert_eq!(filter.operator, HudiOperator::Eq); + assert_eq!(filter.value.get().0.len(), 1); + + let filter_tuple = ("category", "!=", "foo"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_ok()); + let filter = filter.unwrap(); + assert_eq!(filter.field.name(), "category"); + assert_eq!(filter.operator, HudiOperator::Ne); + assert_eq!(filter.value.get().0.len(), 1); + assert_eq!( + StringArray::from(filter.value.into_inner().to_data()).value(0), + "foo" + ) + } + + #[test] + fn test_partition_filter_try_from_invalid_field() { + let schema = create_test_schema(); + let filter_tuple = ("invalid_field", "=", "2023-01-01"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_err()); + assert!(filter + .unwrap_err() + .to_string() + .contains("not found in partition schema")); + } + + #[test] + fn test_partition_filter_try_from_invalid_operator() { + let schema = create_test_schema(); + let filter_tuple = ("date", "??", "2023-01-01"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_err()); + assert!(filter + .unwrap_err() + .to_string() + .contains("Unsupported operator: ??")); + } + + #[test] + fn test_partition_filter_try_from_invalid_value() { + let schema = create_test_schema(); + let filter_tuple = ("count", "=", "not_a_number"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_err()); + assert!(filter.unwrap_err().to_string().contains("Unable to cast")); + } + + #[test] + fn test_partition_filter_try_from_all_operators() { + let schema = create_test_schema(); + for (op, _) in HudiOperator::TOKEN_OP_PAIRS { + let filter_tuple = ("count", op, "10"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_ok(), "Failed for operator: {}", op); + let filter = filter.unwrap(); + assert_eq!(filter.field.name(), "count"); + assert_eq!(filter.operator, HudiOperator::from_str(op).unwrap()); + } + } +} diff --git a/crates/core/src/exprs/mod.rs b/crates/core/src/exprs/mod.rs new file mode 100644 index 0000000..2803dce --- /dev/null +++ b/crates/core/src/exprs/mod.rs @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod filter; + +use anyhow::{anyhow, Error}; +use std::cmp::PartialEq; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::str::FromStr; + +pub use filter::*; + +/// An operator that represents a comparison operation used in a partition filter expression. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum HudiOperator { + Eq, + Ne, + Lt, + Lte, + Gt, + Gte, +} + +impl Display for HudiOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + // Binary Operators + HudiOperator::Eq => write!(f, "="), + HudiOperator::Ne => write!(f, "!="), + HudiOperator::Lt => write!(f, "<"), + HudiOperator::Lte => write!(f, "<="), + HudiOperator::Gt => write!(f, ">"), + HudiOperator::Gte => write!(f, ">="), + } + } +} + +// TODO: Add more operators +impl HudiOperator { + pub const TOKEN_OP_PAIRS: [(&'static str, HudiOperator); 6] = [ + ("=", HudiOperator::Eq), + ("!=", HudiOperator::Ne), + ("<", HudiOperator::Lt), + ("<=", HudiOperator::Lte), + (">", HudiOperator::Gt), + (">=", HudiOperator::Gte), + ]; +} + +impl FromStr for HudiOperator { + type Err = Error; + + fn from_str(s: &str) -> Result { + HudiOperator::TOKEN_OP_PAIRS + .iter() + .find_map(|&(token, op)| { + if token.eq_ignore_ascii_case(s) { + Some(op) + } else { + None + } + }) + .ok_or_else(|| anyhow!("Unsupported operator: {}", s)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_operator_from_str() { + assert_eq!(HudiOperator::from_str("=").unwrap(), HudiOperator::Eq); + assert_eq!(HudiOperator::from_str("!=").unwrap(), HudiOperator::Ne); + assert_eq!(HudiOperator::from_str("<").unwrap(), HudiOperator::Lt); + assert_eq!(HudiOperator::from_str("<=").unwrap(), HudiOperator::Lte); + assert_eq!(HudiOperator::from_str(">").unwrap(), HudiOperator::Gt); + assert_eq!(HudiOperator::from_str(">=").unwrap(), HudiOperator::Gte); + assert!(HudiOperator::from_str("??").is_err()); + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 3190dec..710ec5c 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -44,6 +44,7 @@ //! ``` pub mod config; +pub mod exprs; pub mod file_group; pub mod storage; pub mod table; diff --git a/crates/core/src/table/fs_view.rs b/crates/core/src/table/fs_view.rs index f2309e3..955081b 100644 --- a/crates/core/src/table/fs_view.rs +++ b/crates/core/src/table/fs_view.rs @@ -181,12 +181,22 @@ mod tests { use crate::storage::Storage; use crate::table::fs_view::FileSystemView; use crate::table::partition::PartitionPruner; - use crate::table::Table; + use crate::table::{PartitionFilter, Table}; + + use anyhow::anyhow; + use arrow::datatypes::{DataType, Field, Schema}; use hudi_tests::TestTable; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use url::Url; + fn create_test_schema() -> Schema { + Schema::new(vec![ + Field::new("byteField", DataType::Int32, false), + Field::new("shortField", DataType::Int32, false), + ]) + } + async fn create_test_fs_view(base_url: Url) -> FileSystemView { FileSystemView::new( Arc::new(HudiConfigs::new([(HudiTableConfig::BasePath, base_url)])), @@ -296,8 +306,16 @@ mod tests { .await .unwrap(); let partition_schema = hudi_table.get_partition_schema().await.unwrap(); + + let schema = create_test_schema(); + let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_eq_300 = PartitionFilter::try_from((("shortField", "=", "300"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); let partition_pruner = PartitionPruner::new( - &[("byteField", "<", "20"), ("shortField", "=", "300")], + &[filter_lt_20, filter_eq_300], &partition_schema, hudi_table.hudi_configs.as_ref(), ) diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 4d69bbf..dee36ad 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -96,6 +96,7 @@ use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig; use crate::config::table::HudiTableConfig::PartitionFields; use crate::config::HudiConfigs; +use crate::exprs::PartitionFilter; use crate::file_group::reader::FileGroupReader; use crate::file_group::FileSlice; use crate::table::builder::TableBuilder; @@ -194,7 +195,7 @@ impl Table { pub async fn get_file_slices_splits( &self, n: usize, - filters: &[(&str, &str, &str)], + filters: &[PartitionFilter], ) -> Result>> { let file_slices = self.get_file_slices(filters).await?; if file_slices.is_empty() { @@ -213,7 +214,7 @@ impl Table { /// Get all the [FileSlice]s in the table. /// /// If the [AsOfTimestamp] configuration is set, the file slices at the specified timestamp will be returned. - pub async fn get_file_slices(&self, filters: &[(&str, &str, &str)]) -> Result> { + pub async fn get_file_slices(&self, filters: &[PartitionFilter]) -> Result> { if let Some(timestamp) = self.hudi_configs.try_get(AsOfTimestamp) { self.get_file_slices_as_of(timestamp.to::().as_str(), filters) .await @@ -228,7 +229,7 @@ impl Table { async fn get_file_slices_as_of( &self, timestamp: &str, - filters: &[(&str, &str, &str)], + filters: &[PartitionFilter], ) -> Result> { let excludes = self.timeline.get_replaced_file_groups().await?; let partition_schema = self.get_partition_schema().await?; @@ -242,7 +243,7 @@ impl Table { /// Get all the latest records in the table. /// /// If the [AsOfTimestamp] configuration is set, the records at the specified timestamp will be returned. - pub async fn read_snapshot(&self, filters: &[(&str, &str, &str)]) -> Result> { + pub async fn read_snapshot(&self, filters: &[PartitionFilter]) -> Result> { if let Some(timestamp) = self.hudi_configs.try_get(AsOfTimestamp) { self.read_snapshot_as_of(timestamp.to::().as_str(), filters) .await @@ -257,7 +258,7 @@ impl Table { async fn read_snapshot_as_of( &self, timestamp: &str, - filters: &[(&str, &str, &str)], + filters: &[PartitionFilter], ) -> Result> { let file_slices = self .get_file_slices_as_of(timestamp, filters) @@ -277,7 +278,7 @@ impl Table { #[cfg(test)] async fn get_file_paths_with_filters( &self, - filters: &[(&str, &str, &str)], + filters: &[PartitionFilter], ) -> Result> { let mut file_paths = Vec::new(); for f in self.get_file_slices(filters).await? { @@ -293,6 +294,7 @@ impl Table { #[cfg(test)] mod tests { + use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::StringArray; use std::collections::HashSet; use std::fs::canonicalize; @@ -300,6 +302,8 @@ mod tests { use std::{env, panic}; use url::Url; + use crate::exprs::PartitionFilter; + use hudi_tests::{assert_not, TestTable}; use crate::config::read::HudiReadConfig::AsOfTimestamp; @@ -312,7 +316,14 @@ mod tests { use crate::config::HUDI_CONF_DIR; use crate::storage::utils::join_url_segments; use crate::storage::Storage; - use crate::table::Table; + use crate::table::{anyhow, Table}; + + fn create_test_schema() -> Schema { + Schema::new(vec![ + Field::new("byteField", DataType::Int32, false), + Field::new("shortField", DataType::Int32, false), + ]) + } /// Test helper to create a new `Table` instance without validating the configuration. /// @@ -722,9 +733,17 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let partition_filters = &[("byteField", ">=", "10"), ("byteField", "<", "30")]; + let schema = create_test_schema(); + let filter_ge_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + + let filter_lt_30 = PartitionFilter::try_from((("byteField", "<", "30"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let actual = hudi_table - .get_file_paths_with_filters(partition_filters) + .get_file_paths_with_filters(&[filter_ge_10, filter_lt_30]) .await .unwrap() .into_iter() @@ -738,9 +757,11 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let partition_filters = &[("byteField", ">", "30")]; + let filter_gt_30 = PartitionFilter::try_from((("byteField", ">", "30"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); let actual = hudi_table - .get_file_paths_with_filters(partition_filters) + .get_file_paths_with_filters(&[filter_gt_30]) .await .unwrap() .into_iter() @@ -772,13 +793,19 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let partition_filters = &[ - ("byteField", ">=", "10"), - ("byteField", "<", "20"), - ("shortField", "!=", "100"), - ]; + let schema = create_test_schema(); + let filter_gte_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_ne_100 = PartitionFilter::try_from((("shortField", "!=", "100"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let actual = hudi_table - .get_file_paths_with_filters(partition_filters) + .get_file_paths_with_filters(&[filter_gte_10, filter_lt_20, filter_ne_100]) .await .unwrap() .into_iter() @@ -790,10 +817,15 @@ mod tests { .into_iter() .collect::>(); assert_eq!(actual, expected); + let filter_lt_20 = PartitionFilter::try_from((("byteField", ">", "20"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_eq_300 = PartitionFilter::try_from((("shortField", "=", "300"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); - let partition_filters = &[("byteField", ">", "20"), ("shortField", "=", "300")]; let actual = hudi_table - .get_file_paths_with_filters(partition_filters) + .get_file_paths_with_filters(&[filter_lt_20, filter_eq_300]) .await .unwrap() .into_iter() @@ -806,12 +838,22 @@ mod tests { async fn hudi_table_read_snapshot_for_complex_keygen_hive_style() { let base_url = TestTable::V6ComplexkeygenHivestyle.url(); let hudi_table = Table::new(base_url.path()).await.unwrap(); - let partition_filters = &[ - ("byteField", ">=", "10"), - ("byteField", "<", "20"), - ("shortField", "!=", "100"), - ]; - let records = hudi_table.read_snapshot(partition_filters).await.unwrap(); + + let schema = create_test_schema(); + let filter_gte_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_ne_100 = PartitionFilter::try_from((("shortField", "!=", "100"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + + let records = hudi_table + .read_snapshot(&[filter_gte_10, filter_lt_20, filter_ne_100]) + .await + .unwrap(); assert_eq!(records.len(), 1); assert_eq!(records[0].num_rows(), 2); let actual_partition_paths: HashSet<&str> = HashSet::from_iter( diff --git a/crates/core/src/table/partition.rs b/crates/core/src/table/partition.rs index 17927eb..7218097 100644 --- a/crates/core/src/table/partition.rs +++ b/crates/core/src/table/partition.rs @@ -18,15 +18,14 @@ */ use crate::config::table::HudiTableConfig; use crate::config::HudiConfigs; +use crate::exprs::{HudiOperator, PartitionFilter}; +use anyhow::anyhow; use anyhow::Result; -use anyhow::{anyhow, Context}; -use arrow_array::{ArrayRef, Scalar, StringArray}; -use arrow_cast::{cast_with_options, CastOptions}; +use arrow_array::{ArrayRef, Scalar}; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; -use arrow_schema::{DataType, Field, Schema}; -use std::cmp::PartialEq; +use arrow_schema::Schema; + use std::collections::HashMap; -use std::str::FromStr; use std::sync::Arc; /// A partition pruner that filters partitions based on the partition path and its filters. @@ -40,14 +39,11 @@ pub struct PartitionPruner { impl PartitionPruner { pub fn new( - and_filters: &[(&str, &str, &str)], + and_filters: &[PartitionFilter], partition_schema: &Schema, hudi_configs: &HudiConfigs, ) -> Result { - let and_filters = and_filters - .iter() - .map(|filter| PartitionFilter::try_from((*filter, partition_schema))) - .collect::>>()?; + let and_filters = and_filters.to_vec(); let schema = Arc::new(partition_schema.clone()); let is_hive_style: bool = hudi_configs @@ -90,12 +86,12 @@ impl PartitionPruner { match segments.get(filter.field.name()) { Some(segment_value) => { let comparison_result = match filter.operator { - Operator::Eq => eq(segment_value, &filter.value), - Operator::Ne => neq(segment_value, &filter.value), - Operator::Lt => lt(segment_value, &filter.value), - Operator::Lte => lt_eq(segment_value, &filter.value), - Operator::Gt => gt(segment_value, &filter.value), - Operator::Gte => gt_eq(segment_value, &filter.value), + HudiOperator::Eq => eq(segment_value, &filter.value), + HudiOperator::Ne => neq(segment_value, &filter.value), + HudiOperator::Lt => lt(segment_value, &filter.value), + HudiOperator::Lte => lt_eq(segment_value, &filter.value), + HudiOperator::Gt => gt(segment_value, &filter.value), + HudiOperator::Gte => gt_eq(segment_value, &filter.value), }; match comparison_result { @@ -154,90 +150,6 @@ impl PartitionPruner { } } -/// An operator that represents a comparison operation used in a partition filter expression. -#[derive(Debug, Clone, Copy, PartialEq)] -enum Operator { - Eq, - Ne, - Lt, - Lte, - Gt, - Gte, -} - -impl Operator { - const TOKEN_OP_PAIRS: [(&'static str, Operator); 6] = [ - ("=", Operator::Eq), - ("!=", Operator::Ne), - ("<", Operator::Lt), - ("<=", Operator::Lte), - (">", Operator::Gt), - (">=", Operator::Gte), - ]; -} - -impl FromStr for Operator { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - Operator::TOKEN_OP_PAIRS - .iter() - .find_map(|&(token, op)| if token == s { Some(op) } else { None }) - .ok_or_else(|| anyhow!("Unsupported operator: {}", s)) - } -} - -/// A partition filter that represents a filter expression for partition pruning. -#[derive(Debug, Clone)] -pub struct PartitionFilter { - field: Field, - operator: Operator, - value: Scalar, -} - -impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { - type Error = anyhow::Error; - - fn try_from((filter, partition_schema): ((&str, &str, &str), &Schema)) -> Result { - let (field_name, operator_str, value_str) = filter; - - let field: &Field = partition_schema - .field_with_name(field_name) - .with_context(|| format!("Field '{}' not found in partition schema", field_name))?; - - let operator = Operator::from_str(operator_str) - .with_context(|| format!("Unsupported operator: {}", operator_str))?; - - let value = &[value_str]; - let value = Self::cast_value(value, field.data_type()) - .with_context(|| format!("Unable to cast {:?} as {:?}", value, field.data_type()))?; - - let field = field.clone(); - Ok(PartitionFilter { - field, - operator, - value, - }) - } -} - -impl PartitionFilter { - fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), - }; - - let value = StringArray::from(Vec::from(value)); - - Ok(Scalar::new(cast_with_options( - &value, - data_type, - &cast_options, - )?)) - } -} - #[cfg(test)] mod tests { use super::*; @@ -245,9 +157,7 @@ mod tests { IsHiveStylePartitioning, IsPartitionPathUrlencoded, }; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{Array, Datum}; use hudi_tests::assert_not; - use std::str::FromStr; fn create_test_schema() -> Schema { Schema::new(vec![ @@ -257,87 +167,6 @@ mod tests { ]) } - #[test] - fn test_partition_filter_try_from_valid() { - let schema = create_test_schema(); - let filter_tuple = ("date", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "date"); - assert_eq!(filter.operator, Operator::Eq); - assert_eq!(filter.value.get().0.len(), 1); - - let filter_tuple = ("category", "!=", "foo"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "category"); - assert_eq!(filter.operator, Operator::Ne); - assert_eq!(filter.value.get().0.len(), 1); - assert_eq!( - StringArray::from(filter.value.into_inner().to_data()).value(0), - "foo" - ) - } - - #[test] - fn test_partition_filter_try_from_invalid_field() { - let schema = create_test_schema(); - let filter_tuple = ("invalid_field", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("not found in partition schema")); - } - - #[test] - fn test_partition_filter_try_from_invalid_operator() { - let schema = create_test_schema(); - let filter_tuple = ("date", "??", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("Unsupported operator: ??")); - } - - #[test] - fn test_partition_filter_try_from_invalid_value() { - let schema = create_test_schema(); - let filter_tuple = ("count", "=", "not_a_number"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter.unwrap_err().to_string().contains("Unable to cast")); - } - - #[test] - fn test_partition_filter_try_from_all_operators() { - let schema = create_test_schema(); - for (op, _) in Operator::TOKEN_OP_PAIRS { - let filter_tuple = ("count", op, "10"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok(), "Failed for operator: {}", op); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "count"); - assert_eq!(filter.operator, Operator::from_str(op).unwrap()); - } - } - - #[test] - fn test_operator_from_str() { - assert_eq!(Operator::from_str("=").unwrap(), Operator::Eq); - assert_eq!(Operator::from_str("!=").unwrap(), Operator::Ne); - assert_eq!(Operator::from_str("<").unwrap(), Operator::Lt); - assert_eq!(Operator::from_str("<=").unwrap(), Operator::Lte); - assert_eq!(Operator::from_str(">").unwrap(), Operator::Gt); - assert_eq!(Operator::from_str(">=").unwrap(), Operator::Gte); - assert!(Operator::from_str("??").is_err()); - } - fn create_hudi_configs(is_hive_style: bool, is_url_encoded: bool) -> HudiConfigs { HudiConfigs::new([ (IsHiveStylePartitioning, is_hive_style.to_string()), @@ -348,9 +177,15 @@ mod tests { fn test_partition_pruner_new() { let schema = create_test_schema(); let configs = create_hudi_configs(true, false); - let filters = vec![("date", ">", "2023-01-01"), ("category", "=", "A")]; - let pruner = PartitionPruner::new(&filters, &schema, &configs); + let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_eq_a = PartitionFilter::try_from((("category", "=", "A"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + + let pruner = PartitionPruner::new(&[filter_gt_date, filter_eq_a], &schema, &configs); assert!(pruner.is_ok()); let pruner = pruner.unwrap(); @@ -375,8 +210,10 @@ mod tests { let pruner_empty = PartitionPruner::new(&[], &schema, &configs).unwrap(); assert!(pruner_empty.is_empty()); - let pruner_non_empty = - PartitionPruner::new(&[("date", ">", "2023-01-01")], &schema, &configs).unwrap(); + let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let pruner_non_empty = PartitionPruner::new(&[filter_gt_date], &schema, &configs).unwrap(); assert_not!(pruner_non_empty.is_empty()); } @@ -384,13 +221,23 @@ mod tests { fn test_partition_pruner_should_include() { let schema = create_test_schema(); let configs = create_hudi_configs(true, false); - let filters = vec![ - ("date", ">", "2023-01-01"), - ("category", "=", "A"), - ("count", "<=", "100"), - ]; - let pruner = PartitionPruner::new(&filters, &schema, &configs).unwrap(); + let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_eq_a = PartitionFilter::try_from((("category", "=", "A"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + let filter_lte_100 = PartitionFilter::try_from((("count", "<=", "100"), &schema)) + .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) + .unwrap(); + + let pruner = PartitionPruner::new( + &[filter_gt_date, filter_eq_a, filter_lte_100], + &schema, + &configs, + ) + .unwrap(); assert!(pruner.should_include("date=2023-02-01/category=A/count=10")); assert!(pruner.should_include("date=2023-02-01/category=A/count=100")); diff --git a/crates/datafusion/Cargo.toml b/crates/datafusion/Cargo.toml index 120aa8a..3a8b1e6 100644 --- a/crates/datafusion/Cargo.toml +++ b/crates/datafusion/Cargo.toml @@ -30,6 +30,8 @@ repository.workspace = true [dependencies] hudi-core = { version = "0.3.0", path = "../core", features = ["datafusion"] } # arrow +arrow-array = { workspace = true } +arrow-cast = { workspace = true } arrow-schema = { workspace = true } # datafusion diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index 1bf17ea..1cab70c 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -17,6 +17,8 @@ * under the License. */ +pub mod utils; + use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; @@ -36,9 +38,10 @@ use datafusion_common::config::TableParquetOptions; use datafusion_common::DFSchema; use datafusion_common::DataFusionError::Execution; use datafusion_common::Result; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, TableProviderFilterPushDown, TableType}; use datafusion_physical_expr::create_physical_expr; +use crate::utils::exprs_to_filter::convert_exprs_to_filter; use hudi_core::config::read::HudiReadConfig::InputPartitions; use hudi_core::config::utils::empty_options; use hudi_core::storage::utils::{get_scheme_authority, parse_uri}; @@ -129,10 +132,11 @@ impl TableProvider for HudiDataSource { ) -> Result> { self.table.register_storage(state.runtime_env().clone()); + // Convert Datafusion `Expr` to `PartitionFilter` + let partition_filters = convert_exprs_to_filter(filters, &self.schema()); let file_slices = self .table - // TODO: implement supports_filters_pushdown() to pass filters to Hudi table API - .get_file_slices_splits(self.get_input_partitions(), &[]) + .get_file_slices_splits(self.get_input_partitions(), partition_filters.as_slice()) .await .map_err(|e| Execution(format!("Failed to get file slices from Hudi table: {}", e)))?; let mut parquet_file_groups: Vec> = Vec::new(); @@ -176,6 +180,16 @@ impl TableProvider for HudiDataSource { Ok(exec_builder.build_arc()) } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } } /// `HudiTableFactory` is responsible for creating and configuring Hudi tables. diff --git a/crates/datafusion/src/utils/exprs_to_filter.rs b/crates/datafusion/src/utils/exprs_to_filter.rs new file mode 100644 index 0000000..9dd282b --- /dev/null +++ b/crates/datafusion/src/utils/exprs_to_filter.rs @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use arrow_array::{Array, Scalar}; +use arrow_schema::SchemaRef; +use datafusion::logical_expr::Operator; +use datafusion_expr::{BinaryExpr, Expr}; +use hudi_core::exprs::{HudiOperator, PartitionFilter}; +use std::sync::Arc; + +// TODO: Handle other Datafusion `Expr` + +/// Converts a slice of DataFusion expressions (`Expr`) into a vector of `PartitionFilter`. +/// Returns `Some(Vec)` if at least one filter is successfully converted, +/// otherwise returns `None`. +pub fn convert_exprs_to_filter( + filters: &[Expr], + partition_schema: &SchemaRef, +) -> Vec { + let mut partition_filters = Vec::new(); + + for expr in filters { + match expr { + Expr::BinaryExpr(binary_expr) => { + if let Some(partition_filter) = convert_binary_expr(binary_expr, partition_schema) { + partition_filters.push(partition_filter); + } else { + continue; + } + } + Expr::Not(not_expr) => { + // Handle NOT expressions + if let Some(partition_filter) = convert_not_expr(not_expr, partition_schema) { + partition_filters.push(partition_filter); + } else { + continue; + } + } + _ => { + continue; + } + } + } + + partition_filters +} + +/// Converts a binary expression (`Expr::BinaryExpr`) into a `PartitionFilter`. +fn convert_binary_expr( + binary_expr: &BinaryExpr, + partition_schema: &SchemaRef, +) -> Option { + // extract the column and literal from the binary expression + let (column, literal) = match (&*binary_expr.left, &*binary_expr.right) { + (Expr::Column(col), Expr::Literal(lit)) => (col, lit), + (Expr::Literal(lit), Expr::Column(col)) => (col, lit), + _ => return None, + }; + + let field = partition_schema + .field_with_name(column.name()) + .unwrap() + .clone(); + + let operator = match binary_expr.op { + Operator::Eq => HudiOperator::Eq, + Operator::NotEq => HudiOperator::Ne, + Operator::Lt => HudiOperator::Lt, + Operator::LtEq => HudiOperator::Lte, + Operator::Gt => HudiOperator::Gt, + Operator::GtEq => HudiOperator::Gte, + _ => return None, + }; + + let value = match literal.cast_to(field.data_type()) { + Ok(value) => { + let array_ref: Arc = value.to_array().unwrap(); + Scalar::new(array_ref) + } + Err(_) => return None, + }; + + Some(PartitionFilter { + field, + operator, + value, + }) +} + +/// Converts a NOT expression (`Expr::Not`) into a `PartitionFilter`. +fn convert_not_expr(not_expr: &Expr, partition_schema: &SchemaRef) -> Option { + match not_expr { + Expr::BinaryExpr(ref binary_expr) => { + let mut partition_filter = convert_binary_expr(binary_expr, partition_schema)?; + partition_filter.operator = negate_operator(partition_filter.operator)?; + Some(partition_filter) + } + _ => None, + } +} + +/// Negates a given operator +fn negate_operator(op: HudiOperator) -> Option { + match op { + HudiOperator::Eq => Some(HudiOperator::Ne), + HudiOperator::Ne => Some(HudiOperator::Eq), + HudiOperator::Lt => Some(HudiOperator::Gte), + HudiOperator::Lte => Some(HudiOperator::Gt), + HudiOperator::Gt => Some(HudiOperator::Lte), + HudiOperator::Gte => Some(HudiOperator::Lt), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ArrayRef, Float64Array, Int32Array, Int64Array, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::logical_expr::{col, lit}; + use datafusion_expr::{BinaryExpr, Expr}; + use hudi_core::exprs::{HudiOperator, PartitionFilter}; + use std::f64::consts::PI; + use std::sync::Arc; + + #[test] + fn test_convert_simple_binary_expr() { + let partition_schema = Arc::new(Schema::new(vec![Field::new( + "partition_col", + DataType::Int32, + false, + )])); + + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("partition_col")), + Operator::Eq, + Box::new(lit(42i32)), + )); + + let filters = vec![expr]; + + let result = convert_exprs_to_filter(&filters, &partition_schema); + + assert_eq!(result.len(), 1); + + let expected_filter = PartitionFilter { + field: partition_schema.field(0).clone(), + operator: HudiOperator::Eq, + value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + }; + + assert_eq!(result[0].field, expected_filter.field); + assert_eq!(result[0].operator, expected_filter.operator); + assert_eq!( + *result[0].value.clone().into_inner(), + expected_filter.value.into_inner() + ); + } + + // Tests the conversion of a NOT expression + #[test] + fn test_convert_not_expr() { + let partition_schema = Arc::new(Schema::new(vec![Field::new( + "partition_col", + DataType::Int32, + false, + )])); + + let inner_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("partition_col")), + Operator::Eq, + Box::new(lit(42i32)), + )); + let expr = Expr::Not(Box::new(inner_expr)); + + let filters = vec![expr]; + + let result = convert_exprs_to_filter(&filters, &partition_schema); + + assert_eq!(result.len(), 1); + + let expected_filter = PartitionFilter { + field: partition_schema.field(0).clone(), + operator: HudiOperator::Ne, + value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + }; + + assert_eq!(result[0].field, expected_filter.field); + assert_eq!(result[0].operator, expected_filter.operator); + assert_eq!( + *result[0].value.clone().into_inner(), + expected_filter.value.into_inner() + ); + } + + #[test] + fn test_convert_binary_expr_extensive() { + // partition schema with multiple fields of different data types + let partition_schema = Arc::new(Schema::new(vec![ + Field::new("int32_col", DataType::Int32, false), + Field::new("int64_col", DataType::Int64, false), + Field::new("float64_col", DataType::Float64, false), + Field::new("string_col", DataType::Utf8, false), + ])); + + // list of test cases with different operators and data types + let test_cases = vec![ + ( + col("int32_col").eq(lit(42i32)), + Some(PartitionFilter { + field: partition_schema + .field_with_name("int32_col") + .unwrap() + .clone(), + operator: HudiOperator::Eq, + value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + }), + ), + ( + col("int64_col").gt_eq(lit(100i64)), + Some(PartitionFilter { + field: partition_schema + .field_with_name("int64_col") + .unwrap() + .clone(), + operator: HudiOperator::Gte, + value: Scalar::new(Arc::new(Int64Array::from(vec![100])) as ArrayRef), + }), + ), + ( + col("float64_col").lt(lit(PI)), + Some(PartitionFilter { + field: partition_schema + .field_with_name("float64_col") + .unwrap() + .clone(), + operator: HudiOperator::Lt, + value: Scalar::new(Arc::new(Float64Array::from(vec![PI])) as ArrayRef), + }), + ), + ( + col("string_col").not_eq(lit("test")), + Some(PartitionFilter { + field: partition_schema + .field_with_name("string_col") + .unwrap() + .clone(), + operator: HudiOperator::Ne, + value: Scalar::new(Arc::new(StringArray::from(vec!["test"])) as ArrayRef), + }), + ), + ]; + + let filters: Vec = test_cases.iter().map(|(expr, _)| expr.clone()).collect(); + let result = convert_exprs_to_filter(&filters, &partition_schema); + let expected_filters: Vec<&PartitionFilter> = test_cases + .iter() + .filter_map(|(_, opt_filter)| opt_filter.as_ref()) + .collect(); + + assert_eq!(result.len(), expected_filters.len()); + + for (converted_filter, expected_filter) in result.iter().zip(expected_filters.iter()) { + assert_eq!(converted_filter.field.name(), expected_filter.field.name()); + assert_eq!(converted_filter.operator, expected_filter.operator); + assert_eq!( + *converted_filter.value.clone().into_inner(), + expected_filter.value.clone().into_inner() + ); + } + } + + // Tests conversion with different operators (e.g., <, <=, >, >=) + #[test] + fn test_convert_various_operators() { + let partition_schema = Arc::new(Schema::new(vec![Field::new( + "partition_col", + DataType::Int32, + false, + )])); + + let operators = vec![ + (Operator::Lt, HudiOperator::Lt), + (Operator::LtEq, HudiOperator::Lte), + (Operator::Gt, HudiOperator::Gt), + (Operator::GtEq, HudiOperator::Gte), + ]; + + for (op, expected_op) in operators { + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("partition_col")), + op, + Box::new(lit(42i32)), + )); + + let filters = vec![expr]; + + let result = convert_exprs_to_filter(&filters, &partition_schema); + + assert_eq!(result.len(), 1); + + let expected_filter = PartitionFilter { + field: partition_schema.field(0).clone(), + operator: expected_op, + value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + }; + + assert_eq!(result[0].field, expected_filter.field); + assert_eq!(result[0].operator, expected_filter.operator); + assert_eq!( + *result[0].value.clone().into_inner(), + expected_filter.value.into_inner() + ); + } + } +} diff --git a/crates/datafusion/src/utils/mod.rs b/crates/datafusion/src/utils/mod.rs new file mode 100644 index 0000000..7d0f090 --- /dev/null +++ b/crates/datafusion/src/utils/mod.rs @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod exprs_to_filter; diff --git a/python/src/internal.rs b/python/src/internal.rs index 37201cd..a4fb8e5 100644 --- a/python/src/internal.rs +++ b/python/src/internal.rs @@ -23,15 +23,16 @@ use std::sync::OnceLock; use anyhow::Context; use arrow::pyarrow::ToPyArrow; +use arrow_schema::Schema; +use pyo3::exceptions::PyValueError; use pyo3::{pyclass, pyfunction, pymethods, PyErr, PyObject, PyResult, Python}; use tokio::runtime::Runtime; +use hudi::exprs::PartitionFilter; use hudi::file_group::reader::FileGroupReader; use hudi::file_group::FileSlice; use hudi::table::builder::TableBuilder; use hudi::table::Table; -use hudi::util::convert_vec_to_slice; -use hudi::util::vec_to_slice; #[cfg(not(tarpaulin))] #[derive(Clone, Debug)] @@ -163,10 +164,13 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult>> { + let schema: Schema = rt().block_on(self.inner.get_schema())?; + let partition_filters = convert_filters(filters, &schema)?; + py.allow_threads(|| { let file_slices = rt().block_on( self.inner - .get_file_slices_splits(n, vec_to_slice!(filters.unwrap_or_default())), + .get_file_slices_splits(n, partition_filters.as_slice()), )?; Ok(file_slices .iter() @@ -181,11 +185,12 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult> { + let schema: Schema = rt().block_on(self.inner.get_schema())?; + let partition_filters = convert_filters(filters, &schema)?; + py.allow_threads(|| { - let file_slices = rt().block_on( - self.inner - .get_file_slices(vec_to_slice!(filters.unwrap_or_default())), - )?; + let file_slices = + rt().block_on(self.inner.get_file_slices(partition_filters.as_slice()))?; Ok(file_slices.iter().map(convert_file_slice).collect()) }) } @@ -201,14 +206,37 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult { - rt().block_on( - self.inner - .read_snapshot(vec_to_slice!(filters.unwrap_or_default())), - )? - .to_pyarrow(py) + let schema: Schema = rt().block_on(self.inner.get_schema())?; + let partition_filters = convert_filters(filters, &schema)?; + + rt().block_on(self.inner.read_snapshot(partition_filters.as_slice()))? + .to_pyarrow(py) } } +// Temporary fix +fn convert_filters( + filters: Option>, + partition_schema: &Schema, +) -> PyResult> { + filters + .unwrap_or_default() + .into_iter() + .map(|(field, op, value)| { + PartitionFilter::try_from(( + (field.as_str(), op.as_str(), value.as_str()), + partition_schema, + )) + .map_err(|e| { + PyValueError::new_err(format!( + "Invalid filter ({}, {}, {}): {}", + field, op, value, e + )) + }) + }) + .collect() +} + #[cfg(not(tarpaulin))] #[pyfunction] #[pyo3(signature = (base_uri, hudi_options=None, storage_options=None, options=None))] From 6b72ac1e861ef1fbb218a6722219753317152928 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 29 Nov 2024 03:47:53 -0500 Subject: [PATCH 2/6] fix filters pushdown --- crates/core/src/exprs/mod.rs | 2 +- crates/datafusion/src/lib.rs | 45 ++++++++++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/crates/core/src/exprs/mod.rs b/crates/core/src/exprs/mod.rs index 2803dce..6dc0e40 100644 --- a/crates/core/src/exprs/mod.rs +++ b/crates/core/src/exprs/mod.rs @@ -51,7 +51,7 @@ impl Display for HudiOperator { } } -// TODO: Add more operators +// TODO: Add more operators impl HudiOperator { pub const TOKEN_OP_PAIRS: [(&'static str, HudiOperator); 6] = [ ("=", HudiOperator::Eq), diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index 1cab70c..320c10f 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -33,6 +33,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder; use datafusion::datasource::physical_plan::FileScanConfig; use datafusion::datasource::TableProvider; +use datafusion::logical_expr::Operator; use datafusion::physical_plan::ExecutionPlan; use datafusion_common::config::TableParquetOptions; use datafusion_common::DFSchema; @@ -95,12 +96,42 @@ impl HudiDataSource { } } + // Helper functions until all exprs are supported fn get_input_partitions(&self) -> usize { self.table .hudi_configs .get_or_default(InputPartitions) .to::() } + + fn can_push_down(&self, expr: &Expr) -> bool { + match expr { + Expr::BinaryExpr(binary_expr) => { + let left = &binary_expr.left; + let op = &binary_expr.op; + let right = &binary_expr.right; + self.is_supported_operator(op) + && self.is_supported_operand(left) + && self.is_supported_operand(right) + } + _ => false, + } + } + + fn is_supported_operator(&self, op: &Operator) -> bool { + matches!( + op, + Operator::Eq | Operator::Gt | Operator::Lt | Operator::GtEq | Operator::LtEq + ) + } + + fn is_supported_operand(&self, expr: &Expr) -> bool { + match expr { + Expr::Column(col) => self.schema().field_with_name(&col.name).is_ok(), + Expr::Literal(_) => true, + _ => false, + } + } } #[async_trait] @@ -185,10 +216,16 @@ impl TableProvider for HudiDataSource { &self, filters: &[&Expr], ) -> Result> { - Ok(vec![ - TableProviderFilterPushDown::Unsupported; - filters.len() - ]) + filters + .iter() + .map(|expr| { + if self.can_push_down(expr) { + Ok(TableProviderFilterPushDown::Exact) + } else { + Ok(TableProviderFilterPushDown::Unsupported) + } + }) + .collect() } } From 3125080bdf71ce6763838165c5bd131cbc897002 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 29 Nov 2024 15:11:45 -0500 Subject: [PATCH 3/6] add code coverage --- crates/core/src/exprs/mod.rs | 10 ++++ crates/datafusion/Cargo.toml | 1 + crates/datafusion/src/lib.rs | 49 ++++++++++++++++++- .../datafusion/src/utils/exprs_to_filter.rs | 23 +++++++-- 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/crates/core/src/exprs/mod.rs b/crates/core/src/exprs/mod.rs index 6dc0e40..0614e0b 100644 --- a/crates/core/src/exprs/mod.rs +++ b/crates/core/src/exprs/mod.rs @@ -94,4 +94,14 @@ mod tests { assert_eq!(HudiOperator::from_str(">=").unwrap(), HudiOperator::Gte); assert!(HudiOperator::from_str("??").is_err()); } + + #[test] + fn test_operator_display() { + assert_eq!(HudiOperator::Eq.to_string(), "="); + assert_eq!(HudiOperator::Ne.to_string(), "!="); + assert_eq!(HudiOperator::Lt.to_string(), "<"); + assert_eq!(HudiOperator::Lte.to_string(), "<="); + assert_eq!(HudiOperator::Gt.to_string(), ">"); + assert_eq!(HudiOperator::Gte.to_string(), ">="); + } } diff --git a/crates/datafusion/Cargo.toml b/crates/datafusion/Cargo.toml index 3a8b1e6..53bc2e4 100644 --- a/crates/datafusion/Cargo.toml +++ b/crates/datafusion/Cargo.toml @@ -30,6 +30,7 @@ repository.workspace = true [dependencies] hudi-core = { version = "0.3.0", path = "../core", features = ["datafusion"] } # arrow +arrow = { workspace = true } arrow-array = { workspace = true } arrow-cast = { workspace = true } arrow-schema = { workspace = true } diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index 320c10f..76a24b7 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -302,12 +302,13 @@ mod tests { use super::*; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::{SessionConfig, SessionContext}; - use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_common::{Column, DataFusionError, ScalarValue}; use std::fs::canonicalize; use std::path::Path; use std::sync::Arc; use url::Url; + use datafusion::logical_expr::BinaryExpr; use hudi_core::config::read::HudiReadConfig::InputPartitions; use hudi_tests::TestTable::{ V6ComplexkeygenHivestyle, V6Empty, V6Nonpartitioned, V6SimplekeygenHivestyleNoMetafields, @@ -516,4 +517,50 @@ mod tests { verify_data_with_replacecommits(&ctx, &sql, test_table.as_ref()).await } } + + #[tokio::test] + async fn test_supports_filters_pushdown() { + let table_provider = + HudiDataSource::new_with_options(V6Nonpartitioned.path().as_str(), empty_options()) + .await + .unwrap(); + + let expr1 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("name".to_string()))), + op: Operator::Eq, + right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("Alice".to_string())))), + }); + + let expr2 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("intField".to_string()))), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(20000)))), + }); + + let expr3 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name( + "nonexistent_column".to_string(), + ))), + op: Operator::Eq, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + }); + + let expr4 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("name".to_string()))), + op: Operator::NotEq, + right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("Diana".to_string())))), + }); + + let expr5 = Expr::Literal(ScalarValue::Int32(Some(10))); + + let filters = vec![&expr1, &expr2, &expr3, &expr4, &expr5]; + let result = table_provider.supports_filters_pushdown(&filters).unwrap(); + + assert_eq!(result.len(), 5); + assert_eq!(result[0], TableProviderFilterPushDown::Exact); // expr1 should be pushed down + assert_eq!(result[1], TableProviderFilterPushDown::Exact); // expr2 should be pushed down + assert_eq!(result[2], TableProviderFilterPushDown::Unsupported); + assert_eq!(result[3], TableProviderFilterPushDown::Unsupported); + assert_eq!(result[4], TableProviderFilterPushDown::Unsupported); + } } diff --git a/crates/datafusion/src/utils/exprs_to_filter.rs b/crates/datafusion/src/utils/exprs_to_filter.rs index 9dd282b..b6388d1 100644 --- a/crates/datafusion/src/utils/exprs_to_filter.rs +++ b/crates/datafusion/src/utils/exprs_to_filter.rs @@ -40,16 +40,12 @@ pub fn convert_exprs_to_filter( Expr::BinaryExpr(binary_expr) => { if let Some(partition_filter) = convert_binary_expr(binary_expr, partition_schema) { partition_filters.push(partition_filter); - } else { - continue; } } Expr::Not(not_expr) => { // Handle NOT expressions if let Some(partition_filter) = convert_not_expr(not_expr, partition_schema) { partition_filters.push(partition_filter); - } else { - continue; } } _ => { @@ -328,4 +324,23 @@ mod tests { ); } } + + #[test] + fn test_convert_expr_with_unsupported_operator() { + let partition_schema = Arc::new(Schema::new(vec![Field::new( + "partition_col", + DataType::Int32, + false, + )])); + + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("partition_col")), + Operator::And, + Box::new(lit("value")), + )); + + let filters = vec![expr]; + let result = convert_exprs_to_filter(&filters, &partition_schema); + assert!(result.is_empty()); + } } From 2537a9d9a6adcc28e12717fd6316de7d7293e26c Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 29 Nov 2024 22:02:35 -0500 Subject: [PATCH 4/6] add not expr --- crates/datafusion/src/lib.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index 76a24b7..d088797 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -114,6 +114,10 @@ impl HudiDataSource { && self.is_supported_operand(left) && self.is_supported_operand(right) } + Expr::Not(inner_expr) => { + // Recursively check if the inner expression can be pushed down + self.can_push_down(inner_expr) + } _ => false, } } @@ -553,14 +557,21 @@ mod tests { let expr5 = Expr::Literal(ScalarValue::Int32(Some(10))); - let filters = vec![&expr1, &expr2, &expr3, &expr4, &expr5]; + let expr6 = Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("intField".to_string()))), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(20000)))), + }))); + + let filters = vec![&expr1, &expr2, &expr3, &expr4, &expr5, &expr6]; let result = table_provider.supports_filters_pushdown(&filters).unwrap(); - assert_eq!(result.len(), 5); - assert_eq!(result[0], TableProviderFilterPushDown::Exact); // expr1 should be pushed down - assert_eq!(result[1], TableProviderFilterPushDown::Exact); // expr2 should be pushed down + assert_eq!(result.len(), 6); + assert_eq!(result[0], TableProviderFilterPushDown::Exact); + assert_eq!(result[1], TableProviderFilterPushDown::Exact); assert_eq!(result[2], TableProviderFilterPushDown::Unsupported); assert_eq!(result[3], TableProviderFilterPushDown::Unsupported); assert_eq!(result[4], TableProviderFilterPushDown::Unsupported); + assert_eq!(result[5], TableProviderFilterPushDown::Exact); } } From eb807a7fc391a880e8729bb764cb43ec6c261b8e Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 2 Dec 2024 00:52:11 -0500 Subject: [PATCH 5/6] small changes --- crates/core/src/exprs/filter.rs | 146 ----------------- crates/core/src/exprs/mod.rs | 64 ++++---- crates/core/src/table/fs_view.rs | 2 - crates/core/src/table/mod.rs | 20 +-- crates/core/src/table/partition.rs | 150 ++++++++++++++++-- crates/datafusion/src/lib.rs | 2 +- .../datafusion/src/utils/exprs_to_filter.rs | 53 ++++--- python/src/internal.rs | 1 - 8 files changed, 200 insertions(+), 238 deletions(-) diff --git a/crates/core/src/exprs/filter.rs b/crates/core/src/exprs/filter.rs index d3b4328..042f3ce 100644 --- a/crates/core/src/exprs/filter.rs +++ b/crates/core/src/exprs/filter.rs @@ -16,149 +16,3 @@ * specific language governing permissions and limitations * under the License. */ - -use crate::exprs::HudiOperator; - -use anyhow::{Context, Result}; -use arrow_array::{ArrayRef, Scalar, StringArray}; -use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Field, Schema}; -use std::str::FromStr; - -/// A partition filter that represents a filter expression for partition pruning. -#[derive(Debug, Clone)] -pub struct PartitionFilter { - pub field: Field, - pub operator: HudiOperator, - pub value: Scalar, -} - -impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { - type Error = anyhow::Error; - - fn try_from((filter, partition_schema): ((&str, &str, &str), &Schema)) -> Result { - let (field_name, operator_str, value_str) = filter; - - let field: &Field = partition_schema - .field_with_name(field_name) - .with_context(|| format!("Field '{}' not found in partition schema", field_name))?; - - let operator = HudiOperator::from_str(operator_str) - .with_context(|| format!("Unsupported operator: {}", operator_str))?; - - let value = &[value_str]; - let value = Self::cast_value(value, field.data_type()) - .with_context(|| format!("Unable to cast {:?} as {:?}", value, field.data_type()))?; - - let field = field.clone(); - Ok(PartitionFilter { - field, - operator, - value, - }) - } -} - -impl PartitionFilter { - pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), - }; - - let value = StringArray::from(Vec::from(value)); - - Ok(Scalar::new(cast_with_options( - &value, - data_type, - &cast_options, - )?)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::exprs::HudiOperator; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::Datum; - use std::str::FromStr; - - fn create_test_schema() -> Schema { - Schema::new(vec![ - Field::new("date", DataType::Date32, false), - Field::new("category", DataType::Utf8, false), - Field::new("count", DataType::Int32, false), - ]) - } - - #[test] - fn test_partition_filter_try_from_valid() { - let schema = create_test_schema(); - let filter_tuple = ("date", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "date"); - assert_eq!(filter.operator, HudiOperator::Eq); - assert_eq!(filter.value.get().0.len(), 1); - - let filter_tuple = ("category", "!=", "foo"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "category"); - assert_eq!(filter.operator, HudiOperator::Ne); - assert_eq!(filter.value.get().0.len(), 1); - assert_eq!( - StringArray::from(filter.value.into_inner().to_data()).value(0), - "foo" - ) - } - - #[test] - fn test_partition_filter_try_from_invalid_field() { - let schema = create_test_schema(); - let filter_tuple = ("invalid_field", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("not found in partition schema")); - } - - #[test] - fn test_partition_filter_try_from_invalid_operator() { - let schema = create_test_schema(); - let filter_tuple = ("date", "??", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("Unsupported operator: ??")); - } - - #[test] - fn test_partition_filter_try_from_invalid_value() { - let schema = create_test_schema(); - let filter_tuple = ("count", "=", "not_a_number"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter.unwrap_err().to_string().contains("Unable to cast")); - } - - #[test] - fn test_partition_filter_try_from_all_operators() { - let schema = create_test_schema(); - for (op, _) in HudiOperator::TOKEN_OP_PAIRS { - let filter_tuple = ("count", op, "10"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok(), "Failed for operator: {}", op); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "count"); - assert_eq!(filter.operator, HudiOperator::from_str(op).unwrap()); - } - } -} diff --git a/crates/core/src/exprs/mod.rs b/crates/core/src/exprs/mod.rs index 0614e0b..063f0f2 100644 --- a/crates/core/src/exprs/mod.rs +++ b/crates/core/src/exprs/mod.rs @@ -24,11 +24,9 @@ use std::cmp::PartialEq; use std::fmt::{Display, Formatter, Result as FmtResult}; use std::str::FromStr; -pub use filter::*; - /// An operator that represents a comparison operation used in a partition filter expression. #[derive(Debug, Clone, Copy, PartialEq)] -pub enum HudiOperator { +pub enum ExprOperator { Eq, Ne, Lt, @@ -37,37 +35,37 @@ pub enum HudiOperator { Gte, } -impl Display for HudiOperator { +impl Display for ExprOperator { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match self { // Binary Operators - HudiOperator::Eq => write!(f, "="), - HudiOperator::Ne => write!(f, "!="), - HudiOperator::Lt => write!(f, "<"), - HudiOperator::Lte => write!(f, "<="), - HudiOperator::Gt => write!(f, ">"), - HudiOperator::Gte => write!(f, ">="), + ExprOperator::Eq => write!(f, "="), + ExprOperator::Ne => write!(f, "!="), + ExprOperator::Lt => write!(f, "<"), + ExprOperator::Lte => write!(f, "<="), + ExprOperator::Gt => write!(f, ">"), + ExprOperator::Gte => write!(f, ">="), } } } // TODO: Add more operators -impl HudiOperator { - pub const TOKEN_OP_PAIRS: [(&'static str, HudiOperator); 6] = [ - ("=", HudiOperator::Eq), - ("!=", HudiOperator::Ne), - ("<", HudiOperator::Lt), - ("<=", HudiOperator::Lte), - (">", HudiOperator::Gt), - (">=", HudiOperator::Gte), +impl ExprOperator { + pub const TOKEN_OP_PAIRS: [(&'static str, ExprOperator); 6] = [ + ("=", ExprOperator::Eq), + ("!=", ExprOperator::Ne), + ("<", ExprOperator::Lt), + ("<=", ExprOperator::Lte), + (">", ExprOperator::Gt), + (">=", ExprOperator::Gte), ]; } -impl FromStr for HudiOperator { +impl FromStr for ExprOperator { type Err = Error; fn from_str(s: &str) -> Result { - HudiOperator::TOKEN_OP_PAIRS + ExprOperator::TOKEN_OP_PAIRS .iter() .find_map(|&(token, op)| { if token.eq_ignore_ascii_case(s) { @@ -86,22 +84,22 @@ mod tests { #[test] fn test_operator_from_str() { - assert_eq!(HudiOperator::from_str("=").unwrap(), HudiOperator::Eq); - assert_eq!(HudiOperator::from_str("!=").unwrap(), HudiOperator::Ne); - assert_eq!(HudiOperator::from_str("<").unwrap(), HudiOperator::Lt); - assert_eq!(HudiOperator::from_str("<=").unwrap(), HudiOperator::Lte); - assert_eq!(HudiOperator::from_str(">").unwrap(), HudiOperator::Gt); - assert_eq!(HudiOperator::from_str(">=").unwrap(), HudiOperator::Gte); - assert!(HudiOperator::from_str("??").is_err()); + assert_eq!(ExprOperator::from_str("=").unwrap(), ExprOperator::Eq); + assert_eq!(ExprOperator::from_str("!=").unwrap(), ExprOperator::Ne); + assert_eq!(ExprOperator::from_str("<").unwrap(), ExprOperator::Lt); + assert_eq!(ExprOperator::from_str("<=").unwrap(), ExprOperator::Lte); + assert_eq!(ExprOperator::from_str(">").unwrap(), ExprOperator::Gt); + assert_eq!(ExprOperator::from_str(">=").unwrap(), ExprOperator::Gte); + assert!(ExprOperator::from_str("??").is_err()); } #[test] fn test_operator_display() { - assert_eq!(HudiOperator::Eq.to_string(), "="); - assert_eq!(HudiOperator::Ne.to_string(), "!="); - assert_eq!(HudiOperator::Lt.to_string(), "<"); - assert_eq!(HudiOperator::Lte.to_string(), "<="); - assert_eq!(HudiOperator::Gt.to_string(), ">"); - assert_eq!(HudiOperator::Gte.to_string(), ">="); + assert_eq!(ExprOperator::Eq.to_string(), "="); + assert_eq!(ExprOperator::Ne.to_string(), "!="); + assert_eq!(ExprOperator::Lt.to_string(), "<"); + assert_eq!(ExprOperator::Lte.to_string(), "<="); + assert_eq!(ExprOperator::Gt.to_string(), ">"); + assert_eq!(ExprOperator::Gte.to_string(), ">="); } } diff --git a/crates/core/src/table/fs_view.rs b/crates/core/src/table/fs_view.rs index 955081b..fff30e1 100644 --- a/crates/core/src/table/fs_view.rs +++ b/crates/core/src/table/fs_view.rs @@ -309,10 +309,8 @@ mod tests { let schema = create_test_schema(); let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_eq_300 = PartitionFilter::try_from((("shortField", "=", "300"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let partition_pruner = PartitionPruner::new( &[filter_lt_20, filter_eq_300], diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index dee36ad..2ad2f45 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -96,17 +96,16 @@ use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig; use crate::config::table::HudiTableConfig::PartitionFields; use crate::config::HudiConfigs; -use crate::exprs::PartitionFilter; use crate::file_group::reader::FileGroupReader; use crate::file_group::FileSlice; use crate::table::builder::TableBuilder; use crate::table::fs_view::FileSystemView; -use crate::table::partition::PartitionPruner; +use crate::table::partition::{PartitionFilter, PartitionPruner}; use crate::table::timeline::Timeline; pub mod builder; mod fs_view; -mod partition; +pub mod partition; mod timeline; /// Hudi Table in-memory @@ -301,10 +300,8 @@ mod tests { use std::path::PathBuf; use std::{env, panic}; use url::Url; - - use crate::exprs::PartitionFilter; - use hudi_tests::{assert_not, TestTable}; + use crate::table::PartitionFilter; use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig::{ @@ -735,11 +732,9 @@ mod tests { let schema = create_test_schema(); let filter_ge_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_lt_30 = PartitionFilter::try_from((("byteField", "<", "30"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let actual = hudi_table @@ -758,7 +753,6 @@ mod tests { assert_eq!(actual, expected); let filter_gt_30 = PartitionFilter::try_from((("byteField", ">", "30"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let actual = hudi_table .get_file_paths_with_filters(&[filter_gt_30]) @@ -795,13 +789,10 @@ mod tests { let schema = create_test_schema(); let filter_gte_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_ne_100 = PartitionFilter::try_from((("shortField", "!=", "100"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let actual = hudi_table @@ -818,10 +809,8 @@ mod tests { .collect::>(); assert_eq!(actual, expected); let filter_lt_20 = PartitionFilter::try_from((("byteField", ">", "20"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_eq_300 = PartitionFilter::try_from((("shortField", "=", "300"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let actual = hudi_table @@ -841,13 +830,10 @@ mod tests { let schema = create_test_schema(); let filter_gte_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_ne_100 = PartitionFilter::try_from((("shortField", "!=", "100"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let records = hudi_table diff --git a/crates/core/src/table/partition.rs b/crates/core/src/table/partition.rs index 7218097..7774a36 100644 --- a/crates/core/src/table/partition.rs +++ b/crates/core/src/table/partition.rs @@ -18,12 +18,18 @@ */ use crate::config::table::HudiTableConfig; use crate::config::HudiConfigs; -use crate::exprs::{HudiOperator, PartitionFilter}; +use crate::exprs::ExprOperator; use anyhow::anyhow; use anyhow::Result; use arrow_array::{ArrayRef, Scalar}; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::Schema; +use anyhow::Context; +use arrow_array::StringArray; +use arrow_cast::{cast_with_options, CastOptions}; +use arrow_schema::{DataType, Field}; +use std::str::FromStr; + use std::collections::HashMap; use std::sync::Arc; @@ -86,12 +92,12 @@ impl PartitionPruner { match segments.get(filter.field.name()) { Some(segment_value) => { let comparison_result = match filter.operator { - HudiOperator::Eq => eq(segment_value, &filter.value), - HudiOperator::Ne => neq(segment_value, &filter.value), - HudiOperator::Lt => lt(segment_value, &filter.value), - HudiOperator::Lte => lt_eq(segment_value, &filter.value), - HudiOperator::Gt => gt(segment_value, &filter.value), - HudiOperator::Gte => gt_eq(segment_value, &filter.value), + ExprOperator::Eq => eq(segment_value, &filter.value), + ExprOperator::Ne => neq(segment_value, &filter.value), + ExprOperator::Lt => lt(segment_value, &filter.value), + ExprOperator::Lte => lt_eq(segment_value, &filter.value), + ExprOperator::Gt => gt(segment_value, &filter.value), + ExprOperator::Gte => gt_eq(segment_value, &filter.value), }; match comparison_result { @@ -150,6 +156,57 @@ impl PartitionPruner { } } +/// A partition filter that represents a filter expression for partition pruning. +#[derive(Debug, Clone)] +pub struct PartitionFilter { + pub field: Field, + pub operator: ExprOperator, + pub value: Scalar, +} + +impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { + type Error = anyhow::Error; + + fn try_from((filter, partition_schema): ((&str, &str, &str), &Schema)) -> Result { + let (field_name, operator_str, value_str) = filter; + + let field: &Field = partition_schema + .field_with_name(field_name) + .with_context(|| format!("Field '{}' not found in partition schema", field_name))?; + + let operator = ExprOperator::from_str(operator_str) + .with_context(|| format!("Unsupported operator: {}", operator_str))?; + + let value = &[value_str]; + let value = Self::cast_value(value, field.data_type()) + .with_context(|| format!("Unable to cast {:?} as {:?}", value, field.data_type()))?; + + let field = field.clone(); + Ok(PartitionFilter { + field, + operator, + value, + }) + } +} + +impl PartitionFilter { + pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { + let cast_options = CastOptions { + safe: false, + format_options: Default::default(), + }; + + let value = StringArray::from(Vec::from(value)); + + Ok(Scalar::new(cast_with_options( + &value, + data_type, + &cast_options, + )?)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -158,6 +215,9 @@ mod tests { }; use arrow::datatypes::{DataType, Field, Schema}; use hudi_tests::assert_not; + use crate::exprs::ExprOperator; + use arrow_array::Datum; + use std::str::FromStr; fn create_test_schema() -> Schema { Schema::new(vec![ @@ -179,10 +239,8 @@ mod tests { let configs = create_hudi_configs(true, false); let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_eq_a = PartitionFilter::try_from((("category", "=", "A"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let pruner = PartitionPruner::new(&[filter_gt_date, filter_eq_a], &schema, &configs); @@ -211,7 +269,6 @@ mod tests { assert!(pruner_empty.is_empty()); let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let pruner_non_empty = PartitionPruner::new(&[filter_gt_date], &schema, &configs).unwrap(); assert_not!(pruner_non_empty.is_empty()); @@ -223,13 +280,10 @@ mod tests { let configs = create_hudi_configs(true, false); let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_eq_a = PartitionFilter::try_from((("category", "=", "A"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let filter_lte_100 = PartitionFilter::try_from((("count", "<=", "100"), &schema)) - .map_err(|e| anyhow!("Failed to create PartitionFilter: {}", e)) .unwrap(); let pruner = PartitionPruner::new( @@ -283,4 +337,74 @@ mod tests { assert!(pruner.parse_segments("invalid/path").is_err()); } + + #[test] + fn test_partition_filter_try_from_valid() { + let schema = create_test_schema(); + let filter_tuple = ("date", "=", "2023-01-01"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_ok()); + let filter = filter.unwrap(); + assert_eq!(filter.field.name(), "date"); + assert_eq!(filter.operator, ExprOperator::Eq); + assert_eq!(filter.value.get().0.len(), 1); + + let filter_tuple = ("category", "!=", "foo"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_ok()); + let filter = filter.unwrap(); + assert_eq!(filter.field.name(), "category"); + assert_eq!(filter.operator, ExprOperator::Ne); + assert_eq!(filter.value.get().0.len(), 1); + assert_eq!( + StringArray::from(filter.value.into_inner().to_data()).value(0), + "foo" + ) + } + + #[test] + fn test_partition_filter_try_from_invalid_field() { + let schema = create_test_schema(); + let filter_tuple = ("invalid_field", "=", "2023-01-01"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_err()); + assert!(filter + .unwrap_err() + .to_string() + .contains("not found in partition schema")); + } + + #[test] + fn test_partition_filter_try_from_invalid_operator() { + let schema = create_test_schema(); + let filter_tuple = ("date", "??", "2023-01-01"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_err()); + assert!(filter + .unwrap_err() + .to_string() + .contains("Unsupported operator: ??")); + } + + #[test] + fn test_partition_filter_try_from_invalid_value() { + let schema = create_test_schema(); + let filter_tuple = ("count", "=", "not_a_number"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_err()); + assert!(filter.unwrap_err().to_string().contains("Unable to cast")); + } + + #[test] + fn test_partition_filter_try_from_all_operators() { + let schema = create_test_schema(); + for (op, _) in ExprOperator::TOKEN_OP_PAIRS { + let filter_tuple = ("count", op, "10"); + let filter = PartitionFilter::try_from((filter_tuple, &schema)); + assert!(filter.is_ok(), "Failed for operator: {}", op); + let filter = filter.unwrap(); + assert_eq!(filter.field.name(), "count"); + assert_eq!(filter.operator, ExprOperator::from_str(op).unwrap()); + } + } } diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index d088797..fc14f4d 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -224,7 +224,7 @@ impl TableProvider for HudiDataSource { .iter() .map(|expr| { if self.can_push_down(expr) { - Ok(TableProviderFilterPushDown::Exact) + Ok(TableProviderFilterPushDown::Inexact) } else { Ok(TableProviderFilterPushDown::Unsupported) } diff --git a/crates/datafusion/src/utils/exprs_to_filter.rs b/crates/datafusion/src/utils/exprs_to_filter.rs index b6388d1..0f9b5ed 100644 --- a/crates/datafusion/src/utils/exprs_to_filter.rs +++ b/crates/datafusion/src/utils/exprs_to_filter.rs @@ -19,9 +19,11 @@ use arrow_array::{Array, Scalar}; use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; use datafusion::logical_expr::Operator; use datafusion_expr::{BinaryExpr, Expr}; -use hudi_core::exprs::{HudiOperator, PartitionFilter}; +use hudi_core::exprs::ExprOperator; +use hudi_core::table::partition::PartitionFilter; use std::sync::Arc; // TODO: Handle other Datafusion `Expr` @@ -71,16 +73,17 @@ fn convert_binary_expr( let field = partition_schema .field_with_name(column.name()) + .map_err(|e| DataFusionError::Plan(format!("Error finding field with name '{}': {}", column.name(), e))) .unwrap() .clone(); let operator = match binary_expr.op { - Operator::Eq => HudiOperator::Eq, - Operator::NotEq => HudiOperator::Ne, - Operator::Lt => HudiOperator::Lt, - Operator::LtEq => HudiOperator::Lte, - Operator::Gt => HudiOperator::Gt, - Operator::GtEq => HudiOperator::Gte, + Operator::Eq => ExprOperator::Eq, + Operator::NotEq => ExprOperator::Ne, + Operator::Lt => ExprOperator::Lt, + Operator::LtEq => ExprOperator::Lte, + Operator::Gt => ExprOperator::Gt, + Operator::GtEq => ExprOperator::Gte, _ => return None, }; @@ -112,14 +115,14 @@ fn convert_not_expr(not_expr: &Expr, partition_schema: &SchemaRef) -> Option Option { +fn negate_operator(op: ExprOperator) -> Option { match op { - HudiOperator::Eq => Some(HudiOperator::Ne), - HudiOperator::Ne => Some(HudiOperator::Eq), - HudiOperator::Lt => Some(HudiOperator::Gte), - HudiOperator::Lte => Some(HudiOperator::Gt), - HudiOperator::Gt => Some(HudiOperator::Lte), - HudiOperator::Gte => Some(HudiOperator::Lt), + ExprOperator::Eq => Some(ExprOperator::Ne), + ExprOperator::Ne => Some(ExprOperator::Eq), + ExprOperator::Lt => Some(ExprOperator::Gte), + ExprOperator::Lte => Some(ExprOperator::Gt), + ExprOperator::Gt => Some(ExprOperator::Lte), + ExprOperator::Gte => Some(ExprOperator::Lt), } } @@ -130,7 +133,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::{col, lit}; use datafusion_expr::{BinaryExpr, Expr}; - use hudi_core::exprs::{HudiOperator, PartitionFilter}; + use hudi_core::exprs::ExprOperator; use std::f64::consts::PI; use std::sync::Arc; @@ -156,7 +159,7 @@ mod tests { let expected_filter = PartitionFilter { field: partition_schema.field(0).clone(), - operator: HudiOperator::Eq, + operator: ExprOperator::Eq, value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), }; @@ -192,7 +195,7 @@ mod tests { let expected_filter = PartitionFilter { field: partition_schema.field(0).clone(), - operator: HudiOperator::Ne, + operator: ExprOperator::Ne, value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), }; @@ -223,7 +226,7 @@ mod tests { .field_with_name("int32_col") .unwrap() .clone(), - operator: HudiOperator::Eq, + operator: ExprOperator::Eq, value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), }), ), @@ -234,7 +237,7 @@ mod tests { .field_with_name("int64_col") .unwrap() .clone(), - operator: HudiOperator::Gte, + operator: ExprOperator::Gte, value: Scalar::new(Arc::new(Int64Array::from(vec![100])) as ArrayRef), }), ), @@ -245,7 +248,7 @@ mod tests { .field_with_name("float64_col") .unwrap() .clone(), - operator: HudiOperator::Lt, + operator: ExprOperator::Lt, value: Scalar::new(Arc::new(Float64Array::from(vec![PI])) as ArrayRef), }), ), @@ -256,7 +259,7 @@ mod tests { .field_with_name("string_col") .unwrap() .clone(), - operator: HudiOperator::Ne, + operator: ExprOperator::Ne, value: Scalar::new(Arc::new(StringArray::from(vec!["test"])) as ArrayRef), }), ), @@ -291,10 +294,10 @@ mod tests { )])); let operators = vec![ - (Operator::Lt, HudiOperator::Lt), - (Operator::LtEq, HudiOperator::Lte), - (Operator::Gt, HudiOperator::Gt), - (Operator::GtEq, HudiOperator::Gte), + (Operator::Lt, ExprOperator::Lt), + (Operator::LtEq, ExprOperator::Lte), + (Operator::Gt, ExprOperator::Gt), + (Operator::GtEq, ExprOperator::Gte), ]; for (op, expected_op) in operators { diff --git a/python/src/internal.rs b/python/src/internal.rs index a4fb8e5..82abb54 100644 --- a/python/src/internal.rs +++ b/python/src/internal.rs @@ -214,7 +214,6 @@ impl HudiTable { } } -// Temporary fix fn convert_filters( filters: Option>, partition_schema: &Schema, From 8d7aa6268e4305ee0ef21e7c02c712a89e77e327 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 2 Dec 2024 22:24:05 -0500 Subject: [PATCH 6/6] fixes --- crates/core/src/exprs/filter.rs | 53 +++++ crates/core/src/table/fs_view.rs | 20 +- crates/core/src/table/mod.rs | 69 ++---- crates/core/src/table/partition.rs | 144 ++++++------ crates/datafusion/src/lib.rs | 12 +- .../datafusion/src/utils/exprs_to_filter.rs | 217 +++++++----------- python/src/internal.rs | 34 +-- python/tests/test_table_read.py | 14 ++ 8 files changed, 259 insertions(+), 304 deletions(-) diff --git a/crates/core/src/exprs/filter.rs b/crates/core/src/exprs/filter.rs index 042f3ce..4b082b3 100644 --- a/crates/core/src/exprs/filter.rs +++ b/crates/core/src/exprs/filter.rs @@ -16,3 +16,56 @@ * specific language governing permissions and limitations * under the License. */ + +use crate::exprs::ExprOperator; +use anyhow::{Context, Result}; +use arrow_array::StringArray; +use arrow_array::{ArrayRef, Scalar}; +use arrow_cast::{cast_with_options, CastOptions}; +use arrow_schema::DataType; +use std::str::FromStr; + +#[derive(Debug, Clone)] +pub struct Filter { + pub field_name: String, + pub operator: ExprOperator, + pub value: String, +} + +impl Filter { + pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { + let cast_options = CastOptions { + safe: false, + format_options: Default::default(), + }; + + let value = StringArray::from(Vec::from(value)); + + Ok(Scalar::new(cast_with_options( + &value, + data_type, + &cast_options, + )?)) + } +} + +impl TryFrom<(&str, &str, &str)> for Filter { + type Error = anyhow::Error; + + fn try_from(filter: (&str, &str, &str)) -> Result { + let (field_name, operator_str, value_str) = filter; + + let field_name = field_name.to_string(); + + let operator = ExprOperator::from_str(operator_str) + .with_context(|| format!("Unsupported operator: {}", operator_str))?; + + let value = value_str.to_string(); + + Ok(Filter { + field_name, + operator, + value, + }) + } +} diff --git a/crates/core/src/table/fs_view.rs b/crates/core/src/table/fs_view.rs index fff30e1..e3ba47c 100644 --- a/crates/core/src/table/fs_view.rs +++ b/crates/core/src/table/fs_view.rs @@ -178,25 +178,17 @@ impl FileSystemView { mod tests { use crate::config::table::HudiTableConfig; use crate::config::HudiConfigs; + use crate::exprs::filter::Filter; use crate::storage::Storage; use crate::table::fs_view::FileSystemView; use crate::table::partition::PartitionPruner; - use crate::table::{PartitionFilter, Table}; + use crate::table::Table; - use anyhow::anyhow; - use arrow::datatypes::{DataType, Field, Schema}; use hudi_tests::TestTable; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use url::Url; - fn create_test_schema() -> Schema { - Schema::new(vec![ - Field::new("byteField", DataType::Int32, false), - Field::new("shortField", DataType::Int32, false), - ]) - } - async fn create_test_fs_view(base_url: Url) -> FileSystemView { FileSystemView::new( Arc::new(HudiConfigs::new([(HudiTableConfig::BasePath, base_url)])), @@ -307,17 +299,15 @@ mod tests { .unwrap(); let partition_schema = hudi_table.get_partition_schema().await.unwrap(); - let schema = create_test_schema(); - let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) - .unwrap(); - let filter_eq_300 = PartitionFilter::try_from((("shortField", "=", "300"), &schema)) - .unwrap(); + let filter_lt_20 = Filter::try_from(("byteField", "<", "20")).unwrap(); + let filter_eq_300 = Filter::try_from(("shortField", "=", "300")).unwrap(); let partition_pruner = PartitionPruner::new( &[filter_lt_20, filter_eq_300], &partition_schema, hudi_table.hudi_configs.as_ref(), ) .unwrap(); + let file_slices = fs_view .get_file_slices_as_of("20240418173235694", &partition_pruner, excludes) .await diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 2ad2f45..3bbea8f 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -96,11 +96,12 @@ use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig; use crate::config::table::HudiTableConfig::PartitionFields; use crate::config::HudiConfigs; +use crate::exprs::filter::Filter; use crate::file_group::reader::FileGroupReader; use crate::file_group::FileSlice; use crate::table::builder::TableBuilder; use crate::table::fs_view::FileSystemView; -use crate::table::partition::{PartitionFilter, PartitionPruner}; +use crate::table::partition::PartitionPruner; use crate::table::timeline::Timeline; pub mod builder; @@ -191,10 +192,11 @@ impl Table { /// The file slices are split into `n` chunks. /// /// If the [AsOfTimestamp] configuration is set, the file slices at the specified timestamp will be returned. + /// pub async fn get_file_slices_splits( &self, n: usize, - filters: &[PartitionFilter], + filters: &[Filter], ) -> Result>> { let file_slices = self.get_file_slices(filters).await?; if file_slices.is_empty() { @@ -213,7 +215,7 @@ impl Table { /// Get all the [FileSlice]s in the table. /// /// If the [AsOfTimestamp] configuration is set, the file slices at the specified timestamp will be returned. - pub async fn get_file_slices(&self, filters: &[PartitionFilter]) -> Result> { + pub async fn get_file_slices(&self, filters: &[Filter]) -> Result> { if let Some(timestamp) = self.hudi_configs.try_get(AsOfTimestamp) { self.get_file_slices_as_of(timestamp.to::().as_str(), filters) .await @@ -228,7 +230,7 @@ impl Table { async fn get_file_slices_as_of( &self, timestamp: &str, - filters: &[PartitionFilter], + filters: &[Filter], ) -> Result> { let excludes = self.timeline.get_replaced_file_groups().await?; let partition_schema = self.get_partition_schema().await?; @@ -242,7 +244,7 @@ impl Table { /// Get all the latest records in the table. /// /// If the [AsOfTimestamp] configuration is set, the records at the specified timestamp will be returned. - pub async fn read_snapshot(&self, filters: &[PartitionFilter]) -> Result> { + pub async fn read_snapshot(&self, filters: &[Filter]) -> Result> { if let Some(timestamp) = self.hudi_configs.try_get(AsOfTimestamp) { self.read_snapshot_as_of(timestamp.to::().as_str(), filters) .await @@ -257,7 +259,7 @@ impl Table { async fn read_snapshot_as_of( &self, timestamp: &str, - filters: &[PartitionFilter], + filters: &[Filter], ) -> Result> { let file_slices = self .get_file_slices_as_of(timestamp, filters) @@ -275,10 +277,7 @@ impl Table { } #[cfg(test)] - async fn get_file_paths_with_filters( - &self, - filters: &[PartitionFilter], - ) -> Result> { + async fn get_file_paths_with_filters(&self, filters: &[Filter]) -> Result> { let mut file_paths = Vec::new(); for f in self.get_file_slices(filters).await? { file_paths.push(f.base_file_path().to_string()); @@ -293,15 +292,14 @@ impl Table { #[cfg(test)] mod tests { - use arrow::datatypes::{DataType, Field, Schema}; + use crate::table::Filter; use arrow_array::StringArray; + use hudi_tests::{assert_not, TestTable}; use std::collections::HashSet; use std::fs::canonicalize; use std::path::PathBuf; use std::{env, panic}; use url::Url; - use hudi_tests::{assert_not, TestTable}; - use crate::table::PartitionFilter; use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig::{ @@ -313,14 +311,7 @@ mod tests { use crate::config::HUDI_CONF_DIR; use crate::storage::utils::join_url_segments; use crate::storage::Storage; - use crate::table::{anyhow, Table}; - - fn create_test_schema() -> Schema { - Schema::new(vec![ - Field::new("byteField", DataType::Int32, false), - Field::new("shortField", DataType::Int32, false), - ]) - } + use crate::table::Table; /// Test helper to create a new `Table` instance without validating the configuration. /// @@ -730,12 +721,9 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let schema = create_test_schema(); - let filter_ge_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) - .unwrap(); + let filter_ge_10 = Filter::try_from(("byteField", ">=", "10")).unwrap(); - let filter_lt_30 = PartitionFilter::try_from((("byteField", "<", "30"), &schema)) - .unwrap(); + let filter_lt_30 = Filter::try_from(("byteField", "<", "30")).unwrap(); let actual = hudi_table .get_file_paths_with_filters(&[filter_ge_10, filter_lt_30]) @@ -752,8 +740,7 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let filter_gt_30 = PartitionFilter::try_from((("byteField", ">", "30"), &schema)) - .unwrap(); + let filter_gt_30 = Filter::try_from(("byteField", ">", "30")).unwrap(); let actual = hudi_table .get_file_paths_with_filters(&[filter_gt_30]) .await @@ -787,13 +774,9 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let schema = create_test_schema(); - let filter_gte_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) - .unwrap(); - let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) - .unwrap(); - let filter_ne_100 = PartitionFilter::try_from((("shortField", "!=", "100"), &schema)) - .unwrap(); + let filter_gte_10 = Filter::try_from(("byteField", ">=", "10")).unwrap(); + let filter_lt_20 = Filter::try_from(("byteField", "<", "20")).unwrap(); + let filter_ne_100 = Filter::try_from(("shortField", "!=", "100")).unwrap(); let actual = hudi_table .get_file_paths_with_filters(&[filter_gte_10, filter_lt_20, filter_ne_100]) @@ -808,10 +791,8 @@ mod tests { .into_iter() .collect::>(); assert_eq!(actual, expected); - let filter_lt_20 = PartitionFilter::try_from((("byteField", ">", "20"), &schema)) - .unwrap(); - let filter_eq_300 = PartitionFilter::try_from((("shortField", "=", "300"), &schema)) - .unwrap(); + let filter_lt_20 = Filter::try_from(("byteField", ">", "20")).unwrap(); + let filter_eq_300 = Filter::try_from(("shortField", "=", "300")).unwrap(); let actual = hudi_table .get_file_paths_with_filters(&[filter_lt_20, filter_eq_300]) @@ -828,13 +809,9 @@ mod tests { let base_url = TestTable::V6ComplexkeygenHivestyle.url(); let hudi_table = Table::new(base_url.path()).await.unwrap(); - let schema = create_test_schema(); - let filter_gte_10 = PartitionFilter::try_from((("byteField", ">=", "10"), &schema)) - .unwrap(); - let filter_lt_20 = PartitionFilter::try_from((("byteField", "<", "20"), &schema)) - .unwrap(); - let filter_ne_100 = PartitionFilter::try_from((("shortField", "!=", "100"), &schema)) - .unwrap(); + let filter_gte_10 = Filter::try_from(("byteField", ">=", "10")).unwrap(); + let filter_lt_20 = Filter::try_from(("byteField", "<", "20")).unwrap(); + let filter_ne_100 = Filter::try_from(("shortField", "!=", "100")).unwrap(); let records = hudi_table .read_snapshot(&[filter_gte_10, filter_lt_20, filter_ne_100]) diff --git a/crates/core/src/table/partition.rs b/crates/core/src/table/partition.rs index 7774a36..7f7d046 100644 --- a/crates/core/src/table/partition.rs +++ b/crates/core/src/table/partition.rs @@ -18,18 +18,17 @@ */ use crate::config::table::HudiTableConfig; use crate::config::HudiConfigs; +use crate::exprs::filter::Filter; use crate::exprs::ExprOperator; use anyhow::anyhow; +use anyhow::Context; use anyhow::Result; +use arrow_array::StringArray; use arrow_array::{ArrayRef, Scalar}; +use arrow_cast::{cast_with_options, CastOptions}; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::Schema; -use anyhow::Context; -use arrow_array::StringArray; -use arrow_cast::{cast_with_options, CastOptions}; use arrow_schema::{DataType, Field}; -use std::str::FromStr; - use std::collections::HashMap; use std::sync::Arc; @@ -45,11 +44,14 @@ pub struct PartitionPruner { impl PartitionPruner { pub fn new( - and_filters: &[PartitionFilter], + and_filters: &[Filter], partition_schema: &Schema, hudi_configs: &HudiConfigs, ) -> Result { - let and_filters = and_filters.to_vec(); + let and_filters = and_filters + .iter() + .map(|filter| PartitionFilter::try_from((filter.clone(), partition_schema))) + .collect::>>()?; let schema = Arc::new(partition_schema.clone()); let is_hive_style: bool = hudi_configs @@ -164,20 +166,22 @@ pub struct PartitionFilter { pub value: Scalar, } -impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { +impl TryFrom<(Filter, &Schema)> for PartitionFilter { type Error = anyhow::Error; - fn try_from((filter, partition_schema): ((&str, &str, &str), &Schema)) -> Result { - let (field_name, operator_str, value_str) = filter; - + fn try_from((filter, partition_schema): (Filter, &Schema)) -> Result { let field: &Field = partition_schema - .field_with_name(field_name) - .with_context(|| format!("Field '{}' not found in partition schema", field_name))?; - - let operator = ExprOperator::from_str(operator_str) - .with_context(|| format!("Unsupported operator: {}", operator_str))?; - - let value = &[value_str]; + .field_with_name(&filter.field_name) + .with_context(|| { + format!( + "Field '{}' not found in partition schema", + &filter.field_name + ) + })?; + + let operator = filter.operator; + format!("{}", filter.value); + let value = &[filter.value.as_str()]; let value = Self::cast_value(value, field.data_type()) .with_context(|| format!("Unable to cast {:?} as {:?}", value, field.data_type()))?; @@ -213,10 +217,10 @@ mod tests { use crate::config::table::HudiTableConfig::{ IsHiveStylePartitioning, IsPartitionPathUrlencoded, }; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::Date32Array; use hudi_tests::assert_not; - use crate::exprs::ExprOperator; - use arrow_array::Datum; use std::str::FromStr; fn create_test_schema() -> Schema { @@ -238,10 +242,8 @@ mod tests { let schema = create_test_schema(); let configs = create_hudi_configs(true, false); - let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) - .unwrap(); - let filter_eq_a = PartitionFilter::try_from((("category", "=", "A"), &schema)) - .unwrap(); + let filter_gt_date = Filter::try_from(("date", ">", "2023-01-01")).unwrap(); + let filter_eq_a = Filter::try_from(("category", "=", "A")).unwrap(); let pruner = PartitionPruner::new(&[filter_gt_date, filter_eq_a], &schema, &configs); assert!(pruner.is_ok()); @@ -268,8 +270,7 @@ mod tests { let pruner_empty = PartitionPruner::new(&[], &schema, &configs).unwrap(); assert!(pruner_empty.is_empty()); - let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) - .unwrap(); + let filter_gt_date = Filter::try_from(("date", ">", "2023-01-01")).unwrap(); let pruner_non_empty = PartitionPruner::new(&[filter_gt_date], &schema, &configs).unwrap(); assert_not!(pruner_non_empty.is_empty()); } @@ -279,12 +280,9 @@ mod tests { let schema = create_test_schema(); let configs = create_hudi_configs(true, false); - let filter_gt_date = PartitionFilter::try_from((("date", ">", "2023-01-01"), &schema)) - .unwrap(); - let filter_eq_a = PartitionFilter::try_from((("category", "=", "A"), &schema)) - .unwrap(); - let filter_lte_100 = PartitionFilter::try_from((("count", "<=", "100"), &schema)) - .unwrap(); + let filter_gt_date = Filter::try_from(("date", ">", "2023-01-01")).unwrap(); + let filter_eq_a = Filter::try_from(("category", "=", "A")).unwrap(); + let filter_lte_100 = Filter::try_from(("count", "<=", "100")).unwrap(); let pruner = PartitionPruner::new( &[filter_gt_date, filter_eq_a, filter_lte_100], @@ -341,68 +339,64 @@ mod tests { #[test] fn test_partition_filter_try_from_valid() { let schema = create_test_schema(); - let filter_tuple = ("date", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "date"); - assert_eq!(filter.operator, ExprOperator::Eq); - assert_eq!(filter.value.get().0.len(), 1); - - let filter_tuple = ("category", "!=", "foo"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "category"); - assert_eq!(filter.operator, ExprOperator::Ne); - assert_eq!(filter.value.get().0.len(), 1); - assert_eq!( - StringArray::from(filter.value.into_inner().to_data()).value(0), - "foo" - ) + let filter = Filter { + field_name: "date".to_string(), + operator: ExprOperator::Eq, + value: "2023-01-01".to_string(), + }; + + let partition_filter = PartitionFilter::try_from((filter, &schema)).unwrap(); + assert_eq!(partition_filter.field.name(), "date"); + assert_eq!(partition_filter.operator, ExprOperator::Eq); + + let value_inner = partition_filter.value.into_inner(); + + let date_array = value_inner.as_any().downcast_ref::().unwrap(); + + let date_value = date_array.value_as_date(0).unwrap(); + assert_eq!(date_value.to_string(), "2023-01-01"); } #[test] fn test_partition_filter_try_from_invalid_field() { let schema = create_test_schema(); - let filter_tuple = ("invalid_field", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter + let filter = Filter { + field_name: "invalid_field".to_string(), + operator: ExprOperator::Eq, + value: "2023-01-01".to_string(), + }; + let result = PartitionFilter::try_from((filter, &schema)); + assert!(result.is_err()); + assert!(result .unwrap_err() .to_string() .contains("not found in partition schema")); } - #[test] - fn test_partition_filter_try_from_invalid_operator() { - let schema = create_test_schema(); - let filter_tuple = ("date", "??", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("Unsupported operator: ??")); - } - #[test] fn test_partition_filter_try_from_invalid_value() { let schema = create_test_schema(); - let filter_tuple = ("count", "=", "not_a_number"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter.unwrap_err().to_string().contains("Unable to cast")); + let filter = Filter { + field_name: "count".to_string(), + operator: ExprOperator::Eq, + value: "not_a_number".to_string(), + }; + let result = PartitionFilter::try_from((filter, &schema)); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Unable to cast")); } #[test] fn test_partition_filter_try_from_all_operators() { let schema = create_test_schema(); for (op, _) in ExprOperator::TOKEN_OP_PAIRS { - let filter_tuple = ("count", op, "10"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok(), "Failed for operator: {}", op); - let filter = filter.unwrap(); + let filter = Filter { + field_name: "count".to_string(), + operator: ExprOperator::from_str(op).unwrap(), + value: "5".to_string(), + }; + let partition_filter = PartitionFilter::try_from((filter, &schema)); + let filter = partition_filter.unwrap(); assert_eq!(filter.field.name(), "count"); assert_eq!(filter.operator, ExprOperator::from_str(op).unwrap()); } diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index fc14f4d..f3ae5d4 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -167,11 +167,11 @@ impl TableProvider for HudiDataSource { ) -> Result> { self.table.register_storage(state.runtime_env().clone()); - // Convert Datafusion `Expr` to `PartitionFilter` - let partition_filters = convert_exprs_to_filter(filters, &self.schema()); + // Convert Datafusion `Expr` to `Filter` + let pushdown_filters = convert_exprs_to_filter(filters); let file_slices = self .table - .get_file_slices_splits(self.get_input_partitions(), partition_filters.as_slice()) + .get_file_slices_splits(self.get_input_partitions(), pushdown_filters.as_slice()) .await .map_err(|e| Execution(format!("Failed to get file slices from Hudi table: {}", e)))?; let mut parquet_file_groups: Vec> = Vec::new(); @@ -567,11 +567,11 @@ mod tests { let result = table_provider.supports_filters_pushdown(&filters).unwrap(); assert_eq!(result.len(), 6); - assert_eq!(result[0], TableProviderFilterPushDown::Exact); - assert_eq!(result[1], TableProviderFilterPushDown::Exact); + assert_eq!(result[0], TableProviderFilterPushDown::Inexact); + assert_eq!(result[1], TableProviderFilterPushDown::Inexact); assert_eq!(result[2], TableProviderFilterPushDown::Unsupported); assert_eq!(result[3], TableProviderFilterPushDown::Unsupported); assert_eq!(result[4], TableProviderFilterPushDown::Unsupported); - assert_eq!(result[5], TableProviderFilterPushDown::Exact); + assert_eq!(result[5], TableProviderFilterPushDown::Inexact); } } diff --git a/crates/datafusion/src/utils/exprs_to_filter.rs b/crates/datafusion/src/utils/exprs_to_filter.rs index 0f9b5ed..a8724a4 100644 --- a/crates/datafusion/src/utils/exprs_to_filter.rs +++ b/crates/datafusion/src/utils/exprs_to_filter.rs @@ -17,37 +17,30 @@ * under the License. */ -use arrow_array::{Array, Scalar}; -use arrow_schema::SchemaRef; -use datafusion_common::DataFusionError; use datafusion::logical_expr::Operator; use datafusion_expr::{BinaryExpr, Expr}; +use hudi_core::exprs::filter::Filter; use hudi_core::exprs::ExprOperator; -use hudi_core::table::partition::PartitionFilter; -use std::sync::Arc; // TODO: Handle other Datafusion `Expr` -/// Converts a slice of DataFusion expressions (`Expr`) into a vector of `PartitionFilter`. -/// Returns `Some(Vec)` if at least one filter is successfully converted, +/// Converts a slice of DataFusion expressions (`Expr`) into a vector of `Filter`. +/// Returns `Some(Vec)` if at least one filter is successfully converted, /// otherwise returns `None`. -pub fn convert_exprs_to_filter( - filters: &[Expr], - partition_schema: &SchemaRef, -) -> Vec { - let mut partition_filters = Vec::new(); +pub fn convert_exprs_to_filter(exprs: &[Expr]) -> Vec { + let mut filters: Vec = Vec::new(); - for expr in filters { + for expr in exprs { match expr { Expr::BinaryExpr(binary_expr) => { - if let Some(partition_filter) = convert_binary_expr(binary_expr, partition_schema) { - partition_filters.push(partition_filter); + if let Some(filter) = convert_binary_expr(binary_expr) { + filters.push(filter); } } Expr::Not(not_expr) => { // Handle NOT expressions - if let Some(partition_filter) = convert_not_expr(not_expr, partition_schema) { - partition_filters.push(partition_filter); + if let Some(filter) = convert_not_expr(not_expr) { + filters.push(filter); } } _ => { @@ -56,14 +49,11 @@ pub fn convert_exprs_to_filter( } } - partition_filters + filters } -/// Converts a binary expression (`Expr::BinaryExpr`) into a `PartitionFilter`. -fn convert_binary_expr( - binary_expr: &BinaryExpr, - partition_schema: &SchemaRef, -) -> Option { +/// Converts a binary expression (`Expr::BinaryExpr`) into a `Filter`. +fn convert_binary_expr(binary_expr: &BinaryExpr) -> Option { // extract the column and literal from the binary expression let (column, literal) = match (&*binary_expr.left, &*binary_expr.right) { (Expr::Column(col), Expr::Literal(lit)) => (col, lit), @@ -71,11 +61,7 @@ fn convert_binary_expr( _ => return None, }; - let field = partition_schema - .field_with_name(column.name()) - .map_err(|e| DataFusionError::Plan(format!("Error finding field with name '{}': {}", column.name(), e))) - .unwrap() - .clone(); + let field_name = column.name().to_string(); let operator = match binary_expr.op { Operator::Eq => ExprOperator::Eq, @@ -87,28 +73,22 @@ fn convert_binary_expr( _ => return None, }; - let value = match literal.cast_to(field.data_type()) { - Ok(value) => { - let array_ref: Arc = value.to_array().unwrap(); - Scalar::new(array_ref) - } - Err(_) => return None, - }; + let value = literal.to_string(); - Some(PartitionFilter { - field, + Some(Filter { + field_name, operator, value, }) } /// Converts a NOT expression (`Expr::Not`) into a `PartitionFilter`. -fn convert_not_expr(not_expr: &Expr, partition_schema: &SchemaRef) -> Option { +fn convert_not_expr(not_expr: &Expr) -> Option { match not_expr { Expr::BinaryExpr(ref binary_expr) => { - let mut partition_filter = convert_binary_expr(binary_expr, partition_schema)?; - partition_filter.operator = negate_operator(partition_filter.operator)?; - Some(partition_filter) + let mut filter = convert_binary_expr(binary_expr)?; + filter.operator = negate_operator(filter.operator)?; + Some(filter) } _ => None, } @@ -129,59 +109,47 @@ fn negate_operator(op: ExprOperator) -> Option { #[cfg(test)] mod tests { use super::*; - use arrow_array::{ArrayRef, Float64Array, Int32Array, Int64Array, StringArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::{col, lit}; use datafusion_expr::{BinaryExpr, Expr}; use hudi_core::exprs::ExprOperator; - use std::f64::consts::PI; + use std::str::FromStr; use std::sync::Arc; #[test] fn test_convert_simple_binary_expr() { - let partition_schema = Arc::new(Schema::new(vec![Field::new( - "partition_col", - DataType::Int32, - false, - )])); + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])); let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col("partition_col")), + Box::new(col("col")), Operator::Eq, Box::new(lit(42i32)), )); let filters = vec![expr]; - let result = convert_exprs_to_filter(&filters, &partition_schema); + let result = convert_exprs_to_filter(&filters); assert_eq!(result.len(), 1); - let expected_filter = PartitionFilter { - field: partition_schema.field(0).clone(), + let expected_filter = Filter { + field_name: schema.field(0).name().to_string(), operator: ExprOperator::Eq, - value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + value: "42".to_string(), }; - assert_eq!(result[0].field, expected_filter.field); + assert_eq!(result[0].field_name, expected_filter.field_name); assert_eq!(result[0].operator, expected_filter.operator); - assert_eq!( - *result[0].value.clone().into_inner(), - expected_filter.value.into_inner() - ); + assert_eq!(*result[0].value.clone(), expected_filter.value); } // Tests the conversion of a NOT expression #[test] fn test_convert_not_expr() { - let partition_schema = Arc::new(Schema::new(vec![Field::new( - "partition_col", - DataType::Int32, - false, - )])); + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])); let inner_expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col("partition_col")), + Box::new(col("col")), Operator::Eq, Box::new(lit(42i32)), )); @@ -189,109 +157,79 @@ mod tests { let filters = vec![expr]; - let result = convert_exprs_to_filter(&filters, &partition_schema); + let result = convert_exprs_to_filter(&filters); assert_eq!(result.len(), 1); - let expected_filter = PartitionFilter { - field: partition_schema.field(0).clone(), + let expected_filter = Filter { + field_name: schema.field(0).name().to_string(), operator: ExprOperator::Ne, - value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + value: "42".to_string(), }; - assert_eq!(result[0].field, expected_filter.field); + assert_eq!(result[0].field_name, expected_filter.field_name); assert_eq!(result[0].operator, expected_filter.operator); - assert_eq!( - *result[0].value.clone().into_inner(), - expected_filter.value.into_inner() - ); + assert_eq!(*result[0].value.clone(), expected_filter.value); } #[test] fn test_convert_binary_expr_extensive() { - // partition schema with multiple fields of different data types - let partition_schema = Arc::new(Schema::new(vec![ - Field::new("int32_col", DataType::Int32, false), - Field::new("int64_col", DataType::Int64, false), - Field::new("float64_col", DataType::Float64, false), - Field::new("string_col", DataType::Utf8, false), - ])); - // list of test cases with different operators and data types let test_cases = vec![ ( col("int32_col").eq(lit(42i32)), - Some(PartitionFilter { - field: partition_schema - .field_with_name("int32_col") - .unwrap() - .clone(), + Some(Filter { + field_name: String::from("int32_col"), operator: ExprOperator::Eq, - value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + value: String::from("42"), }), ), ( col("int64_col").gt_eq(lit(100i64)), - Some(PartitionFilter { - field: partition_schema - .field_with_name("int64_col") - .unwrap() - .clone(), + Some(Filter { + field_name: String::from("int64_col"), operator: ExprOperator::Gte, - value: Scalar::new(Arc::new(Int64Array::from(vec![100])) as ArrayRef), + value: String::from("100"), }), ), ( - col("float64_col").lt(lit(PI)), - Some(PartitionFilter { - field: partition_schema - .field_with_name("float64_col") - .unwrap() - .clone(), + col("float64_col").lt(lit(32.666)), + Some(Filter { + field_name: String::from("float64_col"), operator: ExprOperator::Lt, - value: Scalar::new(Arc::new(Float64Array::from(vec![PI])) as ArrayRef), + value: "32.666".to_string(), }), ), ( col("string_col").not_eq(lit("test")), - Some(PartitionFilter { - field: partition_schema - .field_with_name("string_col") - .unwrap() - .clone(), + Some(Filter { + field_name: String::from("string_col"), operator: ExprOperator::Ne, - value: Scalar::new(Arc::new(StringArray::from(vec!["test"])) as ArrayRef), + value: String::from("test"), }), ), ]; let filters: Vec = test_cases.iter().map(|(expr, _)| expr.clone()).collect(); - let result = convert_exprs_to_filter(&filters, &partition_schema); - let expected_filters: Vec<&PartitionFilter> = test_cases + let result = convert_exprs_to_filter(&filters); + let expected_filters: Vec<&Filter> = test_cases .iter() .filter_map(|(_, opt_filter)| opt_filter.as_ref()) .collect(); assert_eq!(result.len(), expected_filters.len()); - for (converted_filter, expected_filter) in result.iter().zip(expected_filters.iter()) { - assert_eq!(converted_filter.field.name(), expected_filter.field.name()); - assert_eq!(converted_filter.operator, expected_filter.operator); - assert_eq!( - *converted_filter.value.clone().into_inner(), - expected_filter.value.clone().into_inner() - ); + for (result, expected_filter) in result.iter().zip(expected_filters.iter()) { + assert_eq!(result.field_name, expected_filter.field_name); + assert_eq!(result.operator, expected_filter.operator); + assert_eq!(*result.value.clone(), expected_filter.value); } } // Tests conversion with different operators (e.g., <, <=, >, >=) #[test] fn test_convert_various_operators() { - let partition_schema = Arc::new(Schema::new(vec![Field::new( - "partition_col", - DataType::Int32, - false, - )])); + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])); let operators = vec![ (Operator::Lt, ExprOperator::Lt), @@ -302,48 +240,51 @@ mod tests { for (op, expected_op) in operators { let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col("partition_col")), + Box::new(col("col")), op, Box::new(lit(42i32)), )); let filters = vec![expr]; - let result = convert_exprs_to_filter(&filters, &partition_schema); + let result = convert_exprs_to_filter(&filters); assert_eq!(result.len(), 1); - let expected_filter = PartitionFilter { - field: partition_schema.field(0).clone(), + let expected_filter = Filter { + field_name: schema.field(0).name().to_string(), operator: expected_op, - value: Scalar::new(Arc::new(Int32Array::from(vec![42])) as ArrayRef), + value: String::from("42"), }; - assert_eq!(result[0].field, expected_filter.field); + assert_eq!(result[0].field_name, expected_filter.field_name); assert_eq!(result[0].operator, expected_filter.operator); - assert_eq!( - *result[0].value.clone().into_inner(), - expected_filter.value.into_inner() - ); + assert_eq!(*result[0].value.clone(), expected_filter.value); } } #[test] fn test_convert_expr_with_unsupported_operator() { - let partition_schema = Arc::new(Schema::new(vec![Field::new( - "partition_col", - DataType::Int32, - false, - )])); - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col("partition_col")), + Box::new(col("col")), Operator::And, Box::new(lit("value")), )); let filters = vec![expr]; - let result = convert_exprs_to_filter(&filters, &partition_schema); + let result = convert_exprs_to_filter(&filters); assert!(result.is_empty()); } + + #[test] + fn test_negate_operator_for_all_ops() { + for (op, _) in ExprOperator::TOKEN_OP_PAIRS { + if let Some(negated_op) = negate_operator(ExprOperator::from_str(op).unwrap()) { + let double_negated_op = negate_operator(negated_op) + .expect("Negation should be defined for all operators"); + + assert_eq!(double_negated_op, ExprOperator::from_str(op).unwrap()); + } + } + } } diff --git a/python/src/internal.rs b/python/src/internal.rs index 82abb54..30deb28 100644 --- a/python/src/internal.rs +++ b/python/src/internal.rs @@ -23,12 +23,11 @@ use std::sync::OnceLock; use anyhow::Context; use arrow::pyarrow::ToPyArrow; -use arrow_schema::Schema; use pyo3::exceptions::PyValueError; use pyo3::{pyclass, pyfunction, pymethods, PyErr, PyObject, PyResult, Python}; use tokio::runtime::Runtime; -use hudi::exprs::PartitionFilter; +use hudi::exprs::filter::Filter; use hudi::file_group::reader::FileGroupReader; use hudi::file_group::FileSlice; use hudi::table::builder::TableBuilder; @@ -164,14 +163,11 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult>> { - let schema: Schema = rt().block_on(self.inner.get_schema())?; - let partition_filters = convert_filters(filters, &schema)?; + let filters = convert_filters(filters)?; py.allow_threads(|| { - let file_slices = rt().block_on( - self.inner - .get_file_slices_splits(n, partition_filters.as_slice()), - )?; + let file_slices = + rt().block_on(self.inner.get_file_slices_splits(n, filters.as_slice()))?; Ok(file_slices .iter() .map(|inner_vec| inner_vec.iter().map(convert_file_slice).collect()) @@ -185,12 +181,10 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult> { - let schema: Schema = rt().block_on(self.inner.get_schema())?; - let partition_filters = convert_filters(filters, &schema)?; + let filters = convert_filters(filters)?; py.allow_threads(|| { - let file_slices = - rt().block_on(self.inner.get_file_slices(partition_filters.as_slice()))?; + let file_slices = rt().block_on(self.inner.get_file_slices(filters.as_slice()))?; Ok(file_slices.iter().map(convert_file_slice).collect()) }) } @@ -206,27 +200,19 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult { - let schema: Schema = rt().block_on(self.inner.get_schema())?; - let partition_filters = convert_filters(filters, &schema)?; + let filters = convert_filters(filters)?; - rt().block_on(self.inner.read_snapshot(partition_filters.as_slice()))? + rt().block_on(self.inner.read_snapshot(filters.as_slice()))? .to_pyarrow(py) } } -fn convert_filters( - filters: Option>, - partition_schema: &Schema, -) -> PyResult> { +fn convert_filters(filters: Option>) -> PyResult> { filters .unwrap_or_default() .into_iter() .map(|(field, op, value)| { - PartitionFilter::try_from(( - (field.as_str(), op.as_str(), value.as_str()), - partition_schema, - )) - .map_err(|e| { + Filter::try_from((field.as_str(), op.as_str(), value.as_str())).map_err(|e| { PyValueError::new_err(format!( "Invalid filter ({}, {}, {}): {}", field, op, value, e diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index baebdff..9448a46 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -193,3 +193,17 @@ def test_read_table_as_of_timestamp(get_sample_table): "fare": 34.15, }, ] + +def test_convert_filters_valid(get_sample_table): + table_path = get_sample_table + table = HudiTable(table_path) + + filters = [ + ("city", "=", "san_francisco"), + ("fare", ">", "50"), + ("driver", "!=", "john_doe"), + ] + + file_slices = table.get_file_slices(filters=filters) + + assert len(file_slices) > 0