From 0958f35c883f31ed373bd10c7ace3d516e1c3ed8 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Thu, 26 Sep 2024 15:16:08 -0700 Subject: [PATCH] Adding serde utils for Arrow Arrays --- Cargo.lock | 1 + crates/core/Cargo.toml | 1 + crates/core/src/context.rs | 6 +- crates/core/src/logical_plan/mod.rs | 2 +- ...lesce_before_streaming_window_aggregate.rs | 8 +- crates/core/src/physical_optimizer/mod.rs | 2 +- crates/core/src/utils/mod.rs | 1 + crates/core/src/utils/serialization.rs | 441 ++++++++++++++++++ examples/examples/emit_measurements.rs | 5 +- examples/examples/kafka_rideshare.rs | 2 +- examples/examples/simple_aggregation.rs | 2 +- 11 files changed, 460 insertions(+), 11 deletions(-) create mode 100644 crates/core/src/utils/serialization.rs diff --git a/Cargo.lock b/Cargo.lock index c482824..88bb717 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1303,6 +1303,7 @@ dependencies = [ "delegate", "denormalized-common", "denormalized-orchestrator", + "flatbuffers", "futures", "half", "hashbrown", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 6d3374f..2a5b07a 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -37,3 +37,4 @@ half = "2.4.1" delegate = "0.12.0" ahash = "0.8.11" hashbrown = "0.14.5" +flatbuffers = "24.3.25" diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index dbf917b..6319160 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -8,7 +8,7 @@ use datafusion::execution::{ use crate::datasource::kafka::TopicReader; use crate::datastream::DataStream; -use crate::physical_optimizer::CoaslesceBeforeStreamingAggregate; +use crate::physical_optimizer::EnsureHashPartititionOnGroupByForStreamingAggregates; use crate::query_planner::StreamingQueryPlanner; use crate::utils::get_default_optimizer_rules; @@ -41,7 +41,9 @@ impl Context { .with_runtime_env(runtime) .with_query_planner(Arc::new(StreamingQueryPlanner {})) .with_optimizer_rules(get_default_optimizer_rules()) - .with_physical_optimizer_rule(Arc::new(CoaslesceBeforeStreamingAggregate::new())) + .with_physical_optimizer_rule(Arc::new( + EnsureHashPartititionOnGroupByForStreamingAggregates::new(), + )) .build(); Ok(Self { diff --git a/crates/core/src/logical_plan/mod.rs b/crates/core/src/logical_plan/mod.rs index 3433994..ce2846f 100644 --- a/crates/core/src/logical_plan/mod.rs +++ b/crates/core/src/logical_plan/mod.rs @@ -45,7 +45,7 @@ impl StreamingLogicalPlanBuilder for LogicalPlanBuilder { let plan = self.plan().clone(); Aggregate::try_new(Arc::new(plan.clone()), group_expr, aggr_expr) - .map(|new_aggr| { + .map(|new_aggr: Aggregate| { LogicalPlan::Extension(Extension { node: Arc::new(StreamingWindowPlanNode { window_type: window, diff --git a/crates/core/src/physical_optimizer/coalesce_before_streaming_window_aggregate.rs b/crates/core/src/physical_optimizer/coalesce_before_streaming_window_aggregate.rs index 3c7910a..8d3f796 100644 --- a/crates/core/src/physical_optimizer/coalesce_before_streaming_window_aggregate.rs +++ b/crates/core/src/physical_optimizer/coalesce_before_streaming_window_aggregate.rs @@ -10,15 +10,15 @@ use datafusion::error::Result; use crate::physical_plan::continuous::streaming_window::StreamingWindowExec; -pub struct CoaslesceBeforeStreamingAggregate {} +pub struct EnsureHashPartititionOnGroupByForStreamingAggregates {} -impl Default for CoaslesceBeforeStreamingAggregate { +impl Default for EnsureHashPartititionOnGroupByForStreamingAggregates { fn default() -> Self { Self::new() } } -impl CoaslesceBeforeStreamingAggregate { +impl EnsureHashPartititionOnGroupByForStreamingAggregates { #[allow(missing_docs)] pub fn new() -> Self { Self {} @@ -29,7 +29,7 @@ impl CoaslesceBeforeStreamingAggregate { // Franz optimizer rule, added to ensure coalescing of partitions before a global aggregate // window. This rule may be removed once we have support for two stage partial and final // aggregates a la vanilla Datafusion. -impl PhysicalOptimizerRule for CoaslesceBeforeStreamingAggregate { +impl PhysicalOptimizerRule for EnsureHashPartititionOnGroupByForStreamingAggregates { fn optimize( &self, plan: Arc, diff --git a/crates/core/src/physical_optimizer/mod.rs b/crates/core/src/physical_optimizer/mod.rs index ceb6834..369a420 100644 --- a/crates/core/src/physical_optimizer/mod.rs +++ b/crates/core/src/physical_optimizer/mod.rs @@ -1,3 +1,3 @@ pub mod coalesce_before_streaming_window_aggregate; -pub use coalesce_before_streaming_window_aggregate::CoaslesceBeforeStreamingAggregate; +pub use coalesce_before_streaming_window_aggregate::EnsureHashPartititionOnGroupByForStreamingAggregates; diff --git a/crates/core/src/utils/mod.rs b/crates/core/src/utils/mod.rs index 3710df1..4f2b535 100644 --- a/crates/core/src/utils/mod.rs +++ b/crates/core/src/utils/mod.rs @@ -2,5 +2,6 @@ pub mod arrow_helpers; mod default_optimizer_rules; pub mod row_encoder; +pub mod serialization; pub use default_optimizer_rules::get_default_optimizer_rules; diff --git a/crates/core/src/utils/serialization.rs b/crates/core/src/utils/serialization.rs new file mode 100644 index 0000000..70c198e --- /dev/null +++ b/crates/core/src/utils/serialization.rs @@ -0,0 +1,441 @@ +use std::sync::Arc; + +use arrow::{array::ArrayData, buffer::Buffer}; +use arrow_array::{make_array, Array, ArrayRef}; +use arrow_schema::{DataType, Field}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +struct SerializedArrayData { + data_type: SerializedDataType, + len: usize, + null_bit_buffer: Option>, + offset: usize, + buffers: Vec>, + child_data: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +enum SerializedDataType { + Null, + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Utf8, + List(Box), + Struct(Vec), + Dictionary(Box, Box), +} + +#[derive(Serialize, Deserialize, Clone)] +struct SerializedField { + name: String, + data_type: SerializedDataType, + nullable: bool, +} + +impl From<&DataType> for SerializedDataType { + fn from(dt: &DataType) -> Self { + match dt { + DataType::Null => SerializedDataType::Null, + DataType::Boolean => SerializedDataType::Boolean, + DataType::Int8 => SerializedDataType::Int8, + DataType::Int16 => SerializedDataType::Int16, + DataType::Int32 => SerializedDataType::Int32, + DataType::Int64 => SerializedDataType::Int64, + DataType::UInt8 => SerializedDataType::UInt8, + DataType::UInt16 => SerializedDataType::UInt16, + DataType::UInt32 => SerializedDataType::UInt32, + DataType::UInt64 => SerializedDataType::UInt64, + DataType::Float16 => SerializedDataType::Float16, + DataType::Float32 => SerializedDataType::Float32, + DataType::Float64 => SerializedDataType::Float64, + DataType::Utf8 => SerializedDataType::Utf8, + DataType::List(field) => { + SerializedDataType::List(Box::new(SerializedField::from(field.as_ref()))) + } + DataType::Struct(fields) => SerializedDataType::Struct( + fields + .iter() + .map(|f| SerializedField { + name: f.name().clone(), + data_type: SerializedDataType::from(f.data_type()), + nullable: f.is_nullable(), + }) + .collect(), + ), + DataType::Dictionary(key_type, value_type) => SerializedDataType::Dictionary( + Box::new(SerializedDataType::from(key_type.as_ref())), + Box::new(SerializedDataType::from(value_type.as_ref())), + ), + // Add other types as needed + _ => unimplemented!("Serialization not implemented for this data type"), + } + } +} + +impl From for DataType { + fn from(sdt: SerializedDataType) -> Self { + match sdt { + SerializedDataType::Null => DataType::Null, + SerializedDataType::Boolean => DataType::Boolean, + SerializedDataType::Int8 => DataType::Int8, + SerializedDataType::Int16 => DataType::Int16, + SerializedDataType::Int32 => DataType::Int32, + SerializedDataType::Int64 => DataType::Int64, + SerializedDataType::UInt8 => DataType::UInt8, + SerializedDataType::UInt16 => DataType::UInt16, + SerializedDataType::UInt32 => DataType::UInt32, + SerializedDataType::UInt64 => DataType::UInt64, + SerializedDataType::Float16 => DataType::Float16, + SerializedDataType::Float32 => DataType::Float32, + SerializedDataType::Float64 => DataType::Float64, + SerializedDataType::Utf8 => DataType::Utf8, + SerializedDataType::List(field) => DataType::List(Arc::new(Field::from(*field))), + SerializedDataType::Struct(fields) => { + DataType::Struct(fields.into_iter().map(Field::from).collect()) + } + SerializedDataType::Dictionary(key_type, value_type) => { + DataType::Dictionary(Box::new((*key_type).into()), Box::new((*value_type).into())) + } + } + } +} + +impl From<&Field> for SerializedField { + fn from(field: &Field) -> Self { + SerializedField { + name: field.name().to_string(), + data_type: SerializedDataType::from(field.data_type()), + nullable: field.is_nullable(), + } + } +} + +impl From for Field { + fn from(sf: SerializedField) -> Self { + Field::new(sf.name, sf.data_type.into(), sf.nullable) + } +} + +pub fn serialize_array(array: &ArrayRef) -> Result, Box> { + let array_data = array.to_data(); + let serialized = serialize_array_data(&array_data)?; + bincode::serialize(&serialized).map_err(|e| e.into()) +} + +fn serialize_array_data( + array_data: &ArrayData, +) -> Result> { + Ok(SerializedArrayData { + data_type: SerializedDataType::from(array_data.data_type()), + len: array_data.len(), + null_bit_buffer: array_data.nulls().map(|n| n.buffer().as_slice().to_vec()), + offset: array_data.offset(), + buffers: array_data + .buffers() + .iter() + .map(|b| b.as_slice().to_vec()) + .collect(), + child_data: array_data + .child_data() + .iter() + .map(|c| serialize_array_data(c)) + .collect::, _>>()?, + }) +} + +pub fn deserialize_array(bytes: &[u8]) -> Result> { + let serialized: SerializedArrayData = bincode::deserialize(bytes)?; + let array_data = deserialize_array_data(&serialized)?; + Ok(make_array(array_data)) +} + +fn deserialize_array_data( + serialized: &SerializedArrayData, +) -> Result> { + let data_type: DataType = serialized.data_type.clone().into(); + + if serialized.len == 0 { + return Ok(ArrayData::new_empty(&data_type)); + } + + let buffers: Vec = serialized + .buffers + .iter() + .map(|buf| Buffer::from_vec(buf.clone())) + .collect(); + + let child_data: Vec = serialized + .child_data + .iter() + .map(|c| deserialize_array_data(c)) + .collect::, _>>()?; + + let null_buffer = serialized + .null_bit_buffer + .as_ref() + .map(|buf| Buffer::from_vec(buf.clone())); + + ArrayData::try_new( + data_type, + serialized.len, + null_buffer, + serialized.offset, + buffers, + child_data, + ) + .map_err(|e| e.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::AsArray, + datatypes::{ + ArrowDictionaryKeyType, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, + }; + use arrow_array::{ + Array, ArrowPrimitiveType, BooleanArray, DictionaryArray, Float64Array, Int32Array, + ListArray, StringArray, StructArray, + }; + use arrow_schema::{Field, Fields}; + use std::sync::Arc; + + fn test_roundtrip(array: A) { + let array_ref: ArrayRef = Arc::new(array.clone()); + let serialized = serialize_array(&array_ref).unwrap(); + let deserialized: ArrayRef = deserialize_array(&serialized).unwrap(); + + assert_eq!(array_ref.len(), deserialized.len()); + assert_eq!(array_ref.data_type(), deserialized.data_type()); + assert_eq!(array_ref.null_count(), deserialized.null_count()); + + // Compare the actual data + compare_arrays(&array_ref, &deserialized); + } + + fn compare_arrays(left: &ArrayRef, right: &ArrayRef) { + assert_eq!(left.len(), right.len()); + assert_eq!(left.data_type(), right.data_type()); + + for i in 0..left.len() { + assert_eq!(left.is_null(i), right.is_null(i)); + if !left.is_null(i) { + compare_array_values(left, right, i); + } + } + } + + fn compare_dictionary( + left: &ArrayRef, + right: &ArrayRef, + index: usize, + ) { + let l = left.as_any().downcast_ref::>().unwrap(); + let r = right.as_any().downcast_ref::>().unwrap(); + assert_eq!(l.key(index), r.key(index)); + compare_arrays(l.values(), r.values()); + } + + fn compare_array_values(left: &ArrayRef, right: &ArrayRef, index: usize) { + match left.data_type() { + DataType::Boolean => { + let l = left.as_boolean(); + let r = right.as_boolean(); + assert_eq!(l.value(index), r.value(index)); + } + DataType::Int32 => { + let l = left.as_primitive::(); + let r = right.as_primitive::(); + assert_eq!(l.value(index), r.value(index)); + } + DataType::Float64 => { + let l = left.as_primitive::(); + let r = right.as_primitive::(); + assert!((l.value(index) - r.value(index)).abs() < f64::EPSILON); + } + DataType::Utf8 => { + let l = left.as_string::(); + let r = right.as_string::(); + assert_eq!(l.value(index), r.value(index)); + } + DataType::LargeUtf8 => { + let l = left.as_string::(); + let r = right.as_string::(); + assert_eq!(l.value(index), r.value(index)); + } + DataType::List(_) => { + let l = left.as_list::(); + let r = right.as_list::(); + compare_arrays(&l.value(index), &r.value(index)); + } + DataType::LargeList(_) => { + let l = left.as_list::(); + let r = right.as_list::(); + compare_arrays(&l.value(index), &r.value(index)); + } + DataType::Struct(_) => { + let l = left.as_struct(); + let r = right.as_struct(); + for j in 0..l.num_columns() { + compare_arrays(&l.column(j), &r.column(j)); + } + } + DataType::Dictionary(key_type, _) => match key_type.as_ref() { + DataType::Int8 => compare_dictionary::(left, right, index), + DataType::Int16 => compare_dictionary::(left, right, index), + DataType::Int32 => compare_dictionary::(left, right, index), + DataType::Int64 => compare_dictionary::(left, right, index), + DataType::UInt8 => compare_dictionary::(left, right, index), + DataType::UInt16 => compare_dictionary::(left, right, index), + DataType::UInt32 => compare_dictionary::(left, right, index), + DataType::UInt64 => compare_dictionary::(left, right, index), + _ => panic!("Unsupported dictionary key type: {:?}", key_type), + }, + // Add other data types as needed + _ => panic!( + "Unsupported data type for comparison: {:?}", + left.data_type() + ), + } + } + + #[test] + fn test_boolean_array() { + let array = BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]); + test_roundtrip(array); + } + + #[test] + fn test_int32_array() { + let array = Int32Array::from(vec![Some(1), None, Some(3), Some(-5)]); + test_roundtrip(array); + } + + #[test] + fn test_float64_array() { + let array = Float64Array::from(vec![Some(1.1), None, Some(3.3), Some(-5.5)]); + test_roundtrip(array); + } + + #[test] + fn test_string_array() { + let array = StringArray::from(vec![Some("hello"), None, Some("world"), Some("")]); + test_roundtrip(array); + } + + #[test] + fn test_list_array() { + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let array = ListArray::from_iter_primitive::(data); + test_roundtrip(array); + } + + #[test] + fn test_struct_array() { + let boolean = BooleanArray::from(vec![Some(true), None, Some(false)]); + let int = Int32Array::from(vec![Some(1), Some(2), None]); + let fields = Fields::from(vec![ + Field::new("b", DataType::Boolean, true), + Field::new("i", DataType::Int32, true), + ]); + let array = + StructArray::try_new(fields, vec![Arc::new(boolean), Arc::new(int)], None).unwrap(); + test_roundtrip(array); + } + + #[test] + fn test_dictionary_array() { + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(1)]); + let values = StringArray::from(vec!["foo", "bar", "baz"]); + let array = DictionaryArray::try_new(keys, Arc::new(values)).unwrap(); + test_roundtrip(array); + } + + #[test] + fn test_empty_array() { + let array_data = ArrayData::new_empty(&DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + )))); + let array = make_array(array_data); + test_roundtrip(array); + } + + #[test] + fn test_all_nulls_array() { + let array = Int32Array::from(vec![None, None, None]); + test_roundtrip(array); + } + + #[test] + fn test_large_array() { + let array = Int32Array::from_iter((0..10000).map(Some)); + test_roundtrip(array); + } + + #[test] + fn test_nested_array() { + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] + let value_offsets = Buffer::from_slice_ref([0, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + // 01011001 00000001 + let mut null_bits: [u8; 2] = [0; 2]; + arrow::util::bit_util::set_bit(&mut null_bits, 0); + arrow::util::bit_util::set_bit(&mut null_bits, 3); + arrow::util::bit_util::set_bit(&mut null_bits, 4); + arrow::util::bit_util::set_bit(&mut null_bits, 6); + arrow::util::bit_util::set_bit(&mut null_bits, 8); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data = ArrayData::builder(list_data_type) + .len(9) + .add_buffer(value_offsets) + .add_child_data(value_data.clone()) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let outer = make_array(list_data); + + test_roundtrip(outer); + } + + #[test] + fn test_serialized_size() { + let int_array = Int32Array::from_iter((0..1000).map(Some)); + let array = Arc::new(int_array) as ArrayRef; + let serialized = serialize_array(&array).unwrap(); + assert!(serialized.len() < 5000, "Serialized size is too large"); + } + + #[test] + fn test_error_handling() { + let invalid_bytes = vec![0, 1, 2, 3]; + assert!(deserialize_array(&invalid_bytes).is_err()); + } +} diff --git a/examples/examples/emit_measurements.rs b/examples/examples/emit_measurements.rs index 749a2d4..94e5976 100644 --- a/examples/examples/emit_measurements.rs +++ b/examples/examples/emit_measurements.rs @@ -18,7 +18,10 @@ async fn main() -> Result<()> { let mut tasks = tokio::task::JoinSet::new(); let producer: FutureProducer = ClientConfig::new() - .set("bootstrap.servers", String::from("localhost:9092")) + .set( + "bootstrap.servers", + String::from("localhost:19092,localhost:29092,localhost:39092"), + ) .set("message.timeout.ms", "100") .create() .expect("Producer creation error"); diff --git a/examples/examples/kafka_rideshare.rs b/examples/examples/kafka_rideshare.rs index b520a7b..1c24e29 100644 --- a/examples/examples/kafka_rideshare.rs +++ b/examples/examples/kafka_rideshare.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { } }"#; - let bootstrap_servers = String::from("localhost:9092"); + let bootstrap_servers = String::from("localhost:19092,localhost:29092,localhost:39092"); let ctx = Context::new()?; diff --git a/examples/examples/simple_aggregation.rs b/examples/examples/simple_aggregation.rs index 98e396f..f0d45da 100644 --- a/examples/examples/simple_aggregation.rs +++ b/examples/examples/simple_aggregation.rs @@ -20,7 +20,7 @@ async fn main() -> Result<()> { let sample_event = get_sample_json(); - let bootstrap_servers = String::from("localhost:9092"); + let bootstrap_servers = String::from("localhost:19092,localhost:29092,localhost:39092"); let ctx = Context::new()?; let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone());