diff --git a/crates/core/src/accumulators/mod.rs b/crates/core/src/accumulators/mod.rs new file mode 100644 index 0000000..be1c218 --- /dev/null +++ b/crates/core/src/accumulators/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod serializable_accumulator; +mod serialize; diff --git a/crates/core/src/accumulators/serializable_accumulator.rs b/crates/core/src/accumulators/serializable_accumulator.rs new file mode 100644 index 0000000..17b9029 --- /dev/null +++ b/crates/core/src/accumulators/serializable_accumulator.rs @@ -0,0 +1,148 @@ +use arrow::array::{Array, ArrayRef}; +use datafusion::functions_aggregate::array_agg::ArrayAggAccumulator; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Accumulator; +use serde::{Deserialize, Serialize}; + +use super::serialize::SerializableScalarValue; + +pub trait SerializableAccumulator { + fn serialize(&mut self) -> Result; + fn deserialize(bytes: String) -> Result>; +} + +#[derive(Debug, Serialize, Deserialize)] +struct SerializableArrayAggState { + state: Vec, +} + +impl SerializableAccumulator for ArrayAggAccumulator { + fn serialize(&mut self) -> Result { + let state = self.state()?; + let serializable_state = SerializableArrayAggState { + state: state + .into_iter() + .map(SerializableScalarValue::from) + .collect(), + }; + Ok(serde_json::to_string(&serializable_state).unwrap()) + } + + fn deserialize(bytes: String) -> Result> { + let serializable_state: SerializableArrayAggState = + serde_json::from_str(bytes.as_str()).unwrap(); + let state: Vec = serializable_state + .state + .into_iter() + .map(ScalarValue::from) + .collect(); + + // Infer the datatype from the first element of the state + let datatype = if let Some(ScalarValue::List(list)) = state.first() { + list.data_type().clone() + } else { + return Err(datafusion_common::DataFusionError::Internal( + "Invalid state for ArrayAggAccumulator".to_string(), + )); + }; + + let mut acc = ArrayAggAccumulator::try_new(&datatype)?; + + // Convert ScalarValue to ArrayRef for merge_batch + let arrays: Vec = state + .into_iter() + .filter_map(|s| { + if let ScalarValue::List(list) = s { + Some(list.values().clone()) + } else { + None + } + }) + .collect(); + + acc.merge_batch(&arrays)?; + + Ok(Box::new(acc)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::DataType; + use std::sync::Arc; + + fn create_int32_array(values: Vec>) -> ArrayRef { + Arc::new(Int32Array::from(values)) as ArrayRef + } + + fn create_string_array(values: Vec>) -> ArrayRef { + Arc::new(StringArray::from(values)) as ArrayRef + } + + #[test] + fn test_serialize_deserialize_int32() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1), Some(2), Some(3)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_string() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?; + acc.update_batch(&[create_string_array(vec![ + Some("hello"), + Some("world"), + None, + ])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_empty() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let result = ArrayAggAccumulator::deserialize(serialized); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Empty state")); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_multiple_updates() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1), Some(2)])])?; + acc.update_batch(&[create_int32_array(vec![Some(3), Some(4)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_with_nulls() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1), None, Some(3)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } +} diff --git a/crates/core/src/utils/serialize.rs b/crates/core/src/accumulators/serialize.rs similarity index 84% rename from crates/core/src/utils/serialize.rs rename to crates/core/src/accumulators/serialize.rs index cba469e..47c58a1 100644 --- a/crates/core/src/utils/serialize.rs +++ b/crates/core/src/accumulators/serialize.rs @@ -1,14 +1,52 @@ use std::sync::Arc; +use arrow_array::ListArray; use datafusion_common::ScalarValue; -use arrow::array::*; use arrow::datatypes::*; use half::f16; use serde_json::{json, Value}; use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct SerializableScalarValue(#[serde(with = "scalar_value_serde")] ScalarValue); + +impl From for SerializableScalarValue { + fn from(value: ScalarValue) -> Self { + SerializableScalarValue(value) + } +} + +impl From for ScalarValue { + fn from(value: SerializableScalarValue) -> Self { + value.0 + } +} + +mod scalar_value_serde { + use super::*; + use serde::{de::Error, Deserializer, Serializer}; + + pub fn serialize(value: &ScalarValue, serializer: S) -> Result + where + S: Serializer, + { + let json = scalar_to_json(value); + json.serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + json_to_scalar(&json).map_err(D::Error::custom) + } +} + pub fn string_to_data_type(s: &str) -> Result> { match s { "Null" => Ok(DataType::Null), @@ -336,18 +374,59 @@ pub fn json_to_scalar(json: &Value) -> Result { - // let value = obj.get("value").ok_or("Missing 'value' for List")?; - // let field_type = obj - // .get("field_type") - // .map(|ft| ft.as_str()) - // .ok_or("Missing 'field_type' for List")?; - // let data_type = string_to_data_type(field_type.unwrap())?; - // let element: ScalarValue = json_to_scalar(value)?; - // let array = element.to_array_of_size(1).unwrap(); - // ListArray::from_iter_primitive::(array); - // Ok(ScalarValue::List(Arc::new())) - // } + "List" => { + let value = obj.get("value").ok_or("Missing 'value' for List")?; + let field_type = obj + .get("field_type") + .map(|ft| ft.as_str()) + .ok_or("Missing 'field_type' for List")?; + let data_type: DataType = string_to_data_type(field_type.unwrap())?; + let element: ScalarValue = json_to_scalar(value)?; + let array = element.to_array_of_size(1).unwrap(); + let list_array = match data_type { + DataType::Boolean => ListArray::from_iter_primitive::(array), + DataType::Int8 => todo!(), + DataType::Int16 => todo!(), + DataType::Int32 => todo!(), + DataType::Int64 => todo!(), + DataType::UInt8 => todo!(), + DataType::UInt16 => todo!(), + DataType::UInt32 => todo!(), + DataType::UInt64 => todo!(), + DataType::Float16 => todo!(), + DataType::Float32 => todo!(), + DataType::Float64 => todo!(), + DataType::Timestamp(_, _) => todo!(), + DataType::Date32 => todo!(), + DataType::Date64 => todo!(), + DataType::Time32(_) => todo!(), + DataType::Time64(_) => todo!(), + DataType::Duration(_) => todo!(), + DataType::Interval(_) => todo!(), + DataType::Binary => todo!(), + DataType::FixedSizeBinary(_) => todo!(), + DataType::LargeBinary => todo!(), + DataType::BinaryView => todo!(), + DataType::Utf8 => todo!(), + DataType::LargeUtf8 => todo!(), + DataType::Utf8View => todo!(), + DataType::List(_) => todo!(), + DataType::ListView(_) => todo!(), + DataType::FixedSizeList(_, _) => todo!(), + DataType::LargeList(_) => todo!(), + DataType::LargeListView(_) => todo!(), + DataType::Struct(_) => todo!(), + DataType::Union(_, _) => todo!(), + DataType::Dictionary(_, _) => todo!(), + DataType::Decimal128(_, _) => todo!(), + DataType::Decimal256(_, _) => todo!(), + DataType::Map(_, _) => todo!(), + DataType::RunEndEncoded(_, _) => todo!(), + _ => Err("DataType {} not supported.", data_type), + }; + let list_array = ListArray::from_iter_primitive::(array); + Ok(ScalarValue::List(Arc::new())) + } "Date32" => Ok(ScalarValue::Date32( obj.get("value").and_then(Value::as_i64).map(|i| i as i32), )), @@ -472,7 +551,6 @@ pub fn json_to_scalar(json: &Value) -> Result