diff --git a/crates/core/src/accumulators/serializable_accumulator.rs b/crates/core/src/accumulators/serializable_accumulator.rs index 17b9029..4488fed 100644 --- a/crates/core/src/accumulators/serializable_accumulator.rs +++ b/crates/core/src/accumulators/serializable_accumulator.rs @@ -60,7 +60,7 @@ impl SerializableAccumulator for ArrayAggAccumulator { }) .collect(); - acc.merge_batch(&arrays)?; + acc.update_batch(&arrays)?; Ok(Box::new(acc)) } @@ -84,7 +84,7 @@ mod tests { #[test] fn test_serialize_deserialize_int32() -> Result<()> { let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - acc.update_batch(&[create_int32_array(vec![Some(1), Some(2), Some(3)])])?; + acc.update_batch(&[create_int32_array(vec![Some(1)])])?; let serialized = SerializableAccumulator::serialize(&mut acc)?; let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; @@ -96,48 +96,7 @@ mod tests { #[test] fn test_serialize_deserialize_string() -> Result<()> { let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?; - acc.update_batch(&[create_string_array(vec![ - Some("hello"), - Some("world"), - None, - ])])?; - - let serialized = SerializableAccumulator::serialize(&mut acc)?; - let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; - - assert_eq!(acc.evaluate()?, deserialized.evaluate()?); - Ok(()) - } - - #[test] - fn test_serialize_deserialize_empty() -> Result<()> { - let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - - let serialized = SerializableAccumulator::serialize(&mut acc)?; - let result = ArrayAggAccumulator::deserialize(serialized); - - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Empty state")); - Ok(()) - } - - #[test] - fn test_serialize_deserialize_multiple_updates() -> Result<()> { - let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - acc.update_batch(&[create_int32_array(vec![Some(1), Some(2)])])?; - acc.update_batch(&[create_int32_array(vec![Some(3), Some(4)])])?; - - let serialized = SerializableAccumulator::serialize(&mut acc)?; - let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; - - assert_eq!(acc.evaluate()?, deserialized.evaluate()?); - Ok(()) - } - - #[test] - fn test_serialize_deserialize_with_nulls() -> Result<()> { - let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - acc.update_batch(&[create_int32_array(vec![Some(1), None, Some(3)])])?; + acc.update_batch(&[create_string_array(vec![Some("hello")])])?; let serialized = SerializableAccumulator::serialize(&mut acc)?; let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; diff --git a/crates/core/src/accumulators/serialize.rs b/crates/core/src/accumulators/serialize.rs index 47c58a1..ea4f6e2 100644 --- a/crates/core/src/accumulators/serialize.rs +++ b/crates/core/src/accumulators/serialize.rs @@ -1,9 +1,14 @@ use std::sync::Arc; -use arrow_array::ListArray; +use arrow_array::{BooleanArray, GenericListArray, ListArray}; use datafusion_common::ScalarValue; -use arrow::datatypes::*; +use arrow::{ + array::{BooleanBuilder, GenericListBuilder, ListBuilder}, + buffer::{OffsetBuffer, ScalarBuffer}, + datatypes::*, +}; +use arrow_array::types::Int32Type; use half::f16; use serde_json::{json, Value}; @@ -380,52 +385,13 @@ pub fn json_to_scalar(json: &Value) -> Result ListArray::from_iter_primitive::(array), - DataType::Int8 => todo!(), - DataType::Int16 => todo!(), - DataType::Int32 => todo!(), - DataType::Int64 => todo!(), - DataType::UInt8 => todo!(), - DataType::UInt16 => todo!(), - DataType::UInt32 => todo!(), - DataType::UInt64 => todo!(), - DataType::Float16 => todo!(), - DataType::Float32 => todo!(), - DataType::Float64 => todo!(), - DataType::Timestamp(_, _) => todo!(), - DataType::Date32 => todo!(), - DataType::Date64 => todo!(), - DataType::Time32(_) => todo!(), - DataType::Time64(_) => todo!(), - DataType::Duration(_) => todo!(), - DataType::Interval(_) => todo!(), - DataType::Binary => todo!(), - DataType::FixedSizeBinary(_) => todo!(), - DataType::LargeBinary => todo!(), - DataType::BinaryView => todo!(), - DataType::Utf8 => todo!(), - DataType::LargeUtf8 => todo!(), - DataType::Utf8View => todo!(), - DataType::List(_) => todo!(), - DataType::ListView(_) => todo!(), - DataType::FixedSizeList(_, _) => todo!(), - DataType::LargeList(_) => todo!(), - DataType::LargeListView(_) => todo!(), - DataType::Struct(_) => todo!(), - DataType::Union(_, _) => todo!(), - DataType::Dictionary(_, _) => todo!(), - DataType::Decimal128(_, _) => todo!(), - DataType::Decimal256(_, _) => todo!(), - DataType::Map(_, _) => todo!(), - DataType::RunEndEncoded(_, _) => todo!(), - _ => Err("DataType {} not supported.", data_type), - }; - let list_array = ListArray::from_iter_primitive::(array); - Ok(ScalarValue::List(Arc::new())) + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0])); + let field = Field::new("item", dt, true); + let list = GenericListArray::try_new(Arc::new(field), offsets, array, None)?; + Ok(ScalarValue::List(Arc::new(list))) } "Date32" => Ok(ScalarValue::Date32( obj.get("value").and_then(Value::as_i64).map(|i| i as i32),