Skip to content

Commit

Permalink
checkpointing checkpoint code
Browse files Browse the repository at this point in the history
  • Loading branch information
ameyc committed Aug 6, 2024
1 parent fde9ecb commit 6e8797f
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 15 deletions.
2 changes: 2 additions & 0 deletions crates/core/src/accumulators/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod serializable_accumulator;
mod serialize;
148 changes: 148 additions & 0 deletions crates/core/src/accumulators/serializable_accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use arrow::array::{Array, ArrayRef};
use datafusion::functions_aggregate::array_agg::ArrayAggAccumulator;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
use serde::{Deserialize, Serialize};

use super::serialize::SerializableScalarValue;

pub trait SerializableAccumulator {
fn serialize(&mut self) -> Result<String>;
fn deserialize(bytes: String) -> Result<Box<dyn Accumulator>>;
}

#[derive(Debug, Serialize, Deserialize)]
struct SerializableArrayAggState {
state: Vec<SerializableScalarValue>,
}

impl SerializableAccumulator for ArrayAggAccumulator {
fn serialize(&mut self) -> Result<String> {
let state = self.state()?;
let serializable_state = SerializableArrayAggState {
state: state
.into_iter()
.map(SerializableScalarValue::from)
.collect(),
};
Ok(serde_json::to_string(&serializable_state).unwrap())
}

fn deserialize(bytes: String) -> Result<Box<dyn Accumulator>> {
let serializable_state: SerializableArrayAggState =
serde_json::from_str(bytes.as_str()).unwrap();
let state: Vec<ScalarValue> = serializable_state
.state
.into_iter()
.map(ScalarValue::from)
.collect();

// Infer the datatype from the first element of the state
let datatype = if let Some(ScalarValue::List(list)) = state.first() {
list.data_type().clone()
} else {
return Err(datafusion_common::DataFusionError::Internal(
"Invalid state for ArrayAggAccumulator".to_string(),
));
};

let mut acc = ArrayAggAccumulator::try_new(&datatype)?;

// Convert ScalarValue to ArrayRef for merge_batch
let arrays: Vec<ArrayRef> = state
.into_iter()
.filter_map(|s| {
if let ScalarValue::List(list) = s {
Some(list.values().clone())
} else {
None
}
})
.collect();

acc.merge_batch(&arrays)?;

Ok(Box::new(acc))
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::DataType;
use std::sync::Arc;

fn create_int32_array(values: Vec<Option<i32>>) -> ArrayRef {
Arc::new(Int32Array::from(values)) as ArrayRef
}

fn create_string_array(values: Vec<Option<&str>>) -> ArrayRef {
Arc::new(StringArray::from(values)) as ArrayRef
}

#[test]
fn test_serialize_deserialize_int32() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1), Some(2), Some(3)])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}

#[test]
fn test_serialize_deserialize_string() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?;
acc.update_batch(&[create_string_array(vec![
Some("hello"),
Some("world"),
None,
])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}

#[test]
fn test_serialize_deserialize_empty() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let result = ArrayAggAccumulator::deserialize(serialized);

assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Empty state"));
Ok(())
}

#[test]
fn test_serialize_deserialize_multiple_updates() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1), Some(2)])])?;
acc.update_batch(&[create_int32_array(vec![Some(3), Some(4)])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}

#[test]
fn test_serialize_deserialize_with_nulls() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1), None, Some(3)])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,52 @@
use std::sync::Arc;

use arrow_array::ListArray;
use datafusion_common::ScalarValue;

use arrow::array::*;
use arrow::datatypes::*;
use half::f16;
use serde_json::{json, Value};

