Skip to content

Commit

Permalink
Merge pull request #13 from probably-nothing-labs/amey/checkpointing-…
Browse files Browse the repository at this point in the history
…deux

Amey/checkpointing deux
  • Loading branch information
ameyc authored Aug 7, 2024
2 parents 57db66f + cbf6814 commit ee54420
Show file tree
Hide file tree
Showing 8 changed files with 750 additions and 934 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ itertools = { workspace = true }
serde.workspace = true
rocksdb = "0.22.0"
bincode = "1.3.3"
half = "2.4.1"
2 changes: 2 additions & 0 deletions crates/core/src/accumulators/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod serializable_accumulator;
mod serialize;
112 changes: 112 additions & 0 deletions crates/core/src/accumulators/serializable_accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use arrow::array::{Array, ArrayRef};
use datafusion::functions_aggregate::array_agg::ArrayAggAccumulator;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
use serde::{Deserialize, Serialize};

use super::serialize::SerializableScalarValue;

#[allow(dead_code)]
pub trait SerializableAccumulator {
fn serialize(&mut self) -> Result<String>;
fn deserialize(self, bytes: String) -> Result<Box<dyn Accumulator>>;
}

#[derive(Debug, Serialize, Deserialize)]
struct SerializableArrayAggState {
state: Vec<SerializableScalarValue>,
}

impl SerializableAccumulator for ArrayAggAccumulator {
fn serialize(&mut self) -> Result<String> {
let state = self.state()?;
let serializable_state = SerializableArrayAggState {
state: state
.into_iter()
.map(SerializableScalarValue::from)
.collect(),
};
Ok(serde_json::to_string(&serializable_state).unwrap())
}

fn deserialize(self, bytes: String) -> Result<Box<dyn Accumulator>> {
let serializable_state: SerializableArrayAggState =
serde_json::from_str(bytes.as_str()).unwrap();
let state: Vec<ScalarValue> = serializable_state
.state
.into_iter()
.map(ScalarValue::from)
.collect();

// Infer the datatype from the first element of the state
let datatype = if let Some(ScalarValue::List(list)) = state.first() {
list.data_type().clone()
} else {
return Err(datafusion_common::DataFusionError::Internal(
"Invalid state for ArrayAggAccumulator".to_string(),
));
};

let mut acc = ArrayAggAccumulator::try_new(&datatype)?;

// Convert ScalarValue to ArrayRef for merge_batch
let arrays: Vec<ArrayRef> = state
.into_iter()
.filter_map(|s| {
if let ScalarValue::List(list) = s {
Some(list.values().clone())
} else {
None
}
})
.collect();

acc.update_batch(&arrays)?;

Ok(Box::new(acc))
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::DataType;
use std::sync::Arc;

fn create_int32_array(values: Vec<Option<i32>>) -> ArrayRef {
Arc::new(Int32Array::from(values)) as ArrayRef
}

fn create_string_array(values: Vec<Option<&str>>) -> ArrayRef {
Arc::new(StringArray::from(values)) as ArrayRef
}

#[test]
fn test_serialize_deserialize_int32() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?;
acc.update_batch(&[create_int32_array(vec![Some(1)])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let acc2 = ArrayAggAccumulator::try_new(&DataType::Int32)?;

let mut deserialized = ArrayAggAccumulator::deserialize(acc2, serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}

#[test]
fn test_serialize_deserialize_string() -> Result<()> {
let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?;
acc.update_batch(&[create_string_array(vec![Some("hello")])])?;

let serialized = SerializableAccumulator::serialize(&mut acc)?;
let acc2 = ArrayAggAccumulator::try_new(&DataType::Utf8)?;

let mut deserialized = ArrayAggAccumulator::deserialize(acc2, serialized)?;

assert_eq!(acc.evaluate()?, deserialized.evaluate()?);
Ok(())
}
}
Loading

0 comments on commit ee54420

Please sign in to comment.