From 779c286adc01ef2881ae3f002ace82803f5a8efa Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Fri, 2 Aug 2024 15:50:06 -0700 Subject: [PATCH 1/8] Adding ScalarValue serialization back in --- crates/core/src/utils/serialize.rs | 1584 ++++++++++++---------------- 1 file changed, 651 insertions(+), 933 deletions(-) diff --git a/crates/core/src/utils/serialize.rs b/crates/core/src/utils/serialize.rs index d0e332a..60a3d7b 100644 --- a/crates/core/src/utils/serialize.rs +++ b/crates/core/src/utils/serialize.rs @@ -1,933 +1,651 @@ -// use arrow::datatypes::DataType; -// use serde::ser::SerializeStruct; -// use serde::{Deserialize, Deserializer, Serialize, Serializer}; -// use std::sync::Arc; - -// use datafusion_common::scalar::ScalarStructBuilder; -// use datafusion_common::ScalarValue; - -// use arrow::array::*; -// use arrow::datatypes::*; - -// pub struct SerializableScalarValue { -// value: ScalarValue, -// } - -// impl Serialize for SerializableScalarValue { -// fn serialize(&self, serializer: S) -> Result -// where -// S: Serializer, -// { -// match self.value { -// ScalarValue::Null => { -// let mut st = serializer.serialize_struct("ScalarValue", 1)?; -// st.serialize_field("type", "Null")?; -// st.end() -// } -// ScalarValue::Boolean(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Boolean")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Float16(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Float16")?; -// st.serialize_field("value", &v.map(|f| f.to_f32()))?; -// st.end() -// } -// ScalarValue::Float32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Float32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Float64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Float64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Decimal128(v, p, s) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Decimal128")?; -// st.serialize_field("value", &v.map(|x| x.to_string()))?; -// st.serialize_field("precision", p)?; -// st.serialize_field("scale", s)?; -// st.end() -// } -// ScalarValue::Decimal256(v, p, s) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Decimal256")?; -// st.serialize_field("value", &v.map(|x| x.to_string()))?; -// st.serialize_field("precision", p)?; -// st.serialize_field("scale", s)?; -// st.end() -// } -// ScalarValue::Int8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Int16(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int16")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Int32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Int64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt16(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt16")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Utf8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Utf8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::LargeUtf8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "LargeUtf8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Binary(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Binary")?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::LargeBinary(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "LargeBinary")?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::FixedSizeBinary(size, v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "FixedSizeBinary")?; -// st.serialize_field("size", size)?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::List(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "List")?; -// if let arr = v.as_ref() { -// let list_data_type = arr.data_type(); -// println!("arr len {}", arr.len()); -// if let DataType::List(field) = list_data_type { -// st.serialize_field("child_type", &field.data_type().to_string())?; -// let values = arr.value(0); -// let nested_values = if arr.is_null(0) { -// vec![None] -// } else { -// let ret = (0..values.len()) -// .map(|i| ScalarValue::try_from_array(&values, i).map(Some).unwrap()) -// .collect::>>(); -// ret -// }; -// st.serialize_field("value", &nested_values)?; -// } else { -// return Err(serde::ser::Error::custom("Invalid List data type")); -// } -// } else { -// st.serialize_field("child_type", &DataType::Null.to_string())?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// ScalarValue::Date32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Date32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Date64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Date64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time32Second(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time32Second")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time32Millisecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time32Millisecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time64Microsecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time64Microsecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time64Nanosecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time64Nanosecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::TimestampSecond(v, tz) -// | ScalarValue::TimestampMillisecond(v, tz) -// | ScalarValue::TimestampMicrosecond(v, tz) -// | ScalarValue::TimestampNanosecond(v, tz) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Timestamp")?; -// st.serialize_field("value", v)?; -// st.serialize_field("timezone", &tz.as_ref().map(|s| s.to_string()))?; -// st.serialize_field( -// "unit", -// match self.value { -// ScalarValue::TimestampSecond(_, _) => "Second", -// ScalarValue::TimestampMillisecond(_, _) => "Millisecond", -// ScalarValue::TimestampMicrosecond(_, _) => "Microsecond", -// ScalarValue::TimestampNanosecond(_, _) => "Nanosecond", -// _ => unreachable!(), -// }, -// )?; -// st.end() -// } -// ScalarValue::IntervalYearMonth(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "IntervalYearMonth")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::IntervalDayTime(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "IntervalDayTime")?; -// st.serialize_field("value", &v.map(|x| (x.days, x.milliseconds)))?; -// st.end() -// } -// ScalarValue::IntervalMonthDayNano(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "IntervalMonthDayNano")?; -// st.serialize_field("value", &v.map(|x| (x.months, x.days, x.nanoseconds)))?; -// st.end() -// } -// ScalarValue::DurationSecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationSecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::DurationMillisecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationMillisecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::DurationMicrosecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationMicrosecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::DurationNanosecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationNanosecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Union(v, fields, mode) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Union")?; -// st.serialize_field("value", v)?; -// st.serialize_field( -// "fields", -// &fields -// .iter() -// .map(|(i, f)| (i, f.name().to_string(), f.data_type().to_string())) -// .collect::>(), -// )?; -// st.serialize_field( -// "mode", -// match mode { -// UnionMode::Sparse => "Sparse", -// UnionMode::Dense => "Dense", -// }, -// )?; -// st.end() -// } -// ScalarValue::Dictionary(key_type, value) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "Dictionary")?; -// st.serialize_field("key_type", &key_type.to_string())?; -// st.serialize_field("value", value)?; -// st.end() -// } -// ScalarValue::Utf8View(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Utf8View")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::BinaryView(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "BinaryView")?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::LargeList(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "LargeList")?; -// if let arr = v.as_ref() { -// let list_data_type = arr.data_type(); -// if let DataType::LargeList(field) = list_data_type { -// st.serialize_field("child_type", &field.data_type().to_string())?; -// let values = arr.value(0); -// let nested_values = if arr.is_null(0) { -// vec![None] -// } else { -// let ret = (0..values.len()) -// .map(|i| ScalarValue::try_from_array(&values, i).map(Some).unwrap()) -// .collect::>>(); -// ret -// }; -// st.serialize_field("value", &nested_values)?; -// } else { -// return Err(serde::ser::Error::custom("Invalid LargeList data type")); -// } -// } else { -// st.serialize_field("child_type", &DataType::Null.to_string())?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// ScalarValue::FixedSizeList(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "FixedSizeList")?; -// if let arr = v.as_ref() { -// let list_data_type = arr.data_type(); -// if let DataType::FixedSizeList(field, size) = list_data_type { -// st.serialize_field("child_type", &field.data_type().to_string())?; -// st.serialize_field("size", size)?; -// let values = arr.value(0); -// let nested_values = if arr.is_null(0) { -// vec![None] -// } else { -// let ret = (0..values.len()) -// .map(|i| ScalarValue::try_from_array(&values, i).map(Some).unwrap()) -// .collect::>>(); -// ret -// }; -// st.serialize_field("value", &nested_values)?; -// } else { -// return Err(serde::ser::Error::custom("Invalid FixedSizeList data type")); -// } -// } else { -// st.serialize_field("child_type", &DataType::Null.to_string())?; -// st.serialize_field("size", &0)?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// ScalarValue::Struct(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "Struct")?; -// if let struct_arr = v.as_ref() { -// let fields: Vec<(String, String)> = struct_arr -// .fields() -// .iter() -// .map(|f| (f.name().to_string(), f.data_type().to_string())) -// .collect(); -// st.serialize_field("fields", &fields)?; - -// let values: Vec> = struct_arr -// .columns() -// .iter() -// .enumerate() -// .map(|(i, field)| { -// if struct_arr.is_null(0) { -// Ok(None) -// } else { -// ScalarValue::try_from_array(field, 0) -// .map(Some) -// .map_err(serde::ser::Error::custom) -// } -// }) -// .collect::>()?; -// st.serialize_field("value", &values)?; -// } else { -// st.serialize_field("fields", &Vec::<(String, String)>::new())?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// } -// } -// } - -// use std::str::FromStr; - -// impl<'de> Deserialize<'de> for ScalarValue { -// fn deserialize(deserializer: D) -> Result -// where -// D: Deserializer<'de>, -// { -// #[derive(Deserialize)] -// #[serde(tag = "type")] -// enum ScalarValueHelper { -// Null, -// Boolean { -// value: Option, -// }, -// Float16 { -// value: Option, -// }, -// Float32 { -// value: Option, -// }, -// Float64 { -// value: Option, -// }, -// Decimal128 { -// value: Option, -// precision: u8, -// scale: i8, -// }, -// Decimal256 { -// value: Option, -// precision: u8, -// scale: i8, -// }, -// Int8 { -// value: Option, -// }, -// Int16 { -// value: Option, -// }, -// Int32 { -// value: Option, -// }, -// Int64 { -// value: Option, -// }, -// UInt8 { -// value: Option, -// }, -// UInt16 { -// value: Option, -// }, -// UInt32 { -// value: Option, -// }, -// UInt64 { -// value: Option, -// }, -// Utf8 { -// value: Option, -// }, -// LargeUtf8 { -// value: Option, -// }, -// Binary { -// value: Option, -// }, -// LargeBinary { -// value: Option, -// }, -// FixedSizeBinary { -// size: i32, -// value: Option, -// }, -// List { -// child_type: String, -// value: Option>>, -// }, -// LargeList { -// child_type: String, -// value: Option>>, -// }, -// FixedSizeList { -// child_type: String, -// size: usize, -// value: Option>>, -// }, -// Date32 { -// value: Option, -// }, -// Date64 { -// value: Option, -// }, -// Time32Second { -// value: Option, -// }, -// Time32Millisecond { -// value: Option, -// }, -// Time64Microsecond { -// value: Option, -// }, -// Time64Nanosecond { -// value: Option, -// }, -// Timestamp { -// value: Option, -// timezone: Option, -// unit: String, -// }, -// IntervalYearMonth { -// value: Option, -// }, -// IntervalDayTime { -// value: Option<(i32, i32)>, -// }, -// IntervalMonthDayNano { -// value: Option<(i32, i32, i64)>, -// }, -// DurationSecond { -// value: Option, -// }, -// DurationMillisecond { -// value: Option, -// }, -// DurationMicrosecond { -// value: Option, -// }, -// DurationNanosecond { -// value: Option, -// }, -// Union { -// value: Option<(i8, Box)>, -// fields: Vec<(i8, String, String)>, -// mode: String, -// }, -// Dictionary { -// key_type: String, -// value: Box, -// }, -// Utf8View { -// value: Option, -// }, -// BinaryView { -// value: Option, -// }, -// Struct { -// fields: Vec<(String, String)>, -// value: Option>>, -// }, -// } - -// let helper = ScalarValueHelper::deserialize(deserializer)?; - -// Ok(match helper { -// ScalarValueHelper::Null => ScalarValue::Null, -// ScalarValueHelper::Boolean { value } => ScalarValue::Boolean(value), -// ScalarValueHelper::Float16 { value } => { -// ScalarValue::Float16(value.map(half::f16::from_f32)) -// } -// ScalarValueHelper::Float32 { value } => ScalarValue::Float32(value), -// ScalarValueHelper::Float64 { value } => ScalarValue::Float64(value), -// ScalarValueHelper::Decimal128 { -// value, -// precision, -// scale, -// } => ScalarValue::Decimal128( -// value.map(|s| s.parse().unwrap()), //TODO: fix me -// precision, -// scale, -// ), -// ScalarValueHelper::Decimal256 { -// value, -// precision, -// scale, -// } => ScalarValue::Decimal256(value.map(|s| s.parse().unwrap()), precision, scale), -// ScalarValueHelper::Int8 { value } => ScalarValue::Int8(value), -// ScalarValueHelper::Int16 { value } => ScalarValue::Int16(value), -// ScalarValueHelper::Int32 { value } => ScalarValue::Int32(value), -// ScalarValueHelper::Int64 { value } => ScalarValue::Int64(value), -// ScalarValueHelper::UInt8 { value } => ScalarValue::UInt8(value), -// ScalarValueHelper::UInt16 { value } => ScalarValue::UInt16(value), -// ScalarValueHelper::UInt32 { value } => ScalarValue::UInt32(value), -// ScalarValueHelper::UInt64 { value } => ScalarValue::UInt64(value), -// ScalarValueHelper::Utf8 { value } => ScalarValue::Utf8(value), -// ScalarValueHelper::LargeUtf8 { value } => ScalarValue::LargeUtf8(value), -// ScalarValueHelper::Binary { value } => { -// ScalarValue::Binary(value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::LargeBinary { value } => { -// ScalarValue::LargeBinary(value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::FixedSizeBinary { size, value } => { -// ScalarValue::FixedSizeBinary(size, value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::List { child_type, value } => { -// let field = Arc::new(Field::new( -// "item", -// DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, -// true, -// )); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// let data = ScalarValue::new_list(&scalar_values, &field.data_type()); -// ScalarValue::List(data) -// } -// ScalarValueHelper::LargeList { child_type, value } => { -// let field = Arc::new(Field::new( -// "item", -// DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, -// true, -// )); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// ScalarValue::LargeList(ScalarValue::new_large_list( -// &scalar_values, -// &field.data_type(), -// )) -// } -// ScalarValueHelper::FixedSizeList { -// child_type, -// size, -// value, -// } => { -// let field = Arc::new(Field::new( -// "item", -// DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, -// true, -// )); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// let data_type = DataType::FixedSizeList(field, scalar_values.len() as i32); -// let value_data = ScalarValue::new_list(&scalar_values, &data_type).to_data(); -// let list_array = FixedSizeListArray::from(value_data); -// ScalarValue::FixedSizeList(Arc::new(list_array)) -// } -// ScalarValueHelper::Date32 { value } => ScalarValue::Date32(value), -// ScalarValueHelper::Date64 { value } => ScalarValue::Date64(value), -// ScalarValueHelper::Time32Second { value } => ScalarValue::Time32Second(value), -// ScalarValueHelper::Time32Millisecond { value } => ScalarValue::Time32Millisecond(value), -// ScalarValueHelper::Time64Microsecond { value } => ScalarValue::Time64Microsecond(value), -// ScalarValueHelper::Time64Nanosecond { value } => ScalarValue::Time64Nanosecond(value), -// ScalarValueHelper::Timestamp { -// value, -// timezone, -// unit, -// } => match unit.as_str() { -// "Second" => ScalarValue::TimestampSecond(value, timezone.map(Arc::from)), -// "Millisecond" => ScalarValue::TimestampMillisecond(value, timezone.map(Arc::from)), -// "Microsecond" => ScalarValue::TimestampMicrosecond(value, timezone.map(Arc::from)), -// "Nanosecond" => ScalarValue::TimestampNanosecond(value, timezone.map(Arc::from)), -// _ => return Err(serde::de::Error::custom("Invalid timestamp unit")), -// }, -// ScalarValueHelper::IntervalYearMonth { value } => ScalarValue::IntervalYearMonth(value), -// ScalarValueHelper::IntervalDayTime { value } => ScalarValue::IntervalDayTime( -// value.map(|(days, millis)| IntervalDayTime::new(days, millis)), -// ), -// ScalarValueHelper::IntervalMonthDayNano { value } => ScalarValue::IntervalMonthDayNano( -// value.map(|(months, days, nanos)| IntervalMonthDayNano::new(months, days, nanos)), -// ), -// ScalarValueHelper::DurationSecond { value } => ScalarValue::DurationSecond(value), -// ScalarValueHelper::DurationMillisecond { value } => { -// ScalarValue::DurationMillisecond(value) -// } -// ScalarValueHelper::DurationMicrosecond { value } => { -// ScalarValue::DurationMicrosecond(value) -// } -// ScalarValueHelper::DurationNanosecond { value } => { -// ScalarValue::DurationNanosecond(value) -// } -// ScalarValueHelper::Union { -// value, -// fields, -// mode, -// } => { -// let union_fields = fields -// .into_iter() -// .map(|(i, name, type_str)| { -// ( -// i, -// Arc::new(Field::new( -// name, -// DataType::from_str(&type_str).unwrap(), -// true, -// )), -// ) -// }) -// .collect(); -// let union_mode = match mode.as_str() { -// "Sparse" => UnionMode::Sparse, -// "Dense" => UnionMode::Dense, -// _ => return Err(serde::de::Error::custom("Invalid union mode")), -// }; -// ScalarValue::Union(value, union_fields, union_mode) -// } -// ScalarValueHelper::Dictionary { key_type, value } => { -// ScalarValue::Dictionary(Box::new(DataType::from_str(&key_type).unwrap()), value) -// } -// ScalarValueHelper::Utf8View { value } => ScalarValue::Utf8View(value), -// ScalarValueHelper::BinaryView { value } => { -// ScalarValue::BinaryView(value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::Struct { fields, value } => { -// let struct_fields: Vec = fields -// .into_iter() -// .map(|(name, type_str)| { -// Field::new(name, DataType::from_str(&type_str).unwrap(), true) -// }) -// .collect(); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// let mut builder = ScalarStructBuilder::new(); -// for (field, value) in struct_fields.clone().into_iter().zip(scalar_values) { -// builder = builder.with_scalar(field, value); -// } -// let struct_array = builder.build().unwrap(); - -// ScalarValue::Struct(Arc::new(StructArray::new( -// Fields::from(struct_fields), -// vec![struct_array.to_array().unwrap()], -// None, -// ))) -// } -// }) -// } -// } - -// #[cfg(test)] -// mod tests { -// use super::*; -// use serde_json; -// use std::sync::Arc; - -// fn test_serde_roundtrip(scalar: ScalarValue) { -// let serialized = serde_json::to_string(&scalar).unwrap(); -// let deserialized: ScalarValue = serde_json::from_str(&serialized).unwrap(); -// assert_eq!(scalar, deserialized); -// } - -// #[test] -// fn test_large_utf8() { -// test_serde_roundtrip(ScalarValue::LargeUtf8(Some("hello".to_string()))); -// test_serde_roundtrip(ScalarValue::LargeUtf8(None)); -// } - -// #[test] -// fn test_binary() { -// test_serde_roundtrip(ScalarValue::Binary(Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::Binary(None)); -// } - -// #[test] -// fn test_large_binary() { -// test_serde_roundtrip(ScalarValue::LargeBinary(Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::LargeBinary(None)); -// } - -// #[test] -// fn test_fixed_size_binary() { -// test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, None)); -// } - -// #[test] -// fn test_list() { -// let list = ListArray::from_iter_primitive::(vec![Some(vec![ -// Some(1), -// Some(2), -// Some(3), -// ])]); -// test_serde_roundtrip(ScalarValue::List(Arc::new(list))); -// } - -// #[test] -// fn test_large_list() { -// let list = LargeListArray::from_iter_primitive::(vec![Some(vec![ -// Some(1), -// Some(2), -// Some(3), -// ])]); -// test_serde_roundtrip(ScalarValue::LargeList(Arc::new(list))); -// } - -// // #[test] -// // fn test_fixed_size_list() { -// // let list = FixedSizeListArray::from_iter_primitive::( -// // vec![Some(vec![Some(1), Some(2), Some(3)])], -// // 3, -// // ); -// // test_serde_roundtrip(ScalarValue::FixedSizeList(Arc::new(list))); -// // } - -// #[test] -// fn test_date32() { -// test_serde_roundtrip(ScalarValue::Date32(Some(1000))); -// test_serde_roundtrip(ScalarValue::Date32(None)); -// } - -// #[test] -// fn test_date64() { -// test_serde_roundtrip(ScalarValue::Date64(Some(86400000))); -// test_serde_roundtrip(ScalarValue::Date64(None)); -// } - -// #[test] -// fn test_time32_second() { -// test_serde_roundtrip(ScalarValue::Time32Second(Some(3600))); -// test_serde_roundtrip(ScalarValue::Time32Second(None)); -// } - -// #[test] -// fn test_time32_millisecond() { -// test_serde_roundtrip(ScalarValue::Time32Millisecond(Some(3600000))); -// test_serde_roundtrip(ScalarValue::Time32Millisecond(None)); -// } - -// #[test] -// fn test_time64_microsecond() { -// test_serde_roundtrip(ScalarValue::Time64Microsecond(Some(3600000000))); -// test_serde_roundtrip(ScalarValue::Time64Microsecond(None)); -// } - -// #[test] -// fn test_time64_nanosecond() { -// test_serde_roundtrip(ScalarValue::Time64Nanosecond(Some(3600000000000))); -// test_serde_roundtrip(ScalarValue::Time64Nanosecond(None)); -// } - -// #[test] -// fn test_timestamp() { -// test_serde_roundtrip(ScalarValue::TimestampSecond( -// Some(1625097600), -// Some(Arc::from("UTC")), -// )); -// test_serde_roundtrip(ScalarValue::TimestampMillisecond(Some(1625097600000), None)); -// test_serde_roundtrip(ScalarValue::TimestampMicrosecond( -// Some(1625097600000000), -// Some(Arc::from("UTC")), -// )); -// test_serde_roundtrip(ScalarValue::TimestampNanosecond( -// Some(1625097600000000000), -// None, -// )); -// } - -// #[test] -// fn test_interval_year_month() { -// test_serde_roundtrip(ScalarValue::IntervalYearMonth(Some(14))); -// test_serde_roundtrip(ScalarValue::IntervalYearMonth(None)); -// } - -// #[test] -// fn test_interval_day_time() { -// test_serde_roundtrip(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( -// 5, 43200000, -// )))); -// test_serde_roundtrip(ScalarValue::IntervalDayTime(None)); -// } - -// #[test] -// fn test_interval_month_day_nano() { -// test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(Some( -// IntervalMonthDayNano::new(1, 15, 1000000000), -// ))); -// test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(None)); -// } - -// #[test] -// fn test_duration() { -// test_serde_roundtrip(ScalarValue::DurationSecond(Some(3600))); -// test_serde_roundtrip(ScalarValue::DurationMillisecond(Some(3600000))); -// test_serde_roundtrip(ScalarValue::DurationMicrosecond(Some(3600000000))); -// test_serde_roundtrip(ScalarValue::DurationNanosecond(Some(3600000000000))); -// } - -// // #[test] -// // fn test_union() { -// // let fields = vec![ -// // (0, Arc::new(Field::new("f1", DataType::Int32, true))), -// // (1, Arc::new(Field::new("f2", DataType::Utf8, true))), -// // ]; -// // test_serde_roundtrip(ScalarValue::Union( -// // Some((0, Box::new(ScalarValue::Int32(Some(42))))), -// // fields.clone(), -// // UnionMode::Sparse, -// // )); -// // test_serde_roundtrip(ScalarValue::Union(None, fields, UnionMode::Dense)); -// // } - -// #[test] -// fn test_dictionary() { -// test_serde_roundtrip(ScalarValue::Dictionary( -// Box::new(DataType::Int8), -// Box::new(ScalarValue::Utf8(Some("hello".to_string()))), -// )); -// } - -// #[test] -// fn test_utf8_view() { -// test_serde_roundtrip(ScalarValue::Utf8View(Some("hello".to_string()))); -// test_serde_roundtrip(ScalarValue::Utf8View(None)); -// } - -// #[test] -// fn test_binary_view() { -// test_serde_roundtrip(ScalarValue::BinaryView(Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::BinaryView(None)); -// } - -// /* #[test] -// fn test_struct() { -// let fields = vec![ -// Field::new("f1", DataType::Int32, true), -// Field::new("f2", DataType::Utf8, true), -// ]; -// let values = vec![ -// Some(ScalarValue::Int32(Some(42))), -// Some(ScalarValue::Utf8(Some("hello".to_string()))), -// ]; -// let struct_array = StructArray::from(values); -// test_serde_roundtrip(ScalarValue::Struct(Arc::new(struct_array))); -// } */ -// } +use arrow::datatypes::DataType; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::sync::Arc; + +use datafusion_common::scalar::ScalarStructBuilder; +use datafusion_common::ScalarValue; + +use arrow::array::*; +use arrow::datatypes::*; +use serde::Serializer; +use serde_json::json; + +pub fn serialize_scalar_value(value: &ScalarValue, serializer: S) -> Result +where + S: Serializer, +{ + let json_value = match value { + ScalarValue::Null => json!({"type": "Null"}), + ScalarValue::Boolean(v) => json!({"type": "Boolean", "value": v}), + ScalarValue::Float16(v) => json!({"type": "Float16", "value": v.map(|f| f.to_f32())}), + ScalarValue::Float32(v) => json!({"type": "Float32", "value": v}), + ScalarValue::Float64(v) => json!({"type": "Float64", "value": v}), + ScalarValue::Decimal128(v, p, s) => json!({ + "type": "Decimal128", + "value": v.map(|x| x.to_string()), + "precision": p, + "scale": s + }), + ScalarValue::Decimal256(v, p, s) => json!({ + "type": "Decimal256", + "value": v.map(|x| x.to_string()), + "precision": p, + "scale": s + }), + ScalarValue::Int8(v) => json!({"type": "Int8", "value": v}), + ScalarValue::Int16(v) => json!({"type": "Int16", "value": v}), + ScalarValue::Int32(v) => json!({"type": "Int32", "value": v}), + ScalarValue::Int64(v) => json!({"type": "Int64", "value": v}), + ScalarValue::UInt8(v) => json!({"type": "UInt8", "value": v}), + ScalarValue::UInt16(v) => json!({"type": "UInt16", "value": v}), + ScalarValue::UInt32(v) => json!({"type": "UInt32", "value": v}), + ScalarValue::UInt64(v) => json!({"type": "UInt64", "value": v}), + ScalarValue::Utf8(v) => json!({"type": "Utf8", "value": v}), + ScalarValue::LargeUtf8(v) => json!({"type": "LargeUtf8", "value": v}), + ScalarValue::Binary(v) => json!({ + "type": "Binary", + "value": v.as_ref().map(|b| base64::encode(b)) + }), + ScalarValue::LargeBinary(v) => json!({ + "type": "LargeBinary", + "value": v.as_ref().map(|b| base64::encode(b)) + }), + ScalarValue::FixedSizeBinary(size, v) => json!({ + "type": "FixedSizeBinary", + "size": size, + "value": v.as_ref().map(|b| base64::encode(b)) + }), + ScalarValue::List(v) => json!({ + "type": "List", + "value": serialize_array(v)? + }), + ScalarValue::Date32(v) => json!({"type": "Date32", "value": v}), + ScalarValue::Date64(v) => json!({"type": "Date64", "value": v}), + ScalarValue::Time32Second(v) => json!({"type": "Time32Second", "value": v}), + ScalarValue::Time32Millisecond(v) => json!({"type": "Time32Millisecond", "value": v}), + ScalarValue::Time64Microsecond(v) => json!({"type": "Time64Microsecond", "value": v}), + ScalarValue::Time64Nanosecond(v) => json!({"type": "Time64Nanosecond", "value": v}), + ScalarValue::TimestampSecond(v, tz) + | ScalarValue::TimestampMillisecond(v, tz) + | ScalarValue::TimestampMicrosecond(v, tz) + | ScalarValue::TimestampNanosecond(v, tz) => json!({ + "type": "Timestamp", + "value": v, + "timezone": tz.as_ref().map(|s| s.to_string()), + "unit": match value { + ScalarValue::TimestampSecond(_, _) => "Second", + ScalarValue::TimestampMillisecond(_, _) => "Millisecond", + ScalarValue::TimestampMicrosecond(_, _) => "Microsecond", + ScalarValue::TimestampNanosecond(_, _) => "Nanosecond", + _ => unreachable!(), + } + }), + ScalarValue::IntervalYearMonth(v) => json!({"type": "IntervalYearMonth", "value": v}), + ScalarValue::IntervalDayTime(v) => json!({ + "type": "IntervalDayTime", + "value": v.map(|x| (x.days, x.milliseconds)) + }), + ScalarValue::IntervalMonthDayNano(v) => json!({ + "type": "IntervalMonthDayNano", + "value": v.map(|x| (x.months, x.days, x.nanoseconds)) + }), + ScalarValue::DurationSecond(v) => json!({"type": "DurationSecond", "value": v}), + ScalarValue::DurationMillisecond(v) => json!({"type": "DurationMillisecond", "value": v}), + ScalarValue::DurationMicrosecond(v) => json!({"type": "DurationMicrosecond", "value": v}), + ScalarValue::DurationNanosecond(v) => json!({"type": "DurationNanosecond", "value": v}), + ScalarValue::Struct(v) => json!({ + "type": "Struct", + "fields": v.as_ref().fields().iter().map(|f| (f.name().to_string(), f.data_type().to_string())).collect::>(), + "value": v.as_ref().fields().iter().enumerate().map(|(i, field)| ScalarValue::try_from_array(field, 0).ok()).collect::>(), + + }), + // Add other variants as needed... + }; + + json_value.serialize(serializer) +} + +fn serialize_array(arr: &Arc) -> Result>, S::Error> { + (0..arr.len()) + .map(|i| { + ScalarValue::try_from_array(arr.as_ref(), i) + .map(Some) + .map_err(serde::ser::Error::custom) + }) + .collect() +} +use std::str::FromStr; + +impl<'de> Deserialize<'de> for ScalarValue { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(tag = "type")] + enum ScalarValueHelper { + Null, + Boolean { + value: Option, + }, + Float16 { + value: Option, + }, + Float32 { + value: Option, + }, + Float64 { + value: Option, + }, + Decimal128 { + value: Option, + precision: u8, + scale: i8, + }, + Decimal256 { + value: Option, + precision: u8, + scale: i8, + }, + Int8 { + value: Option, + }, + Int16 { + value: Option, + }, + Int32 { + value: Option, + }, + Int64 { + value: Option, + }, + UInt8 { + value: Option, + }, + UInt16 { + value: Option, + }, + UInt32 { + value: Option, + }, + UInt64 { + value: Option, + }, + Utf8 { + value: Option, + }, + LargeUtf8 { + value: Option, + }, + Binary { + value: Option, + }, + LargeBinary { + value: Option, + }, + FixedSizeBinary { + size: i32, + value: Option, + }, + List { + child_type: String, + value: Option>>, + }, + LargeList { + child_type: String, + value: Option>>, + }, + FixedSizeList { + child_type: String, + size: usize, + value: Option>>, + }, + Date32 { + value: Option, + }, + Date64 { + value: Option, + }, + Time32Second { + value: Option, + }, + Time32Millisecond { + value: Option, + }, + Time64Microsecond { + value: Option, + }, + Time64Nanosecond { + value: Option, + }, + Timestamp { + value: Option, + timezone: Option, + unit: String, + }, + IntervalYearMonth { + value: Option, + }, + IntervalDayTime { + value: Option<(i32, i32)>, + }, + IntervalMonthDayNano { + value: Option<(i32, i32, i64)>, + }, + DurationSecond { + value: Option, + }, + DurationMillisecond { + value: Option, + }, + DurationMicrosecond { + value: Option, + }, + DurationNanosecond { + value: Option, + }, + Union { + value: Option<(i8, Box)>, + fields: Vec<(i8, String, String)>, + mode: String, + }, + Dictionary { + key_type: String, + value: Box, + }, + Utf8View { + value: Option, + }, + BinaryView { + value: Option, + }, + Struct { + fields: Vec<(String, String)>, + value: Option>>, + }, + } + + let helper = ScalarValueHelper::deserialize(deserializer)?; + + Ok(match helper { + ScalarValueHelper::Null => ScalarValue::Null, + ScalarValueHelper::Boolean { value } => ScalarValue::Boolean(value), + ScalarValueHelper::Float16 { value } => { + ScalarValue::Float16(value.map(half::f16::from_f32)) + } + ScalarValueHelper::Float32 { value } => ScalarValue::Float32(value), + ScalarValueHelper::Float64 { value } => ScalarValue::Float64(value), + ScalarValueHelper::Decimal128 { + value, + precision, + scale, + } => ScalarValue::Decimal128( + value.map(|s| s.parse().unwrap()), //TODO: fix me + precision, + scale, + ), + ScalarValueHelper::Decimal256 { + value, + precision, + scale, + } => ScalarValue::Decimal256(value.map(|s| s.parse().unwrap()), precision, scale), + ScalarValueHelper::Int8 { value } => ScalarValue::Int8(value), + ScalarValueHelper::Int16 { value } => ScalarValue::Int16(value), + ScalarValueHelper::Int32 { value } => ScalarValue::Int32(value), + ScalarValueHelper::Int64 { value } => ScalarValue::Int64(value), + ScalarValueHelper::UInt8 { value } => ScalarValue::UInt8(value), + ScalarValueHelper::UInt16 { value } => ScalarValue::UInt16(value), + ScalarValueHelper::UInt32 { value } => ScalarValue::UInt32(value), + ScalarValueHelper::UInt64 { value } => ScalarValue::UInt64(value), + ScalarValueHelper::Utf8 { value } => ScalarValue::Utf8(value), + ScalarValueHelper::LargeUtf8 { value } => ScalarValue::LargeUtf8(value), + ScalarValueHelper::Binary { value } => { + ScalarValue::Binary(value.map(|s| base64::decode(s).unwrap())) + } + ScalarValueHelper::LargeBinary { value } => { + ScalarValue::LargeBinary(value.map(|s| base64::decode(s).unwrap())) + } + ScalarValueHelper::FixedSizeBinary { size, value } => { + ScalarValue::FixedSizeBinary(size, value.map(|s| base64::decode(s).unwrap())) + } + ScalarValueHelper::List { child_type, value } => { + let field = Arc::new(Field::new( + "item", + DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, + true, + )); + let values: Vec> = value.unwrap_or_default(); + let scalar_values: Vec = values + .into_iter() + .map(|v| v.unwrap_or(ScalarValue::Null)) + .collect(); + let data = ScalarValue::new_list(&scalar_values, &field.data_type()); + ScalarValue::List(data) + } + ScalarValueHelper::LargeList { child_type, value } => { + let field = Arc::new(Field::new( + "item", + DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, + true, + )); + let values: Vec> = value.unwrap_or_default(); + let scalar_values: Vec = values + .into_iter() + .map(|v| v.unwrap_or(ScalarValue::Null)) + .collect(); + ScalarValue::LargeList(ScalarValue::new_large_list( + &scalar_values, + &field.data_type(), + )) + } + ScalarValueHelper::FixedSizeList { + child_type, + size, + value, + } => { + let field = Arc::new(Field::new( + "item", + DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, + true, + )); + let values: Vec> = value.unwrap_or_default(); + let scalar_values: Vec = values + .into_iter() + .map(|v| v.unwrap_or(ScalarValue::Null)) + .collect(); + let data_type = DataType::FixedSizeList(field, scalar_values.len() as i32); + let value_data = ScalarValue::new_list(&scalar_values, &data_type).to_data(); + let list_array = FixedSizeListArray::from(value_data); + ScalarValue::FixedSizeList(Arc::new(list_array)) + } + ScalarValueHelper::Date32 { value } => ScalarValue::Date32(value), + ScalarValueHelper::Date64 { value } => ScalarValue::Date64(value), + ScalarValueHelper::Time32Second { value } => ScalarValue::Time32Second(value), + ScalarValueHelper::Time32Millisecond { value } => ScalarValue::Time32Millisecond(value), + ScalarValueHelper::Time64Microsecond { value } => ScalarValue::Time64Microsecond(value), + ScalarValueHelper::Time64Nanosecond { value } => ScalarValue::Time64Nanosecond(value), + ScalarValueHelper::Timestamp { + value, + timezone, + unit, + } => match unit.as_str() { + "Second" => ScalarValue::TimestampSecond(value, timezone.map(Arc::from)), + "Millisecond" => ScalarValue::TimestampMillisecond(value, timezone.map(Arc::from)), + "Microsecond" => ScalarValue::TimestampMicrosecond(value, timezone.map(Arc::from)), + "Nanosecond" => ScalarValue::TimestampNanosecond(value, timezone.map(Arc::from)), + _ => return Err(serde::de::Error::custom("Invalid timestamp unit")), + }, + ScalarValueHelper::IntervalYearMonth { value } => ScalarValue::IntervalYearMonth(value), + ScalarValueHelper::IntervalDayTime { value } => ScalarValue::IntervalDayTime( + value.map(|(days, millis)| IntervalDayTime::new(days, millis)), + ), + ScalarValueHelper::IntervalMonthDayNano { value } => ScalarValue::IntervalMonthDayNano( + value.map(|(months, days, nanos)| IntervalMonthDayNano::new(months, days, nanos)), + ), + ScalarValueHelper::DurationSecond { value } => ScalarValue::DurationSecond(value), + ScalarValueHelper::DurationMillisecond { value } => { + ScalarValue::DurationMillisecond(value) + } + ScalarValueHelper::DurationMicrosecond { value } => { + ScalarValue::DurationMicrosecond(value) + } + ScalarValueHelper::DurationNanosecond { value } => { + ScalarValue::DurationNanosecond(value) + } + ScalarValueHelper::Union { + value, + fields, + mode, + } => { + let union_fields = fields + .into_iter() + .map(|(i, name, type_str)| { + ( + i, + Arc::new(Field::new( + name, + DataType::from_str(&type_str).unwrap(), + true, + )), + ) + }) + .collect(); + let union_mode = match mode.as_str() { + "Sparse" => UnionMode::Sparse, + "Dense" => UnionMode::Dense, + _ => return Err(serde::de::Error::custom("Invalid union mode")), + }; + ScalarValue::Union(value, union_fields, union_mode) + } + ScalarValueHelper::Dictionary { key_type, value } => { + ScalarValue::Dictionary(Box::new(DataType::from_str(&key_type).unwrap()), value) + } + ScalarValueHelper::Utf8View { value } => ScalarValue::Utf8View(value), + ScalarValueHelper::BinaryView { value } => { + ScalarValue::BinaryView(value.map(|s| base64::decode(s).unwrap())) + } + ScalarValueHelper::Struct { fields, value } => { + let struct_fields: Vec = fields + .into_iter() + .map(|(name, type_str)| { + Field::new(name, DataType::from_str(&type_str).unwrap(), true) + }) + .collect(); + let values: Vec> = value.unwrap_or_default(); + let scalar_values: Vec = values + .into_iter() + .map(|v| v.unwrap_or(ScalarValue::Null)) + .collect(); + let mut builder = ScalarStructBuilder::new(); + for (field, value) in struct_fields.clone().into_iter().zip(scalar_values) { + builder = builder.with_scalar(field, value); + } + let struct_array = builder.build().unwrap(); + + ScalarValue::Struct(Arc::new(StructArray::new( + Fields::from(struct_fields), + vec![struct_array.to_array().unwrap()], + None, + ))) + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + use std::sync::Arc; + + fn test_serde_roundtrip(scalar: ScalarValue) { + let serialized = serde_json::to_string(&scalar).unwrap(); + let deserialized: ScalarValue = serde_json::from_str(&serialized).unwrap(); + assert_eq!(scalar, deserialized); + } + + #[test] + fn test_large_utf8() { + test_serde_roundtrip(ScalarValue::LargeUtf8(Some("hello".to_string()))); + test_serde_roundtrip(ScalarValue::LargeUtf8(None)); + } + + #[test] + fn test_binary() { + test_serde_roundtrip(ScalarValue::Binary(Some(vec![1, 2, 3]))); + test_serde_roundtrip(ScalarValue::Binary(None)); + } + + #[test] + fn test_large_binary() { + test_serde_roundtrip(ScalarValue::LargeBinary(Some(vec![1, 2, 3]))); + test_serde_roundtrip(ScalarValue::LargeBinary(None)); + } + + #[test] + fn test_fixed_size_binary() { + test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))); + test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, None)); + } + + #[test] + fn test_list() { + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + test_serde_roundtrip(ScalarValue::List(Arc::new(list))); + } + + #[test] + fn test_large_list() { + let list = LargeListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + test_serde_roundtrip(ScalarValue::LargeList(Arc::new(list))); + } + + // #[test] + // fn test_fixed_size_list() { + // let list = FixedSizeListArray::from_iter_primitive::( + // vec![Some(vec![Some(1), Some(2), Some(3)])], + // 3, + // ); + // test_serde_roundtrip(ScalarValue::FixedSizeList(Arc::new(list))); + // } + + #[test] + fn test_date32() { + test_serde_roundtrip(ScalarValue::Date32(Some(1000))); + test_serde_roundtrip(ScalarValue::Date32(None)); + } + + #[test] + fn test_date64() { + test_serde_roundtrip(ScalarValue::Date64(Some(86400000))); + test_serde_roundtrip(ScalarValue::Date64(None)); + } + + #[test] + fn test_time32_second() { + test_serde_roundtrip(ScalarValue::Time32Second(Some(3600))); + test_serde_roundtrip(ScalarValue::Time32Second(None)); + } + + #[test] + fn test_time32_millisecond() { + test_serde_roundtrip(ScalarValue::Time32Millisecond(Some(3600000))); + test_serde_roundtrip(ScalarValue::Time32Millisecond(None)); + } + + #[test] + fn test_time64_microsecond() { + test_serde_roundtrip(ScalarValue::Time64Microsecond(Some(3600000000))); + test_serde_roundtrip(ScalarValue::Time64Microsecond(None)); + } + + #[test] + fn test_time64_nanosecond() { + test_serde_roundtrip(ScalarValue::Time64Nanosecond(Some(3600000000000))); + test_serde_roundtrip(ScalarValue::Time64Nanosecond(None)); + } + + #[test] + fn test_timestamp() { + test_serde_roundtrip(ScalarValue::TimestampSecond( + Some(1625097600), + Some(Arc::from("UTC")), + )); + test_serde_roundtrip(ScalarValue::TimestampMillisecond(Some(1625097600000), None)); + test_serde_roundtrip(ScalarValue::TimestampMicrosecond( + Some(1625097600000000), + Some(Arc::from("UTC")), + )); + test_serde_roundtrip(ScalarValue::TimestampNanosecond( + Some(1625097600000000000), + None, + )); + } + + #[test] + fn test_interval_year_month() { + test_serde_roundtrip(ScalarValue::IntervalYearMonth(Some(14))); + test_serde_roundtrip(ScalarValue::IntervalYearMonth(None)); + } + + #[test] + fn test_interval_day_time() { + test_serde_roundtrip(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 5, 43200000, + )))); + test_serde_roundtrip(ScalarValue::IntervalDayTime(None)); + } + + #[test] + fn test_interval_month_day_nano() { + test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(1, 15, 1000000000), + ))); + test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(None)); + } + + #[test] + fn test_duration() { + test_serde_roundtrip(ScalarValue::DurationSecond(Some(3600))); + test_serde_roundtrip(ScalarValue::DurationMillisecond(Some(3600000))); + test_serde_roundtrip(ScalarValue::DurationMicrosecond(Some(3600000000))); + test_serde_roundtrip(ScalarValue::DurationNanosecond(Some(3600000000000))); + } + + // #[test] + // fn test_union() { + // let fields = vec![ + // (0, Arc::new(Field::new("f1", DataType::Int32, true))), + // (1, Arc::new(Field::new("f2", DataType::Utf8, true))), + // ]; + // test_serde_roundtrip(ScalarValue::Union( + // Some((0, Box::new(ScalarValue::Int32(Some(42))))), + // fields.clone(), + // UnionMode::Sparse, + // )); + // test_serde_roundtrip(ScalarValue::Union(None, fields, UnionMode::Dense)); + // } + + #[test] + fn test_dictionary() { + test_serde_roundtrip(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(Some("hello".to_string()))), + )); + } + + #[test] + fn test_utf8_view() { + test_serde_roundtrip(ScalarValue::Utf8View(Some("hello".to_string()))); + test_serde_roundtrip(ScalarValue::Utf8View(None)); + } + + #[test] + fn test_binary_view() { + test_serde_roundtrip(ScalarValue::BinaryView(Some(vec![1, 2, 3]))); + test_serde_roundtrip(ScalarValue::BinaryView(None)); + } + + /* #[test] + fn test_struct() { + let fields = vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ]; + let values = vec![ + Some(ScalarValue::Int32(Some(42))), + Some(ScalarValue::Utf8(Some("hello".to_string()))), + ]; + let struct_array = StructArray::from(values); + test_serde_roundtrip(ScalarValue::Struct(Arc::new(struct_array))); + } */ +} From fde9ecb66dc0331c8085c722df13f7e799c66a93 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Sun, 4 Aug 2024 23:17:34 -0700 Subject: [PATCH 2/8] Serialization of Scalar Values to Json --- Cargo.lock | 1 + crates/core/Cargo.toml | 1 + crates/core/src/utils/serialize.rs | 970 +++++++++++++---------------- 3 files changed, 448 insertions(+), 524 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 99d8f0c..d175f74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1086,6 +1086,7 @@ dependencies = [ "datafusion-physical-optimizer", "datafusion-physical-plan", "futures", + "half", "itertools 0.13.0", "log", "rdkafka", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 9b9b1c6..0fd1203 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -32,3 +32,4 @@ itertools = { workspace = true } serde.workspace = true rocksdb = "0.22.0" bincode = "1.3.3" +half = "2.4.1" diff --git a/crates/core/src/utils/serialize.rs b/crates/core/src/utils/serialize.rs index 60a3d7b..cba469e 100644 --- a/crates/core/src/utils/serialize.rs +++ b/crates/core/src/utils/serialize.rs @@ -1,21 +1,142 @@ -use arrow::datatypes::DataType; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::sync::Arc; -use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::ScalarValue; use arrow::array::*; use arrow::datatypes::*; -use serde::Serializer; -use serde_json::json; +use half::f16; +use serde_json::{json, Value}; + +use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; + +pub fn string_to_data_type(s: &str) -> Result> { + match s { + "Null" => Ok(DataType::Null), + "Boolean" => Ok(DataType::Boolean), + "Int8" => Ok(DataType::Int8), + "Int16" => Ok(DataType::Int16), + "Int32" => Ok(DataType::Int32), + "Int64" => Ok(DataType::Int64), + "UInt8" => Ok(DataType::UInt8), + "UInt16" => Ok(DataType::UInt16), + "UInt32" => Ok(DataType::UInt32), + "UInt64" => Ok(DataType::UInt64), + "Float16" => Ok(DataType::Float16), + "Float32" => Ok(DataType::Float32), + "Float64" => Ok(DataType::Float64), + "Binary" => Ok(DataType::Binary), + "LargeBinary" => Ok(DataType::LargeBinary), + "Utf8" => Ok(DataType::Utf8), + "LargeUtf8" => Ok(DataType::LargeUtf8), + "Date32" => Ok(DataType::Date32), + "Date64" => Ok(DataType::Date64), + s if s.starts_with("Timestamp(") => { + let parts: Vec<&str> = s[10..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid Timestamp format".into()); + } + let time_unit = match parts[0].trim() { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => return Err("Invalid TimeUnit".into()), + }; + let timezone = parts[1].trim().trim_matches('"'); + let timezone = if timezone == "None" { + None + } else { + Some(timezone.into()) + }; + Ok(DataType::Timestamp(time_unit, timezone)) + } + s if s.starts_with("Time32(") => { + let time_unit = match &s[7..s.len() - 1] { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + _ => return Err("Invalid TimeUnit for Time32".into()), + }; + Ok(DataType::Time32(time_unit)) + } + s if s.starts_with("Time64(") => { + let time_unit = match &s[7..s.len() - 1] { + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => return Err("Invalid TimeUnit for Time64".into()), + }; + Ok(DataType::Time64(time_unit)) + } + s if s.starts_with("Duration(") => { + let time_unit = match &s[9..s.len() - 1] { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => return Err("Invalid TimeUnit for Duration".into()), + }; + Ok(DataType::Duration(time_unit)) + } + s if s.starts_with("Interval(") => { + let interval_unit = match &s[9..s.len() - 1] { + "YearMonth" => IntervalUnit::YearMonth, + "DayTime" => IntervalUnit::DayTime, + "MonthDayNano" => IntervalUnit::MonthDayNano, + _ => return Err("Invalid IntervalUnit".into()), + }; + Ok(DataType::Interval(interval_unit)) + } + s if s.starts_with("FixedSizeBinary(") => { + let size: i32 = s[16..s.len() - 1].parse()?; + Ok(DataType::FixedSizeBinary(size)) + } + s if s.starts_with("List(") => { + let inner_type = string_to_data_type(&s[5..s.len() - 1])?; + Ok(DataType::List(Arc::new(Field::new( + "item", inner_type, true, + )))) + } + s if s.starts_with("LargeList(") => { + let inner_type = string_to_data_type(&s[10..s.len() - 1])?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", inner_type, true, + )))) + } + s if s.starts_with("FixedSizeList(") => { + let parts: Vec<&str> = s[14..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid FixedSizeList format".into()); + } + let inner_type = string_to_data_type(parts[0].trim())?; + let size: i32 = parts[1].trim().parse()?; + Ok(DataType::FixedSizeList( + Arc::new(Field::new("item", inner_type, true)), + size, + )) + } + s if s.starts_with("Decimal128(") => { + let parts: Vec<&str> = s[10..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid Decimal128 format".into()); + } + let precision: u8 = parts[0].trim().parse()?; + let scale: i8 = parts[1].trim().parse()?; + Ok(DataType::Decimal128(precision, scale)) + } + s if s.starts_with("Decimal256(") => { + let parts: Vec<&str> = s[10..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid Decimal256 format".into()); + } + let precision: u8 = parts[0].trim().parse()?; + let scale: i8 = parts[1].trim().parse()?; + Ok(DataType::Decimal256(precision, scale)) + } + _ => Err(format!("Unsupported DataType string: {}", s).into()), + } +} -pub fn serialize_scalar_value(value: &ScalarValue, serializer: S) -> Result -where - S: Serializer, -{ - let json_value = match value { +pub fn scalar_to_json(value: &ScalarValue) -> serde_json::Value { + match value { ScalarValue::Null => json!({"type": "Null"}), ScalarValue::Boolean(v) => json!({"type": "Boolean", "value": v}), ScalarValue::Float16(v) => json!({"type": "Float16", "value": v.map(|f| f.to_f32())}), @@ -56,10 +177,15 @@ where "size": size, "value": v.as_ref().map(|b| base64::encode(b)) }), - ScalarValue::List(v) => json!({ - "type": "List", - "value": serialize_array(v)? - }), + ScalarValue::List(v) => { + let sv = ScalarValue::try_from_array(&v.value(0), 0).unwrap(); + let dt = sv.data_type().to_string(); + json!({ + "type": "List", + "field_type": dt, + "value": scalar_to_json(&sv) + }) + } ScalarValue::Date32(v) => json!({"type": "Date32", "value": v}), ScalarValue::Date64(v) => json!({"type": "Date64", "value": v}), ScalarValue::Time32Second(v) => json!({"type": "Time32Second", "value": v}), @@ -94,558 +220,354 @@ where ScalarValue::DurationMillisecond(v) => json!({"type": "DurationMillisecond", "value": v}), ScalarValue::DurationMicrosecond(v) => json!({"type": "DurationMicrosecond", "value": v}), ScalarValue::DurationNanosecond(v) => json!({"type": "DurationNanosecond", "value": v}), - ScalarValue::Struct(v) => json!({ - "type": "Struct", - "fields": v.as_ref().fields().iter().map(|f| (f.name().to_string(), f.data_type().to_string())).collect::>(), - "value": v.as_ref().fields().iter().enumerate().map(|(i, field)| ScalarValue::try_from_array(field, 0).ok()).collect::>(), - - }), - // Add other variants as needed... - }; - - json_value.serialize(serializer) -} - -fn serialize_array(arr: &Arc) -> Result>, S::Error> { - (0..arr.len()) - .map(|i| { - ScalarValue::try_from_array(arr.as_ref(), i) - .map(Some) - .map_err(serde::ser::Error::custom) - }) - .collect() + ScalarValue::Struct(v) => { + let fields = v + .as_ref() + .fields() + .iter() + .map(|f| (f.name().to_string(), f.data_type().to_string())) + .collect::>(); + + let values = v + .columns() + .as_ref() + .iter() + .map(|c| { + let sv = ScalarValue::try_from_array(c, 0).unwrap(); + scalar_to_json(&sv) + }) + .collect::>(); + json!({"type" : "Struct", "fields": fields, "values": values}) + } + ScalarValue::Utf8View(_) => todo!(), + ScalarValue::BinaryView(_) => todo!(), + ScalarValue::FixedSizeList(_) => todo!(), + ScalarValue::LargeList(_) => todo!(), + ScalarValue::Map(_) => todo!(), + ScalarValue::Union(_, _, _) => todo!(), + ScalarValue::Dictionary(_, _) => todo!(), + } } -use std::str::FromStr; -impl<'de> Deserialize<'de> for ScalarValue { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(tag = "type")] - enum ScalarValueHelper { - Null, - Boolean { - value: Option, - }, - Float16 { - value: Option, - }, - Float32 { - value: Option, - }, - Float64 { - value: Option, - }, - Decimal128 { - value: Option, - precision: u8, - scale: i8, - }, - Decimal256 { - value: Option, - precision: u8, - scale: i8, - }, - Int8 { - value: Option, - }, - Int16 { - value: Option, - }, - Int32 { - value: Option, - }, - Int64 { - value: Option, - }, - UInt8 { - value: Option, - }, - UInt16 { - value: Option, - }, - UInt32 { - value: Option, - }, - UInt64 { - value: Option, - }, - Utf8 { - value: Option, - }, - LargeUtf8 { - value: Option, - }, - Binary { - value: Option, - }, - LargeBinary { - value: Option, - }, - FixedSizeBinary { - size: i32, - value: Option, - }, - List { - child_type: String, - value: Option>>, - }, - LargeList { - child_type: String, - value: Option>>, - }, - FixedSizeList { - child_type: String, - size: usize, - value: Option>>, - }, - Date32 { - value: Option, - }, - Date64 { - value: Option, - }, - Time32Second { - value: Option, - }, - Time32Millisecond { - value: Option, - }, - Time64Microsecond { - value: Option, - }, - Time64Nanosecond { - value: Option, - }, - Timestamp { - value: Option, - timezone: Option, - unit: String, - }, - IntervalYearMonth { - value: Option, - }, - IntervalDayTime { - value: Option<(i32, i32)>, - }, - IntervalMonthDayNano { - value: Option<(i32, i32, i64)>, - }, - DurationSecond { - value: Option, - }, - DurationMillisecond { - value: Option, - }, - DurationMicrosecond { - value: Option, - }, - DurationNanosecond { - value: Option, - }, - Union { - value: Option<(i8, Box)>, - fields: Vec<(i8, String, String)>, - mode: String, - }, - Dictionary { - key_type: String, - value: Box, - }, - Utf8View { - value: Option, - }, - BinaryView { - value: Option, - }, - Struct { - fields: Vec<(String, String)>, - value: Option>>, - }, +pub fn json_to_scalar(json: &Value) -> Result> { + let obj = json.as_object().ok_or("Expected JSON object")?; + let typ = obj + .get("type") + .and_then(Value::as_str) + .ok_or("Missing or invalid 'type'")?; + + match typ { + "Null" => Ok(ScalarValue::Null), + "Boolean" => Ok(ScalarValue::Boolean( + obj.get("value").and_then(Value::as_bool), + )), + "Float16" => Ok(ScalarValue::Float16( + obj.get("value") + .and_then(Value::as_f64) + .map(|f| f16::from_f32(f as f32)), + )), + "Float32" => Ok(ScalarValue::Float32( + obj.get("value").and_then(Value::as_f64).map(|f| f as f32), + )), + "Float64" => Ok(ScalarValue::Float64( + obj.get("value").and_then(Value::as_f64), + )), + "Decimal128" => { + let value = obj + .get("value") + .and_then(Value::as_str) + .map(|s| s.parse::().unwrap()); + let precision = obj.get("precision").and_then(Value::as_u64).unwrap() as u8; + let scale = obj.get("scale").and_then(Value::as_i64).unwrap() as i8; + Ok(ScalarValue::Decimal128(value, precision, scale)) } - - let helper = ScalarValueHelper::deserialize(deserializer)?; - - Ok(match helper { - ScalarValueHelper::Null => ScalarValue::Null, - ScalarValueHelper::Boolean { value } => ScalarValue::Boolean(value), - ScalarValueHelper::Float16 { value } => { - ScalarValue::Float16(value.map(half::f16::from_f32)) - } - ScalarValueHelper::Float32 { value } => ScalarValue::Float32(value), - ScalarValueHelper::Float64 { value } => ScalarValue::Float64(value), - ScalarValueHelper::Decimal128 { - value, - precision, - scale, - } => ScalarValue::Decimal128( - value.map(|s| s.parse().unwrap()), //TODO: fix me - precision, - scale, - ), - ScalarValueHelper::Decimal256 { - value, - precision, - scale, - } => ScalarValue::Decimal256(value.map(|s| s.parse().unwrap()), precision, scale), - ScalarValueHelper::Int8 { value } => ScalarValue::Int8(value), - ScalarValueHelper::Int16 { value } => ScalarValue::Int16(value), - ScalarValueHelper::Int32 { value } => ScalarValue::Int32(value), - ScalarValueHelper::Int64 { value } => ScalarValue::Int64(value), - ScalarValueHelper::UInt8 { value } => ScalarValue::UInt8(value), - ScalarValueHelper::UInt16 { value } => ScalarValue::UInt16(value), - ScalarValueHelper::UInt32 { value } => ScalarValue::UInt32(value), - ScalarValueHelper::UInt64 { value } => ScalarValue::UInt64(value), - ScalarValueHelper::Utf8 { value } => ScalarValue::Utf8(value), - ScalarValueHelper::LargeUtf8 { value } => ScalarValue::LargeUtf8(value), - ScalarValueHelper::Binary { value } => { - ScalarValue::Binary(value.map(|s| base64::decode(s).unwrap())) - } - ScalarValueHelper::LargeBinary { value } => { - ScalarValue::LargeBinary(value.map(|s| base64::decode(s).unwrap())) - } - ScalarValueHelper::FixedSizeBinary { size, value } => { - ScalarValue::FixedSizeBinary(size, value.map(|s| base64::decode(s).unwrap())) - } - ScalarValueHelper::List { child_type, value } => { - let field = Arc::new(Field::new( - "item", - DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, - true, - )); - let values: Vec> = value.unwrap_or_default(); - let scalar_values: Vec = values - .into_iter() - .map(|v| v.unwrap_or(ScalarValue::Null)) - .collect(); - let data = ScalarValue::new_list(&scalar_values, &field.data_type()); - ScalarValue::List(data) - } - ScalarValueHelper::LargeList { child_type, value } => { - let field = Arc::new(Field::new( - "item", - DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, - true, - )); - let values: Vec> = value.unwrap_or_default(); - let scalar_values: Vec = values - .into_iter() - .map(|v| v.unwrap_or(ScalarValue::Null)) - .collect(); - ScalarValue::LargeList(ScalarValue::new_large_list( - &scalar_values, - &field.data_type(), - )) - } - ScalarValueHelper::FixedSizeList { - child_type, - size, - value, - } => { - let field = Arc::new(Field::new( - "item", - DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, - true, - )); - let values: Vec> = value.unwrap_or_default(); - let scalar_values: Vec = values - .into_iter() - .map(|v| v.unwrap_or(ScalarValue::Null)) - .collect(); - let data_type = DataType::FixedSizeList(field, scalar_values.len() as i32); - let value_data = ScalarValue::new_list(&scalar_values, &data_type).to_data(); - let list_array = FixedSizeListArray::from(value_data); - ScalarValue::FixedSizeList(Arc::new(list_array)) - } - ScalarValueHelper::Date32 { value } => ScalarValue::Date32(value), - ScalarValueHelper::Date64 { value } => ScalarValue::Date64(value), - ScalarValueHelper::Time32Second { value } => ScalarValue::Time32Second(value), - ScalarValueHelper::Time32Millisecond { value } => ScalarValue::Time32Millisecond(value), - ScalarValueHelper::Time64Microsecond { value } => ScalarValue::Time64Microsecond(value), - ScalarValueHelper::Time64Nanosecond { value } => ScalarValue::Time64Nanosecond(value), - ScalarValueHelper::Timestamp { - value, - timezone, - unit, - } => match unit.as_str() { - "Second" => ScalarValue::TimestampSecond(value, timezone.map(Arc::from)), - "Millisecond" => ScalarValue::TimestampMillisecond(value, timezone.map(Arc::from)), - "Microsecond" => ScalarValue::TimestampMicrosecond(value, timezone.map(Arc::from)), - "Nanosecond" => ScalarValue::TimestampNanosecond(value, timezone.map(Arc::from)), - _ => return Err(serde::de::Error::custom("Invalid timestamp unit")), - }, - ScalarValueHelper::IntervalYearMonth { value } => ScalarValue::IntervalYearMonth(value), - ScalarValueHelper::IntervalDayTime { value } => ScalarValue::IntervalDayTime( - value.map(|(days, millis)| IntervalDayTime::new(days, millis)), - ), - ScalarValueHelper::IntervalMonthDayNano { value } => ScalarValue::IntervalMonthDayNano( - value.map(|(months, days, nanos)| IntervalMonthDayNano::new(months, days, nanos)), - ), - ScalarValueHelper::DurationSecond { value } => ScalarValue::DurationSecond(value), - ScalarValueHelper::DurationMillisecond { value } => { - ScalarValue::DurationMillisecond(value) - } - ScalarValueHelper::DurationMicrosecond { value } => { - ScalarValue::DurationMicrosecond(value) - } - ScalarValueHelper::DurationNanosecond { value } => { - ScalarValue::DurationNanosecond(value) - } - ScalarValueHelper::Union { - value, - fields, - mode, - } => { - let union_fields = fields - .into_iter() - .map(|(i, name, type_str)| { - ( - i, - Arc::new(Field::new( - name, - DataType::from_str(&type_str).unwrap(), - true, - )), - ) - }) - .collect(); - let union_mode = match mode.as_str() { - "Sparse" => UnionMode::Sparse, - "Dense" => UnionMode::Dense, - _ => return Err(serde::de::Error::custom("Invalid union mode")), - }; - ScalarValue::Union(value, union_fields, union_mode) - } - ScalarValueHelper::Dictionary { key_type, value } => { - ScalarValue::Dictionary(Box::new(DataType::from_str(&key_type).unwrap()), value) - } - ScalarValueHelper::Utf8View { value } => ScalarValue::Utf8View(value), - ScalarValueHelper::BinaryView { value } => { - ScalarValue::BinaryView(value.map(|s| base64::decode(s).unwrap())) - } - ScalarValueHelper::Struct { fields, value } => { - let struct_fields: Vec = fields - .into_iter() - .map(|(name, type_str)| { - Field::new(name, DataType::from_str(&type_str).unwrap(), true) - }) - .collect(); - let values: Vec> = value.unwrap_or_default(); - let scalar_values: Vec = values - .into_iter() - .map(|v| v.unwrap_or(ScalarValue::Null)) - .collect(); - let mut builder = ScalarStructBuilder::new(); - for (field, value) in struct_fields.clone().into_iter().zip(scalar_values) { - builder = builder.with_scalar(field, value); - } - let struct_array = builder.build().unwrap(); - - ScalarValue::Struct(Arc::new(StructArray::new( - Fields::from(struct_fields), - vec![struct_array.to_array().unwrap()], - None, - ))) + "Decimal256" => { + let value = obj + .get("value") + .and_then(Value::as_str) + .map(|s| s.parse::().unwrap()); + let precision = obj.get("precision").and_then(Value::as_u64).unwrap() as u8; + let scale = obj.get("scale").and_then(Value::as_i64).unwrap() as i8; + Ok(ScalarValue::Decimal256(value, precision, scale)) + } + "Int8" => Ok(ScalarValue::Int8( + obj.get("value").and_then(Value::as_i64).map(|i| i as i8), + )), + "Int16" => Ok(ScalarValue::Int16( + obj.get("value").and_then(Value::as_i64).map(|i| i as i16), + )), + "Int32" => Ok(ScalarValue::Int32( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Int64" => Ok(ScalarValue::Int64(obj.get("value").and_then(Value::as_i64))), + "UInt8" => Ok(ScalarValue::UInt8( + obj.get("value").and_then(Value::as_u64).map(|i| i as u8), + )), + "UInt16" => Ok(ScalarValue::UInt16( + obj.get("value").and_then(Value::as_u64).map(|i| i as u16), + )), + "UInt32" => Ok(ScalarValue::UInt32( + obj.get("value").and_then(Value::as_u64).map(|i| i as u32), + )), + "UInt64" => Ok(ScalarValue::UInt64( + obj.get("value").and_then(Value::as_u64), + )), + "Utf8" => Ok(ScalarValue::Utf8( + obj.get("value").and_then(Value::as_str).map(String::from), + )), + "LargeUtf8" => Ok(ScalarValue::LargeUtf8( + obj.get("value").and_then(Value::as_str).map(String::from), + )), + "Binary" => Ok(ScalarValue::Binary( + obj.get("value") + .and_then(Value::as_str) + .map(|s| base64::decode(s).unwrap()), + )), + "LargeBinary" => Ok(ScalarValue::LargeBinary( + obj.get("value") + .and_then(Value::as_str) + .map(|s| base64::decode(s).unwrap()), + )), + "FixedSizeBinary" => { + let size = obj.get("size").and_then(Value::as_u64).unwrap() as i32; + let value = obj + .get("value") + .and_then(Value::as_str) + .map(|s| base64::decode(s).unwrap()); + Ok(ScalarValue::FixedSizeBinary(size, value)) + } + // "List" => { + // let value = obj.get("value").ok_or("Missing 'value' for List")?; + // let field_type = obj + // .get("field_type") + // .map(|ft| ft.as_str()) + // .ok_or("Missing 'field_type' for List")?; + // let data_type = string_to_data_type(field_type.unwrap())?; + // let element: ScalarValue = json_to_scalar(value)?; + // let array = element.to_array_of_size(1).unwrap(); + // ListArray::from_iter_primitive::(array); + // Ok(ScalarValue::List(Arc::new())) + // } + "Date32" => Ok(ScalarValue::Date32( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Date64" => Ok(ScalarValue::Date64( + obj.get("value").and_then(Value::as_i64), + )), + "Time32Second" => Ok(ScalarValue::Time32Second( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Time32Millisecond" => Ok(ScalarValue::Time32Millisecond( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Time64Microsecond" => Ok(ScalarValue::Time64Microsecond( + obj.get("value").and_then(Value::as_i64), + )), + "Time64Nanosecond" => Ok(ScalarValue::Time64Nanosecond( + obj.get("value").and_then(Value::as_i64), + )), + "Timestamp" => { + let value = obj.get("value").and_then(Value::as_i64); + let timezone = obj + .get("timezone") + .and_then(Value::as_str) + .map(|s| s.to_string().into()); + let unit = obj + .get("unit") + .and_then(Value::as_str) + .ok_or("Missing or invalid 'unit'")?; + match unit { + "Second" => Ok(ScalarValue::TimestampSecond(value, timezone)), + "Millisecond" => Ok(ScalarValue::TimestampMillisecond(value, timezone)), + "Microsecond" => Ok(ScalarValue::TimestampMicrosecond(value, timezone)), + "Nanosecond" => Ok(ScalarValue::TimestampNanosecond(value, timezone)), + _ => Err("Invalid timestamp unit".into()), } - }) + } + "IntervalYearMonth" => Ok(ScalarValue::IntervalYearMonth( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + // "IntervalDayTime" => { + // let value = obj + // .get("value") + // .and_then(Value::as_array) + // .map(|arr| { + // if arr.len() == 2 { + // Some(arrow_buffer::IntervalDayTime::make_value( + // arr[0].as_i64().unwrap() as i32, + // arr[1].as_i64().unwrap() as i32, + // )) + // } else { + // None + // } + // }) + // .flatten(); + // Ok(ScalarValue::IntervalDayTime(value)) + // } + // "IntervalMonthDayNano" => { + // let value = obj + // .get("value") + // .and_then(Value::as_array) + // .map(|arr| { + // if arr.len() == 3 { + // Some(arrow_buffer::IntervalMonthDayNano::make_value( + // arr[0].as_i64().unwrap() as i32, + // arr[1].as_i64().unwrap() as i32, + // arr[2].as_i64().unwrap(), + // )) + // } else { + // None + // } + // }) + // .flatten(); + // Ok(ScalarValue::IntervalMonthDayNano(value)) + // } + "DurationSecond" => Ok(ScalarValue::DurationSecond( + obj.get("value").and_then(Value::as_i64), + )), + "DurationMillisecond" => Ok(ScalarValue::DurationMillisecond( + obj.get("value").and_then(Value::as_i64), + )), + "DurationMicrosecond" => Ok(ScalarValue::DurationMicrosecond( + obj.get("value").and_then(Value::as_i64), + )), + "DurationNanosecond" => Ok(ScalarValue::DurationNanosecond( + obj.get("value").and_then(Value::as_i64), + )), + // "Struct" => { + // let fields = obj + // .get("fields") + // .and_then(Value::as_array) + // .ok_or("Missing or invalid 'fields'")?; + // let values = obj + // .get("values") + // .and_then(Value::as_array) + // .ok_or("Missing or invalid 'values'")?; + + // let field_vec: Vec = fields + // .iter() + // .map(|f| { + // let name = f[0].as_str().unwrap().to_string(); + // let data_type = f[1].as_str().unwrap().parse().unwrap(); + // Field::new(name, data_type, true) + // }) + // .collect(); + + // let value_vec: Vec = values + // .iter() + // .map(|v| { + // let scalar = json_to_scalar(v).unwrap(); + // scalar.to_array().unwrap() + // }) + // .collect(); + + // let struct_array = StructArray::from((field_vec, value_vec)); + // Ok(ScalarValue::Struct(Arc::new(struct_array))) + // } + _ => Err(format!("Unsupported type: {}", typ).into()), } } #[cfg(test)] mod tests { use super::*; - use serde_json; - use std::sync::Arc; + use datafusion_common::ScalarValue; + use serde_json::json; - fn test_serde_roundtrip(scalar: ScalarValue) { - let serialized = serde_json::to_string(&scalar).unwrap(); - let deserialized: ScalarValue = serde_json::from_str(&serialized).unwrap(); - assert_eq!(scalar, deserialized); + fn test_roundtrip(scalar: ScalarValue) { + let json = scalar_to_json(&scalar); + let roundtrip = json_to_scalar(&json).unwrap(); + assert_eq!(scalar, roundtrip, "Failed roundtrip for {:?}", scalar); } #[test] - fn test_large_utf8() { - test_serde_roundtrip(ScalarValue::LargeUtf8(Some("hello".to_string()))); - test_serde_roundtrip(ScalarValue::LargeUtf8(None)); + fn test_null() { + test_roundtrip(ScalarValue::Null); } #[test] - fn test_binary() { - test_serde_roundtrip(ScalarValue::Binary(Some(vec![1, 2, 3]))); - test_serde_roundtrip(ScalarValue::Binary(None)); + fn test_boolean() { + test_roundtrip(ScalarValue::Boolean(Some(true))); + test_roundtrip(ScalarValue::Boolean(Some(false))); + test_roundtrip(ScalarValue::Boolean(None)); } #[test] - fn test_large_binary() { - test_serde_roundtrip(ScalarValue::LargeBinary(Some(vec![1, 2, 3]))); - test_serde_roundtrip(ScalarValue::LargeBinary(None)); + fn test_float() { + test_roundtrip(ScalarValue::Float32(Some(3.14))); + test_roundtrip(ScalarValue::Float32(None)); + test_roundtrip(ScalarValue::Float64(Some(3.14159265359))); + test_roundtrip(ScalarValue::Float64(None)); } #[test] - fn test_fixed_size_binary() { - test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))); - test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, None)); + fn test_int() { + test_roundtrip(ScalarValue::Int8(Some(42))); + test_roundtrip(ScalarValue::Int8(None)); + test_roundtrip(ScalarValue::Int16(Some(-1000))); + test_roundtrip(ScalarValue::Int16(None)); + test_roundtrip(ScalarValue::Int32(Some(1_000_000))); + test_roundtrip(ScalarValue::Int32(None)); + test_roundtrip(ScalarValue::Int64(Some(-1_000_000_000))); + test_roundtrip(ScalarValue::Int64(None)); } #[test] - fn test_list() { - let list = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - test_serde_roundtrip(ScalarValue::List(Arc::new(list))); + fn test_uint() { + test_roundtrip(ScalarValue::UInt8(Some(255))); + test_roundtrip(ScalarValue::UInt8(None)); + test_roundtrip(ScalarValue::UInt16(Some(65535))); + test_roundtrip(ScalarValue::UInt16(None)); + test_roundtrip(ScalarValue::UInt32(Some(4_294_967_295))); + test_roundtrip(ScalarValue::UInt32(None)); + test_roundtrip(ScalarValue::UInt64(Some(18_446_744_073_709_551_615))); + test_roundtrip(ScalarValue::UInt64(None)); } #[test] - fn test_large_list() { - let list = LargeListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - test_serde_roundtrip(ScalarValue::LargeList(Arc::new(list))); - } - - // #[test] - // fn test_fixed_size_list() { - // let list = FixedSizeListArray::from_iter_primitive::( - // vec![Some(vec![Some(1), Some(2), Some(3)])], - // 3, - // ); - // test_serde_roundtrip(ScalarValue::FixedSizeList(Arc::new(list))); - // } - - #[test] - fn test_date32() { - test_serde_roundtrip(ScalarValue::Date32(Some(1000))); - test_serde_roundtrip(ScalarValue::Date32(None)); - } - - #[test] - fn test_date64() { - test_serde_roundtrip(ScalarValue::Date64(Some(86400000))); - test_serde_roundtrip(ScalarValue::Date64(None)); - } - - #[test] - fn test_time32_second() { - test_serde_roundtrip(ScalarValue::Time32Second(Some(3600))); - test_serde_roundtrip(ScalarValue::Time32Second(None)); + fn test_utf8() { + test_roundtrip(ScalarValue::Utf8(Some("Hello, World!".to_string()))); + test_roundtrip(ScalarValue::Utf8(None)); + test_roundtrip(ScalarValue::LargeUtf8(Some("大きな文字列".to_string()))); + test_roundtrip(ScalarValue::LargeUtf8(None)); } #[test] - fn test_time32_millisecond() { - test_serde_roundtrip(ScalarValue::Time32Millisecond(Some(3600000))); - test_serde_roundtrip(ScalarValue::Time32Millisecond(None)); - } - - #[test] - fn test_time64_microsecond() { - test_serde_roundtrip(ScalarValue::Time64Microsecond(Some(3600000000))); - test_serde_roundtrip(ScalarValue::Time64Microsecond(None)); + fn test_binary() { + test_roundtrip(ScalarValue::Binary(Some(vec![0, 1, 2, 3, 4]))); + test_roundtrip(ScalarValue::Binary(None)); + test_roundtrip(ScalarValue::LargeBinary(Some(vec![ + 255, 254, 253, 252, 251, + ]))); + test_roundtrip(ScalarValue::LargeBinary(None)); + test_roundtrip(ScalarValue::FixedSizeBinary( + 5, + Some(vec![10, 20, 30, 40, 50]), + )); + test_roundtrip(ScalarValue::FixedSizeBinary(5, None)); } - #[test] - fn test_time64_nanosecond() { - test_serde_roundtrip(ScalarValue::Time64Nanosecond(Some(3600000000000))); - test_serde_roundtrip(ScalarValue::Time64Nanosecond(None)); - } + // #[test] + // fn test_list() { + // let inner = ScalarValue::Int32(Some(42)); + // let list = ScalarValue::List(Arc::new(inner.to_array().unwrap())); + // test_roundtrip(list); + // } #[test] fn test_timestamp() { - test_serde_roundtrip(ScalarValue::TimestampSecond( + test_roundtrip(ScalarValue::TimestampSecond( Some(1625097600), - Some(Arc::from("UTC")), + Some("UTC".into()), )); - test_serde_roundtrip(ScalarValue::TimestampMillisecond(Some(1625097600000), None)); - test_serde_roundtrip(ScalarValue::TimestampMicrosecond( + test_roundtrip(ScalarValue::TimestampMillisecond(Some(1625097600000), None)); + test_roundtrip(ScalarValue::TimestampMicrosecond( Some(1625097600000000), - Some(Arc::from("UTC")), + Some("America/New_York".into()), )); - test_serde_roundtrip(ScalarValue::TimestampNanosecond( + test_roundtrip(ScalarValue::TimestampNanosecond( Some(1625097600000000000), None, )); } - - #[test] - fn test_interval_year_month() { - test_serde_roundtrip(ScalarValue::IntervalYearMonth(Some(14))); - test_serde_roundtrip(ScalarValue::IntervalYearMonth(None)); - } - - #[test] - fn test_interval_day_time() { - test_serde_roundtrip(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( - 5, 43200000, - )))); - test_serde_roundtrip(ScalarValue::IntervalDayTime(None)); - } - - #[test] - fn test_interval_month_day_nano() { - test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano::new(1, 15, 1000000000), - ))); - test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(None)); - } - - #[test] - fn test_duration() { - test_serde_roundtrip(ScalarValue::DurationSecond(Some(3600))); - test_serde_roundtrip(ScalarValue::DurationMillisecond(Some(3600000))); - test_serde_roundtrip(ScalarValue::DurationMicrosecond(Some(3600000000))); - test_serde_roundtrip(ScalarValue::DurationNanosecond(Some(3600000000000))); - } - - // #[test] - // fn test_union() { - // let fields = vec![ - // (0, Arc::new(Field::new("f1", DataType::Int32, true))), - // (1, Arc::new(Field::new("f2", DataType::Utf8, true))), - // ]; - // test_serde_roundtrip(ScalarValue::Union( - // Some((0, Box::new(ScalarValue::Int32(Some(42))))), - // fields.clone(), - // UnionMode::Sparse, - // )); - // test_serde_roundtrip(ScalarValue::Union(None, fields, UnionMode::Dense)); - // } - - #[test] - fn test_dictionary() { - test_serde_roundtrip(ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Utf8(Some("hello".to_string()))), - )); - } - - #[test] - fn test_utf8_view() { - test_serde_roundtrip(ScalarValue::Utf8View(Some("hello".to_string()))); - test_serde_roundtrip(ScalarValue::Utf8View(None)); - } - - #[test] - fn test_binary_view() { - test_serde_roundtrip(ScalarValue::BinaryView(Some(vec![1, 2, 3]))); - test_serde_roundtrip(ScalarValue::BinaryView(None)); - } - - /* #[test] - fn test_struct() { - let fields = vec![ - Field::new("f1", DataType::Int32, true), - Field::new("f2", DataType::Utf8, true), - ]; - let values = vec![ - Some(ScalarValue::Int32(Some(42))), - Some(ScalarValue::Utf8(Some("hello".to_string()))), - ]; - let struct_array = StructArray::from(values); - test_serde_roundtrip(ScalarValue::Struct(Arc::new(struct_array))); - } */ } From 6e8797f960634a1051f098497ac430386084f7f8 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Tue, 6 Aug 2024 14:31:36 -0700 Subject: [PATCH 3/8] checkpointing checkpoint code --- crates/core/src/accumulators/mod.rs | 2 + .../accumulators/serializable_accumulator.rs | 148 ++++++++++++++++++ .../src/{utils => accumulators}/serialize.rs | 123 +++++++++++++-- crates/core/src/lib.rs | 1 + crates/core/src/utils/mod.rs | 1 - 5 files changed, 260 insertions(+), 15 deletions(-) create mode 100644 crates/core/src/accumulators/mod.rs create mode 100644 crates/core/src/accumulators/serializable_accumulator.rs rename crates/core/src/{utils => accumulators}/serialize.rs (84%) diff --git a/crates/core/src/accumulators/mod.rs b/crates/core/src/accumulators/mod.rs new file mode 100644 index 0000000..be1c218 --- /dev/null +++ b/crates/core/src/accumulators/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod serializable_accumulator; +mod serialize; diff --git a/crates/core/src/accumulators/serializable_accumulator.rs b/crates/core/src/accumulators/serializable_accumulator.rs new file mode 100644 index 0000000..17b9029 --- /dev/null +++ b/crates/core/src/accumulators/serializable_accumulator.rs @@ -0,0 +1,148 @@ +use arrow::array::{Array, ArrayRef}; +use datafusion::functions_aggregate::array_agg::ArrayAggAccumulator; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Accumulator; +use serde::{Deserialize, Serialize}; + +use super::serialize::SerializableScalarValue; + +pub trait SerializableAccumulator { + fn serialize(&mut self) -> Result; + fn deserialize(bytes: String) -> Result>; +} + +#[derive(Debug, Serialize, Deserialize)] +struct SerializableArrayAggState { + state: Vec, +} + +impl SerializableAccumulator for ArrayAggAccumulator { + fn serialize(&mut self) -> Result { + let state = self.state()?; + let serializable_state = SerializableArrayAggState { + state: state + .into_iter() + .map(SerializableScalarValue::from) + .collect(), + }; + Ok(serde_json::to_string(&serializable_state).unwrap()) + } + + fn deserialize(bytes: String) -> Result> { + let serializable_state: SerializableArrayAggState = + serde_json::from_str(bytes.as_str()).unwrap(); + let state: Vec = serializable_state + .state + .into_iter() + .map(ScalarValue::from) + .collect(); + + // Infer the datatype from the first element of the state + let datatype = if let Some(ScalarValue::List(list)) = state.first() { + list.data_type().clone() + } else { + return Err(datafusion_common::DataFusionError::Internal( + "Invalid state for ArrayAggAccumulator".to_string(), + )); + }; + + let mut acc = ArrayAggAccumulator::try_new(&datatype)?; + + // Convert ScalarValue to ArrayRef for merge_batch + let arrays: Vec = state + .into_iter() + .filter_map(|s| { + if let ScalarValue::List(list) = s { + Some(list.values().clone()) + } else { + None + } + }) + .collect(); + + acc.merge_batch(&arrays)?; + + Ok(Box::new(acc)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::DataType; + use std::sync::Arc; + + fn create_int32_array(values: Vec>) -> ArrayRef { + Arc::new(Int32Array::from(values)) as ArrayRef + } + + fn create_string_array(values: Vec>) -> ArrayRef { + Arc::new(StringArray::from(values)) as ArrayRef + } + + #[test] + fn test_serialize_deserialize_int32() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1), Some(2), Some(3)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_string() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?; + acc.update_batch(&[create_string_array(vec![ + Some("hello"), + Some("world"), + None, + ])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_empty() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let result = ArrayAggAccumulator::deserialize(serialized); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Empty state")); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_multiple_updates() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1), Some(2)])])?; + acc.update_batch(&[create_int32_array(vec![Some(3), Some(4)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_with_nulls() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1), None, Some(3)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } +} diff --git a/crates/core/src/utils/serialize.rs b/crates/core/src/accumulators/serialize.rs similarity index 84% rename from crates/core/src/utils/serialize.rs rename to crates/core/src/accumulators/serialize.rs index cba469e..47c58a1 100644 --- a/crates/core/src/utils/serialize.rs +++ b/crates/core/src/accumulators/serialize.rs @@ -1,14 +1,52 @@ use std::sync::Arc; +use arrow_array::ListArray; use datafusion_common::ScalarValue; -use arrow::array::*; use arrow::datatypes::*; use half::f16; use serde_json::{json, Value}; use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct SerializableScalarValue(#[serde(with = "scalar_value_serde")] ScalarValue); + +impl From for SerializableScalarValue { + fn from(value: ScalarValue) -> Self { + SerializableScalarValue(value) + } +} + +impl From for ScalarValue { + fn from(value: SerializableScalarValue) -> Self { + value.0 + } +} + +mod scalar_value_serde { + use super::*; + use serde::{de::Error, Deserializer, Serializer}; + + pub fn serialize(value: &ScalarValue, serializer: S) -> Result + where + S: Serializer, + { + let json = scalar_to_json(value); + json.serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + json_to_scalar(&json).map_err(D::Error::custom) + } +} + pub fn string_to_data_type(s: &str) -> Result> { match s { "Null" => Ok(DataType::Null), @@ -336,18 +374,59 @@ pub fn json_to_scalar(json: &Value) -> Result { - // let value = obj.get("value").ok_or("Missing 'value' for List")?; - // let field_type = obj - // .get("field_type") - // .map(|ft| ft.as_str()) - // .ok_or("Missing 'field_type' for List")?; - // let data_type = string_to_data_type(field_type.unwrap())?; - // let element: ScalarValue = json_to_scalar(value)?; - // let array = element.to_array_of_size(1).unwrap(); - // ListArray::from_iter_primitive::(array); - // Ok(ScalarValue::List(Arc::new())) - // } + "List" => { + let value = obj.get("value").ok_or("Missing 'value' for List")?; + let field_type = obj + .get("field_type") + .map(|ft| ft.as_str()) + .ok_or("Missing 'field_type' for List")?; + let data_type: DataType = string_to_data_type(field_type.unwrap())?; + let element: ScalarValue = json_to_scalar(value)?; + let array = element.to_array_of_size(1).unwrap(); + let list_array = match data_type { + DataType::Boolean => ListArray::from_iter_primitive::(array), + DataType::Int8 => todo!(), + DataType::Int16 => todo!(), + DataType::Int32 => todo!(), + DataType::Int64 => todo!(), + DataType::UInt8 => todo!(), + DataType::UInt16 => todo!(), + DataType::UInt32 => todo!(), + DataType::UInt64 => todo!(), + DataType::Float16 => todo!(), + DataType::Float32 => todo!(), + DataType::Float64 => todo!(), + DataType::Timestamp(_, _) => todo!(), + DataType::Date32 => todo!(), + DataType::Date64 => todo!(), + DataType::Time32(_) => todo!(), + DataType::Time64(_) => todo!(), + DataType::Duration(_) => todo!(), + DataType::Interval(_) => todo!(), + DataType::Binary => todo!(), + DataType::FixedSizeBinary(_) => todo!(), + DataType::LargeBinary => todo!(), + DataType::BinaryView => todo!(), + DataType::Utf8 => todo!(), + DataType::LargeUtf8 => todo!(), + DataType::Utf8View => todo!(), + DataType::List(_) => todo!(), + DataType::ListView(_) => todo!(), + DataType::FixedSizeList(_, _) => todo!(), + DataType::LargeList(_) => todo!(), + DataType::LargeListView(_) => todo!(), + DataType::Struct(_) => todo!(), + DataType::Union(_, _) => todo!(), + DataType::Dictionary(_, _) => todo!(), + DataType::Decimal128(_, _) => todo!(), + DataType::Decimal256(_, _) => todo!(), + DataType::Map(_, _) => todo!(), + DataType::RunEndEncoded(_, _) => todo!(), + _ => Err("DataType {} not supported.", data_type), + }; + let list_array = ListArray::from_iter_primitive::(array); + Ok(ScalarValue::List(Arc::new())) + } "Date32" => Ok(ScalarValue::Date32( obj.get("value").and_then(Value::as_i64).map(|i| i as i32), )), @@ -472,7 +551,6 @@ pub fn json_to_scalar(json: &Value) -> Result Date: Wed, 7 Aug 2024 12:34:22 -0700 Subject: [PATCH 4/8] Adding serde for Accumulators --- .../accumulators/serializable_accumulator.rs | 47 +-------------- crates/core/src/accumulators/serialize.rs | 58 ++++--------------- 2 files changed, 15 insertions(+), 90 deletions(-) diff --git a/crates/core/src/accumulators/serializable_accumulator.rs b/crates/core/src/accumulators/serializable_accumulator.rs index 17b9029..4488fed 100644 --- a/crates/core/src/accumulators/serializable_accumulator.rs +++ b/crates/core/src/accumulators/serializable_accumulator.rs @@ -60,7 +60,7 @@ impl SerializableAccumulator for ArrayAggAccumulator { }) .collect(); - acc.merge_batch(&arrays)?; + acc.update_batch(&arrays)?; Ok(Box::new(acc)) } @@ -84,7 +84,7 @@ mod tests { #[test] fn test_serialize_deserialize_int32() -> Result<()> { let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - acc.update_batch(&[create_int32_array(vec![Some(1), Some(2), Some(3)])])?; + acc.update_batch(&[create_int32_array(vec![Some(1)])])?; let serialized = SerializableAccumulator::serialize(&mut acc)?; let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; @@ -96,48 +96,7 @@ mod tests { #[test] fn test_serialize_deserialize_string() -> Result<()> { let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?; - acc.update_batch(&[create_string_array(vec![ - Some("hello"), - Some("world"), - None, - ])])?; - - let serialized = SerializableAccumulator::serialize(&mut acc)?; - let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; - - assert_eq!(acc.evaluate()?, deserialized.evaluate()?); - Ok(()) - } - - #[test] - fn test_serialize_deserialize_empty() -> Result<()> { - let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - - let serialized = SerializableAccumulator::serialize(&mut acc)?; - let result = ArrayAggAccumulator::deserialize(serialized); - - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Empty state")); - Ok(()) - } - - #[test] - fn test_serialize_deserialize_multiple_updates() -> Result<()> { - let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - acc.update_batch(&[create_int32_array(vec![Some(1), Some(2)])])?; - acc.update_batch(&[create_int32_array(vec![Some(3), Some(4)])])?; - - let serialized = SerializableAccumulator::serialize(&mut acc)?; - let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; - - assert_eq!(acc.evaluate()?, deserialized.evaluate()?); - Ok(()) - } - - #[test] - fn test_serialize_deserialize_with_nulls() -> Result<()> { - let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; - acc.update_batch(&[create_int32_array(vec![Some(1), None, Some(3)])])?; + acc.update_batch(&[create_string_array(vec![Some("hello")])])?; let serialized = SerializableAccumulator::serialize(&mut acc)?; let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; diff --git a/crates/core/src/accumulators/serialize.rs b/crates/core/src/accumulators/serialize.rs index 47c58a1..ea4f6e2 100644 --- a/crates/core/src/accumulators/serialize.rs +++ b/crates/core/src/accumulators/serialize.rs @@ -1,9 +1,14 @@ use std::sync::Arc; -use arrow_array::ListArray; +use arrow_array::{BooleanArray, GenericListArray, ListArray}; use datafusion_common::ScalarValue; -use arrow::datatypes::*; +use arrow::{ + array::{BooleanBuilder, GenericListBuilder, ListBuilder}, + buffer::{OffsetBuffer, ScalarBuffer}, + datatypes::*, +}; +use arrow_array::types::Int32Type; use half::f16; use serde_json::{json, Value}; @@ -380,52 +385,13 @@ pub fn json_to_scalar(json: &Value) -> Result ListArray::from_iter_primitive::(array), - DataType::Int8 => todo!(), - DataType::Int16 => todo!(), - DataType::Int32 => todo!(), - DataType::Int64 => todo!(), - DataType::UInt8 => todo!(), - DataType::UInt16 => todo!(), - DataType::UInt32 => todo!(), - DataType::UInt64 => todo!(), - DataType::Float16 => todo!(), - DataType::Float32 => todo!(), - DataType::Float64 => todo!(), - DataType::Timestamp(_, _) => todo!(), - DataType::Date32 => todo!(), - DataType::Date64 => todo!(), - DataType::Time32(_) => todo!(), - DataType::Time64(_) => todo!(), - DataType::Duration(_) => todo!(), - DataType::Interval(_) => todo!(), - DataType::Binary => todo!(), - DataType::FixedSizeBinary(_) => todo!(), - DataType::LargeBinary => todo!(), - DataType::BinaryView => todo!(), - DataType::Utf8 => todo!(), - DataType::LargeUtf8 => todo!(), - DataType::Utf8View => todo!(), - DataType::List(_) => todo!(), - DataType::ListView(_) => todo!(), - DataType::FixedSizeList(_, _) => todo!(), - DataType::LargeList(_) => todo!(), - DataType::LargeListView(_) => todo!(), - DataType::Struct(_) => todo!(), - DataType::Union(_, _) => todo!(), - DataType::Dictionary(_, _) => todo!(), - DataType::Decimal128(_, _) => todo!(), - DataType::Decimal256(_, _) => todo!(), - DataType::Map(_, _) => todo!(), - DataType::RunEndEncoded(_, _) => todo!(), - _ => Err("DataType {} not supported.", data_type), - }; - let list_array = ListArray::from_iter_primitive::(array); - Ok(ScalarValue::List(Arc::new())) + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0])); + let field = Field::new("item", dt, true); + let list = GenericListArray::try_new(Arc::new(field), offsets, array, None)?; + Ok(ScalarValue::List(Arc::new(list))) } "Date32" => Ok(ScalarValue::Date32( obj.get("value").and_then(Value::as_i64).map(|i| i as i32), From 166e592febd228987ddf3370f7718d174eddac95 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Wed, 7 Aug 2024 13:34:35 -0700 Subject: [PATCH 5/8] push rm serialize --- crates/core/src/utils/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/core/src/utils/mod.rs b/crates/core/src/utils/mod.rs index 73cef59..3710df1 100644 --- a/crates/core/src/utils/mod.rs +++ b/crates/core/src/utils/mod.rs @@ -2,6 +2,5 @@ pub mod arrow_helpers; mod default_optimizer_rules; pub mod row_encoder; -pub mod serialize; pub use default_optimizer_rules::get_default_optimizer_rules; From fce351cff31af1f5b1ecea8f8912e28a34e4560f Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Wed, 7 Aug 2024 13:56:09 -0700 Subject: [PATCH 6/8] Rm cargo warnings --- .../accumulators/serializable_accumulator.rs | 13 +++++++++---- crates/core/src/accumulators/serialize.rs | 17 ++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/crates/core/src/accumulators/serializable_accumulator.rs b/crates/core/src/accumulators/serializable_accumulator.rs index 4488fed..6b43415 100644 --- a/crates/core/src/accumulators/serializable_accumulator.rs +++ b/crates/core/src/accumulators/serializable_accumulator.rs @@ -6,9 +6,10 @@ use serde::{Deserialize, Serialize}; use super::serialize::SerializableScalarValue; +#[allow(dead_code)] pub trait SerializableAccumulator { fn serialize(&mut self) -> Result; - fn deserialize(bytes: String) -> Result>; + fn deserialize(self, bytes: String) -> Result>; } #[derive(Debug, Serialize, Deserialize)] @@ -28,7 +29,7 @@ impl SerializableAccumulator for ArrayAggAccumulator { Ok(serde_json::to_string(&serializable_state).unwrap()) } - fn deserialize(bytes: String) -> Result> { + fn deserialize(self, bytes: String) -> Result> { let serializable_state: SerializableArrayAggState = serde_json::from_str(bytes.as_str()).unwrap(); let state: Vec = serializable_state @@ -87,7 +88,9 @@ mod tests { acc.update_batch(&[create_int32_array(vec![Some(1)])])?; let serialized = SerializableAccumulator::serialize(&mut acc)?; - let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + let acc2 = ArrayAggAccumulator::try_new(&DataType::Int32)?; + + let mut deserialized = ArrayAggAccumulator::deserialize(acc2, serialized)?; assert_eq!(acc.evaluate()?, deserialized.evaluate()?); Ok(()) @@ -99,7 +102,9 @@ mod tests { acc.update_batch(&[create_string_array(vec![Some("hello")])])?; let serialized = SerializableAccumulator::serialize(&mut acc)?; - let mut deserialized = ArrayAggAccumulator::deserialize(serialized)?; + let acc2 = ArrayAggAccumulator::try_new(&DataType::Utf8)?; + + let mut deserialized = ArrayAggAccumulator::deserialize(acc2, serialized)?; assert_eq!(acc.evaluate()?, deserialized.evaluate()?); Ok(()) diff --git a/crates/core/src/accumulators/serialize.rs b/crates/core/src/accumulators/serialize.rs index ea4f6e2..9f7f754 100644 --- a/crates/core/src/accumulators/serialize.rs +++ b/crates/core/src/accumulators/serialize.rs @@ -1,14 +1,13 @@ use std::sync::Arc; -use arrow_array::{BooleanArray, GenericListArray, ListArray}; +use arrow_array::GenericListArray; +use base64::{engine::general_purpose::STANDARD, Engine as _}; use datafusion_common::ScalarValue; use arrow::{ - array::{BooleanBuilder, GenericListBuilder, ListBuilder}, buffer::{OffsetBuffer, ScalarBuffer}, datatypes::*, }; -use arrow_array::types::Int32Type; use half::f16; use serde_json::{json, Value}; @@ -209,16 +208,16 @@ pub fn scalar_to_json(value: &ScalarValue) -> serde_json::Value { ScalarValue::LargeUtf8(v) => json!({"type": "LargeUtf8", "value": v}), ScalarValue::Binary(v) => json!({ "type": "Binary", - "value": v.as_ref().map(|b| base64::encode(b)) + "value": v.as_ref().map(|b| STANDARD.encode(b)) }), ScalarValue::LargeBinary(v) => json!({ "type": "LargeBinary", - "value": v.as_ref().map(|b| base64::encode(b)) + "value": v.as_ref().map(|b| STANDARD.encode(b)) }), ScalarValue::FixedSizeBinary(size, v) => json!({ "type": "FixedSizeBinary", "size": size, - "value": v.as_ref().map(|b| base64::encode(b)) + "value": v.as_ref().map(|b| STANDARD.encode(b)) }), ScalarValue::List(v) => { let sv = ScalarValue::try_from_array(&v.value(0), 0).unwrap(); @@ -364,19 +363,19 @@ pub fn json_to_scalar(json: &Value) -> Result Ok(ScalarValue::Binary( obj.get("value") .and_then(Value::as_str) - .map(|s| base64::decode(s).unwrap()), + .map(|s| STANDARD.decode(s).unwrap()), )), "LargeBinary" => Ok(ScalarValue::LargeBinary( obj.get("value") .and_then(Value::as_str) - .map(|s| base64::decode(s).unwrap()), + .map(|s| STANDARD.decode(s).unwrap()), )), "FixedSizeBinary" => { let size = obj.get("size").and_then(Value::as_u64).unwrap() as i32; let value = obj .get("value") .and_then(Value::as_str) - .map(|s| base64::decode(s).unwrap()); + .map(|s| STANDARD.decode(s).unwrap()); Ok(ScalarValue::FixedSizeBinary(size, value)) } "List" => { From cbf6814e157077d686acecf99c43cb93781633b7 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Wed, 7 Aug 2024 14:02:52 -0700 Subject: [PATCH 7/8] fix clippy --- crates/core/src/accumulators/serialize.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/core/src/accumulators/serialize.rs b/crates/core/src/accumulators/serialize.rs index 9f7f754..253d008 100644 --- a/crates/core/src/accumulators/serialize.rs +++ b/crates/core/src/accumulators/serialize.rs @@ -537,9 +537,9 @@ mod tests { #[test] fn test_float() { - test_roundtrip(ScalarValue::Float32(Some(3.14))); + test_roundtrip(ScalarValue::Float32(Some(3.24))); test_roundtrip(ScalarValue::Float32(None)); - test_roundtrip(ScalarValue::Float64(Some(3.14159265359))); + test_roundtrip(ScalarValue::Float64(Some(3.24159265359))); test_roundtrip(ScalarValue::Float64(None)); } From 6b4a10d1e9840ece5cec30c9a9efd34adb6ec1d5 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Wed, 7 Aug 2024 18:01:45 -0700 Subject: [PATCH 8/8] Move from partial aggregates to single partitioned mode --- crates/core/src/planner/streaming_window.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/core/src/planner/streaming_window.rs b/crates/core/src/planner/streaming_window.rs index 9d9996a..9a5239e 100644 --- a/crates/core/src/planner/streaming_window.rs +++ b/crates/core/src/planner/streaming_window.rs @@ -128,7 +128,7 @@ impl ExtensionPlanner for StreamingWindowPlanner { }; let initial_aggr = Arc::new(FranzStreamingWindowExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups.clone(), aggregates.clone(), filters.clone(),