use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SerializableScalarValue(#[serde(with = "scalar_value_serde")] ScalarValue);

impl From<ScalarValue> for SerializableScalarValue {
fn from(value: ScalarValue) -> Self {
SerializableScalarValue(value)
}
}

impl From<SerializableScalarValue> for ScalarValue {
fn from(value: SerializableScalarValue) -> Self {
value.0
}
}

mod scalar_value_serde {
use super::*;
use serde::{de::Error, Deserializer, Serializer};

pub fn serialize<S>(value: &ScalarValue, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let json = scalar_to_json(value);
json.serialize(serializer)
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<ScalarValue, D::Error>
where
D: Deserializer<'de>,
{
let json = serde_json::Value::deserialize(deserializer)?;
json_to_scalar(&json).map_err(D::Error::custom)
}
}

pub fn string_to_data_type(s: &str) -> Result<DataType, Box<dyn std::error::Error>> {
match s {
"Null" => Ok(DataType::Null),
Expand Down Expand Up @@ -336,18 +374,59 @@ pub fn json_to_scalar(json: &Value) -> Result<ScalarValue, Box<dyn std::error::E
.map(|s| base64::decode(s).unwrap());
Ok(ScalarValue::FixedSizeBinary(size, value))
}
// "List" => {
// let value = obj.get("value").ok_or("Missing 'value' for List")?;
// let field_type = obj
// .get("field_type")
// .map(|ft| ft.as_str())
// .ok_or("Missing 'field_type' for List")?;
// let data_type = string_to_data_type(field_type.unwrap())?;
// let element: ScalarValue = json_to_scalar(value)?;
// let array = element.to_array_of_size(1).unwrap();
// ListArray::from_iter_primitive::<data_type, _, _>(array);
// Ok(ScalarValue::List(Arc::new()))
// }
"List" => {
let value = obj.get("value").ok_or("Missing 'value' for List")?;
let field_type = obj
.get("field_type")
.map(|ft| ft.as_str())
.ok_or("Missing 'field_type' for List")?;
let data_type: DataType = string_to_data_type(field_type.unwrap())?;
let element: ScalarValue = json_to_scalar(value)?;
let array = element.to_array_of_size(1).unwrap();
let list_array = match data_type {
DataType::Boolean => ListArray::from_iter_primitive::<BooleanType, _, _>(array),
DataType::Int8 => todo!(),
DataType::Int16 => todo!(),
DataType::Int32 => todo!(),
DataType::Int64 => todo!(),
DataType::UInt8 => todo!(),
DataType::UInt16 => todo!(),
DataType::UInt32 => todo!(),
DataType::UInt64 => todo!(),
DataType::Float16 => todo!(),
DataType::Float32 => todo!(),
DataType::Float64 => todo!(),
DataType::Timestamp(_, _) => todo!(),
DataType::Date32 => todo!(),
DataType::Date64 => todo!(),
DataType::Time32(_) => todo!(),
DataType::Time64(_) => todo!(),
DataType::Duration(_) => todo!(),
DataType::Interval(_) => todo!(),
DataType::Binary => todo!(),
DataType::FixedSizeBinary(_) => todo!(),
DataType::LargeBinary => todo!(),
DataType::BinaryView => todo!(),
DataType::Utf8 => todo!(),
DataType::LargeUtf8 => todo!(),
DataType::Utf8View => todo!(),
DataType::List(_) => todo!(),
DataType::ListView(_) => todo!(),
DataType::FixedSizeList(_, _) => todo!(),
DataType::LargeList(_) => todo!(),
DataType::LargeListView(_) => todo!(),
DataType::Struct(_) => todo!(),
DataType::Union(_, _) => todo!(),
DataType::Dictionary(_, _) => todo!(),
DataType::Decimal128(_, _) => todo!(),
DataType::Decimal256(_, _) => todo!(),
DataType::Map(_, _) => todo!(),
DataType::RunEndEncoded(_, _) => todo!(),
_ => Err("DataType {} not supported.", data_type),
};
let list_array = ListArray::from_iter_primitive::<data_type, _, _>(array);
Ok(ScalarValue::List(Arc::new()))
}
"Date32" => Ok(ScalarValue::Date32(
obj.get("value").and_then(Value::as_i64).map(|i| i as i32),
)),
Expand Down Expand Up @@ -472,7 +551,6 @@ pub fn json_to_scalar(json: &Value) -> Result<ScalarValue, Box<dyn std::error::E
mod tests {
use super::*;
use datafusion_common::ScalarValue;
use serde_json::json;

fn test_roundtrip(scalar: ScalarValue) {
let json = scalar_to_json(&scalar);
Expand Down Expand Up @@ -570,4 +648,21 @@ mod tests {
None,
));
}

#[test]
fn test_serializable_scalar_value() {
let original = ScalarValue::Int32(Some(42));
let serializable = SerializableScalarValue::from(original.clone());

// Serialize
let serialized = serde_json::to_string(&serializable).unwrap();

// Deserialize
let deserialized: SerializableScalarValue = serde_json::from_str(&serialized).unwrap();

// Convert back to ScalarValue
let result: ScalarValue = deserialized.into();

assert_eq!(original, result);
}
}
1 change: 1 addition & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod accumulators;
pub mod config_extensions;
pub mod context;
pub mod datasource;
Expand Down
1 change: 0 additions & 1 deletion crates/core/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#[allow(dead_code)]
pub mod arrow_helpers;
mod default_optimizer_rules;
pub mod serialize;

pub use default_optimizer_rules::get_default_optimizer_rules;

0 comments on commit 6e8797f

Please sign in to comment.