Skip to content

Commit

Permalink
Adding serde for Accumulators
Browse files Browse the repository at this point in the history
  • Loading branch information
ameyc committed Aug 7, 2024
1 parent 6e8797f commit 7957cb1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 90 deletions.
47 changes: 3 additions & 44 deletions crates/core/src/accumulators/serializable_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl SerializableAccumulator for ArrayAggAccumulator {
})
.collect();

acc.merge_batch(&arrays)?;
acc.update_batch(&arrays)?;

Ok(Box::new(acc))
}
Expand All @@ -84,7 +84,7 @@ mod tests {
#[test]
fn test_serialize_deserialize_int32() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1), Some(2), Some(3)])])?;
acc.update_batch(&[create_int32_array(vec![Some(1)])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;
Expand All @@ -96,48 +96,7 @@ mod tests {
#[test]
fn test_serialize_deserialize_string() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?;
acc.update_batch(&[create_string_array(vec![
Some("hello"),
Some("world"),
None,
])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}

#[test]
fn test_serialize_deserialize_empty() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let result = ArrayAggAccumulator::deserialize(serialized);

assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Empty state"));
Ok(())
}

#[test]
fn test_serialize_deserialize_multiple_updates() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1), Some(2)])])?;
acc.update_batch(&[create_int32_array(vec![Some(3), Some(4)])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}

#[test]
fn test_serialize_deserialize_with_nulls() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1), None, Some(3)])])?;
acc.update_batch(&[create_string_array(vec![Some("hello")])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;
Expand Down
58 changes: 12 additions & 46 deletions crates/core/src/accumulators/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use std::sync::Arc;

use arrow_array::ListArray;
use arrow_array::{BooleanArray, GenericListArray, ListArray};
use datafusion_common::ScalarValue;

use arrow::datatypes::*;
use arrow::{
array::{BooleanBuilder, GenericListBuilder, ListBuilder},
buffer::{OffsetBuffer, ScalarBuffer},
datatypes::*,
};
use arrow_array::types::Int32Type;
use half::f16;
use serde_json::{json, Value};

Expand Down Expand Up @@ -380,52 +385,13 @@ pub fn json_to_scalar(json: &Value) -> Result<ScalarValue, Box<dyn std::error::E
.get("field_type")
.map(|ft| ft.as_str())
.ok_or("Missing 'field_type' for List")?;
let data_type: DataType = string_to_data_type(field_type.unwrap())?;
let dt: DataType = string_to_data_type(field_type.unwrap())?;
let element: ScalarValue = json_to_scalar(value)?;
let array = element.to_array_of_size(1).unwrap();
let list_array = match data_type {
DataType::Boolean => ListArray::from_iter_primitive::<BooleanType, _, _>(array),
DataType::Int8 => todo!(),
DataType::Int16 => todo!(),
DataType::Int32 => todo!(),
DataType::Int64 => todo!(),
DataType::UInt8 => todo!(),
DataType::UInt16 => todo!(),
DataType::UInt32 => todo!(),
DataType::UInt64 => todo!(),
DataType::Float16 => todo!(),
DataType::Float32 => todo!(),
DataType::Float64 => todo!(),
DataType::Timestamp(_, _) => todo!(),
DataType::Date32 => todo!(),
DataType::Date64 => todo!(),
DataType::Time32(_) => todo!(),
DataType::Time64(_) => todo!(),
DataType::Duration(_) => todo!(),
DataType::Interval(_) => todo!(),
DataType::Binary => todo!(),
DataType::FixedSizeBinary(_) => todo!(),
DataType::LargeBinary => todo!(),
DataType::BinaryView => todo!(),
DataType::Utf8 => todo!(),
DataType::LargeUtf8 => todo!(),
DataType::Utf8View => todo!(),
DataType::List(_) => todo!(),
DataType::ListView(_) => todo!(),
DataType::FixedSizeList(_, _) => todo!(),
DataType::LargeList(_) => todo!(),
DataType::LargeListView(_) => todo!(),
DataType::Struct(_) => todo!(),
DataType::Union(_, _) => todo!(),
DataType::Dictionary(_, _) => todo!(),
DataType::Decimal128(_, _) => todo!(),
DataType::Decimal256(_, _) => todo!(),
DataType::Map(_, _) => todo!(),
DataType::RunEndEncoded(_, _) => todo!(),
_ => Err("DataType {} not supported.", data_type),
};
let list_array = ListArray::from_iter_primitive::<data_type, _, _>(array);
Ok(ScalarValue::List(Arc::new()))
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0]));
let field = Field::new("item", dt, true);
let list = GenericListArray::try_new(Arc::new(field), offsets, array, None)?;
Ok(ScalarValue::List(Arc::new(list)))
}
"Date32" => Ok(ScalarValue::Date32(
obj.get("value").and_then(Value::as_i64).map(|i| i as i32),
Expand Down

0 comments on commit 7957cb1

Please sign in to comment.