diff --git a/Cargo.lock b/Cargo.lock index a70b961..7a1ce9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1080,6 +1080,7 @@ dependencies = [ "chrono", "datafusion", "futures", + "half", "itertools 0.13.0", "log", "rdkafka", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 817ee82..f031124 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -26,3 +26,4 @@ itertools = { workspace = true } serde.workspace = true rocksdb = "0.22.0" bincode = "1.3.3" +half = "2.4.1" diff --git a/crates/core/src/accumulators/mod.rs b/crates/core/src/accumulators/mod.rs new file mode 100644 index 0000000..be1c218 --- /dev/null +++ b/crates/core/src/accumulators/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod serializable_accumulator; +mod serialize; diff --git a/crates/core/src/accumulators/serializable_accumulator.rs b/crates/core/src/accumulators/serializable_accumulator.rs new file mode 100644 index 0000000..0c34626 --- /dev/null +++ b/crates/core/src/accumulators/serializable_accumulator.rs @@ -0,0 +1,112 @@ +use arrow::array::{Array, ArrayRef}; +use datafusion::common::{Result, ScalarValue}; +use datafusion::functions_aggregate::array_agg::ArrayAggAccumulator; +use datafusion::logical_expr::Accumulator; +use serde::{Deserialize, Serialize}; + +use super::serialize::SerializableScalarValue; + +#[allow(dead_code)] +pub trait SerializableAccumulator { + fn serialize(&mut self) -> Result; + fn deserialize(self, bytes: String) -> Result>; +} + +#[derive(Debug, Serialize, Deserialize)] +struct SerializableArrayAggState { + state: Vec, +} + +impl SerializableAccumulator for ArrayAggAccumulator { + fn serialize(&mut self) -> Result { + let state = self.state()?; + let serializable_state = SerializableArrayAggState { + state: state + .into_iter() + .map(SerializableScalarValue::from) + .collect(), + }; + Ok(serde_json::to_string(&serializable_state).unwrap()) + } + + fn deserialize(self, bytes: String) -> Result> { + let serializable_state: SerializableArrayAggState = + serde_json::from_str(bytes.as_str()).unwrap(); + let state: Vec = serializable_state + .state + .into_iter() + .map(ScalarValue::from) + .collect(); + + // Infer the datatype from the first element of the state + let datatype = if let Some(ScalarValue::List(list)) = state.first() { + list.data_type().clone() + } else { + return Err(datafusion::common::DataFusionError::Internal( + "Invalid state for ArrayAggAccumulator".to_string(), + )); + }; + + let mut acc = ArrayAggAccumulator::try_new(&datatype)?; + + // Convert ScalarValue to ArrayRef for merge_batch + let arrays: Vec = state + .into_iter() + .filter_map(|s| { + if let ScalarValue::List(list) = s { + Some(list.values().clone()) + } else { + None + } + }) + .collect(); + + acc.update_batch(&arrays)?; + + Ok(Box::new(acc)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::DataType; + use std::sync::Arc; + + fn create_int32_array(values: Vec>) -> ArrayRef { + Arc::new(Int32Array::from(values)) as ArrayRef + } + + fn create_string_array(values: Vec>) -> ArrayRef { + Arc::new(StringArray::from(values)) as ArrayRef + } + + #[test] + fn test_serialize_deserialize_int32() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Int32)?; + acc.update_batch(&[create_int32_array(vec![Some(1)])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let acc2 = ArrayAggAccumulator::try_new(&DataType::Int32)?; + + let mut deserialized = ArrayAggAccumulator::deserialize(acc2, serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } + + #[test] + fn test_serialize_deserialize_string() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8)?; + acc.update_batch(&[create_string_array(vec![Some("hello")])])?; + + let serialized = SerializableAccumulator::serialize(&mut acc)?; + let acc2 = ArrayAggAccumulator::try_new(&DataType::Utf8)?; + + let mut deserialized = ArrayAggAccumulator::deserialize(acc2, serialized)?; + + assert_eq!(acc.evaluate()?, deserialized.evaluate()?); + Ok(()) + } +} diff --git a/crates/core/src/accumulators/serialize.rs b/crates/core/src/accumulators/serialize.rs new file mode 100644 index 0000000..92332d2 --- /dev/null +++ b/crates/core/src/accumulators/serialize.rs @@ -0,0 +1,633 @@ +use std::sync::Arc; + +use arrow_array::GenericListArray; +use base64::{engine::general_purpose::STANDARD, Engine as _}; +use datafusion::common::ScalarValue; + +use arrow::{ + buffer::{OffsetBuffer, ScalarBuffer}, + datatypes::*, +}; +use half::f16; +use serde_json::{json, Value}; + +use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct SerializableScalarValue(#[serde(with = "scalar_value_serde")] ScalarValue); + +impl From for SerializableScalarValue { + fn from(value: ScalarValue) -> Self { + SerializableScalarValue(value) + } +} + +impl From for ScalarValue { + fn from(value: SerializableScalarValue) -> Self { + value.0 + } +} + +mod scalar_value_serde { + use super::*; + use serde::{de::Error, Deserializer, Serializer}; + + pub fn serialize(value: &ScalarValue, serializer: S) -> Result + where + S: Serializer, + { + let json = scalar_to_json(value); + json.serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = serde_json::Value::deserialize(deserializer)?; + json_to_scalar(&json).map_err(D::Error::custom) + } +} + +pub fn string_to_data_type(s: &str) -> Result> { + match s { + "Null" => Ok(DataType::Null), + "Boolean" => Ok(DataType::Boolean), + "Int8" => Ok(DataType::Int8), + "Int16" => Ok(DataType::Int16), + "Int32" => Ok(DataType::Int32), + "Int64" => Ok(DataType::Int64), + "UInt8" => Ok(DataType::UInt8), + "UInt16" => Ok(DataType::UInt16), + "UInt32" => Ok(DataType::UInt32), + "UInt64" => Ok(DataType::UInt64), + "Float16" => Ok(DataType::Float16), + "Float32" => Ok(DataType::Float32), + "Float64" => Ok(DataType::Float64), + "Binary" => Ok(DataType::Binary), + "LargeBinary" => Ok(DataType::LargeBinary), + "Utf8" => Ok(DataType::Utf8), + "LargeUtf8" => Ok(DataType::LargeUtf8), + "Date32" => Ok(DataType::Date32), + "Date64" => Ok(DataType::Date64), + s if s.starts_with("Timestamp(") => { + let parts: Vec<&str> = s[10..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid Timestamp format".into()); + } + let time_unit = match parts[0].trim() { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => return Err("Invalid TimeUnit".into()), + }; + let timezone = parts[1].trim().trim_matches('"'); + let timezone = if timezone == "None" { + None + } else { + Some(timezone.into()) + }; + Ok(DataType::Timestamp(time_unit, timezone)) + } + s if s.starts_with("Time32(") => { + let time_unit = match &s[7..s.len() - 1] { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + _ => return Err("Invalid TimeUnit for Time32".into()), + }; + Ok(DataType::Time32(time_unit)) + } + s if s.starts_with("Time64(") => { + let time_unit = match &s[7..s.len() - 1] { + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => return Err("Invalid TimeUnit for Time64".into()), + }; + Ok(DataType::Time64(time_unit)) + } + s if s.starts_with("Duration(") => { + let time_unit = match &s[9..s.len() - 1] { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => return Err("Invalid TimeUnit for Duration".into()), + }; + Ok(DataType::Duration(time_unit)) + } + s if s.starts_with("Interval(") => { + let interval_unit = match &s[9..s.len() - 1] { + "YearMonth" => IntervalUnit::YearMonth, + "DayTime" => IntervalUnit::DayTime, + "MonthDayNano" => IntervalUnit::MonthDayNano, + _ => return Err("Invalid IntervalUnit".into()), + }; + Ok(DataType::Interval(interval_unit)) + } + s if s.starts_with("FixedSizeBinary(") => { + let size: i32 = s[16..s.len() - 1].parse()?; + Ok(DataType::FixedSizeBinary(size)) + } + s if s.starts_with("List(") => { + let inner_type = string_to_data_type(&s[5..s.len() - 1])?; + Ok(DataType::List(Arc::new(Field::new( + "item", inner_type, true, + )))) + } + s if s.starts_with("LargeList(") => { + let inner_type = string_to_data_type(&s[10..s.len() - 1])?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", inner_type, true, + )))) + } + s if s.starts_with("FixedSizeList(") => { + let parts: Vec<&str> = s[14..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid FixedSizeList format".into()); + } + let inner_type = string_to_data_type(parts[0].trim())?; + let size: i32 = parts[1].trim().parse()?; + Ok(DataType::FixedSizeList( + Arc::new(Field::new("item", inner_type, true)), + size, + )) + } + s if s.starts_with("Decimal128(") => { + let parts: Vec<&str> = s[10..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid Decimal128 format".into()); + } + let precision: u8 = parts[0].trim().parse()?; + let scale: i8 = parts[1].trim().parse()?; + Ok(DataType::Decimal128(precision, scale)) + } + s if s.starts_with("Decimal256(") => { + let parts: Vec<&str> = s[10..s.len() - 1].split(',').collect(); + if parts.len() != 2 { + return Err("Invalid Decimal256 format".into()); + } + let precision: u8 = parts[0].trim().parse()?; + let scale: i8 = parts[1].trim().parse()?; + Ok(DataType::Decimal256(precision, scale)) + } + _ => Err(format!("Unsupported DataType string: {}", s).into()), + } +} + +pub fn scalar_to_json(value: &ScalarValue) -> serde_json::Value { + match value { + ScalarValue::Null => json!({"type": "Null"}), + ScalarValue::Boolean(v) => json!({"type": "Boolean", "value": v}), + ScalarValue::Float16(v) => json!({"type": "Float16", "value": v.map(|f| f.to_f32())}), + ScalarValue::Float32(v) => json!({"type": "Float32", "value": v}), + ScalarValue::Float64(v) => json!({"type": "Float64", "value": v}), + ScalarValue::Decimal128(v, p, s) => json!({ + "type": "Decimal128", + "value": v.map(|x| x.to_string()), + "precision": p, + "scale": s + }), + ScalarValue::Decimal256(v, p, s) => json!({ + "type": "Decimal256", + "value": v.map(|x| x.to_string()), + "precision": p, + "scale": s + }), + ScalarValue::Int8(v) => json!({"type": "Int8", "value": v}), + ScalarValue::Int16(v) => json!({"type": "Int16", "value": v}), + ScalarValue::Int32(v) => json!({"type": "Int32", "value": v}), + ScalarValue::Int64(v) => json!({"type": "Int64", "value": v}), + ScalarValue::UInt8(v) => json!({"type": "UInt8", "value": v}), + ScalarValue::UInt16(v) => json!({"type": "UInt16", "value": v}), + ScalarValue::UInt32(v) => json!({"type": "UInt32", "value": v}), + ScalarValue::UInt64(v) => json!({"type": "UInt64", "value": v}), + ScalarValue::Utf8(v) => json!({"type": "Utf8", "value": v}), + ScalarValue::LargeUtf8(v) => json!({"type": "LargeUtf8", "value": v}), + ScalarValue::Binary(v) => json!({ + "type": "Binary", + "value": v.as_ref().map(|b| STANDARD.encode(b)) + }), + ScalarValue::LargeBinary(v) => json!({ + "type": "LargeBinary", + "value": v.as_ref().map(|b| STANDARD.encode(b)) + }), + ScalarValue::FixedSizeBinary(size, v) => json!({ + "type": "FixedSizeBinary", + "size": size, + "value": v.as_ref().map(|b| STANDARD.encode(b)) + }), + ScalarValue::List(v) => { + let sv = ScalarValue::try_from_array(&v.value(0), 0).unwrap(); + let dt = sv.data_type().to_string(); + json!({ + "type": "List", + "field_type": dt, + "value": scalar_to_json(&sv) + }) + } + ScalarValue::Date32(v) => json!({"type": "Date32", "value": v}), + ScalarValue::Date64(v) => json!({"type": "Date64", "value": v}), + ScalarValue::Time32Second(v) => json!({"type": "Time32Second", "value": v}), + ScalarValue::Time32Millisecond(v) => json!({"type": "Time32Millisecond", "value": v}), + ScalarValue::Time64Microsecond(v) => json!({"type": "Time64Microsecond", "value": v}), + ScalarValue::Time64Nanosecond(v) => json!({"type": "Time64Nanosecond", "value": v}), + ScalarValue::TimestampSecond(v, tz) + | ScalarValue::TimestampMillisecond(v, tz) + | ScalarValue::TimestampMicrosecond(v, tz) + | ScalarValue::TimestampNanosecond(v, tz) => json!({ + "type": "Timestamp", + "value": v, + "timezone": tz.as_ref().map(|s| s.to_string()), + "unit": match value { + ScalarValue::TimestampSecond(_, _) => "Second", + ScalarValue::TimestampMillisecond(_, _) => "Millisecond", + ScalarValue::TimestampMicrosecond(_, _) => "Microsecond", + ScalarValue::TimestampNanosecond(_, _) => "Nanosecond", + _ => unreachable!(), + } + }), + ScalarValue::IntervalYearMonth(v) => json!({"type": "IntervalYearMonth", "value": v}), + ScalarValue::IntervalDayTime(v) => json!({ + "type": "IntervalDayTime", + "value": v.map(|x| (x.days, x.milliseconds)) + }), + ScalarValue::IntervalMonthDayNano(v) => json!({ + "type": "IntervalMonthDayNano", + "value": v.map(|x| (x.months, x.days, x.nanoseconds)) + }), + ScalarValue::DurationSecond(v) => json!({"type": "DurationSecond", "value": v}), + ScalarValue::DurationMillisecond(v) => json!({"type": "DurationMillisecond", "value": v}), + ScalarValue::DurationMicrosecond(v) => json!({"type": "DurationMicrosecond", "value": v}), + ScalarValue::DurationNanosecond(v) => json!({"type": "DurationNanosecond", "value": v}), + ScalarValue::Struct(v) => { + let fields = v + .as_ref() + .fields() + .iter() + .map(|f| (f.name().to_string(), f.data_type().to_string())) + .collect::>(); + + let values = v + .columns() + .as_ref() + .iter() + .map(|c| { + let sv = ScalarValue::try_from_array(c, 0).unwrap(); + scalar_to_json(&sv) + }) + .collect::>(); + json!({"type" : "Struct", "fields": fields, "values": values}) + } + ScalarValue::Utf8View(_) => todo!(), + ScalarValue::BinaryView(_) => todo!(), + ScalarValue::FixedSizeList(_) => todo!(), + ScalarValue::LargeList(_) => todo!(), + ScalarValue::Map(_) => todo!(), + ScalarValue::Union(_, _, _) => todo!(), + ScalarValue::Dictionary(_, _) => todo!(), + } +} + +pub fn json_to_scalar(json: &Value) -> Result> { + let obj = json.as_object().ok_or("Expected JSON object")?; + let typ = obj + .get("type") + .and_then(Value::as_str) + .ok_or("Missing or invalid 'type'")?; + + match typ { + "Null" => Ok(ScalarValue::Null), + "Boolean" => Ok(ScalarValue::Boolean( + obj.get("value").and_then(Value::as_bool), + )), + "Float16" => Ok(ScalarValue::Float16( + obj.get("value") + .and_then(Value::as_f64) + .map(|f| f16::from_f32(f as f32)), + )), + "Float32" => Ok(ScalarValue::Float32( + obj.get("value").and_then(Value::as_f64).map(|f| f as f32), + )), + "Float64" => Ok(ScalarValue::Float64( + obj.get("value").and_then(Value::as_f64), + )), + "Decimal128" => { + let value = obj + .get("value") + .and_then(Value::as_str) + .map(|s| s.parse::().unwrap()); + let precision = obj.get("precision").and_then(Value::as_u64).unwrap() as u8; + let scale = obj.get("scale").and_then(Value::as_i64).unwrap() as i8; + Ok(ScalarValue::Decimal128(value, precision, scale)) + } + "Decimal256" => { + let value = obj + .get("value") + .and_then(Value::as_str) + .map(|s| s.parse::().unwrap()); + let precision = obj.get("precision").and_then(Value::as_u64).unwrap() as u8; + let scale = obj.get("scale").and_then(Value::as_i64).unwrap() as i8; + Ok(ScalarValue::Decimal256(value, precision, scale)) + } + "Int8" => Ok(ScalarValue::Int8( + obj.get("value").and_then(Value::as_i64).map(|i| i as i8), + )), + "Int16" => Ok(ScalarValue::Int16( + obj.get("value").and_then(Value::as_i64).map(|i| i as i16), + )), + "Int32" => Ok(ScalarValue::Int32( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Int64" => Ok(ScalarValue::Int64(obj.get("value").and_then(Value::as_i64))), + "UInt8" => Ok(ScalarValue::UInt8( + obj.get("value").and_then(Value::as_u64).map(|i| i as u8), + )), + "UInt16" => Ok(ScalarValue::UInt16( + obj.get("value").and_then(Value::as_u64).map(|i| i as u16), + )), + "UInt32" => Ok(ScalarValue::UInt32( + obj.get("value").and_then(Value::as_u64).map(|i| i as u32), + )), + "UInt64" => Ok(ScalarValue::UInt64( + obj.get("value").and_then(Value::as_u64), + )), + "Utf8" => Ok(ScalarValue::Utf8( + obj.get("value").and_then(Value::as_str).map(String::from), + )), + "LargeUtf8" => Ok(ScalarValue::LargeUtf8( + obj.get("value").and_then(Value::as_str).map(String::from), + )), + "Binary" => Ok(ScalarValue::Binary( + obj.get("value") + .and_then(Value::as_str) + .map(|s| STANDARD.decode(s).unwrap()), + )), + "LargeBinary" => Ok(ScalarValue::LargeBinary( + obj.get("value") + .and_then(Value::as_str) + .map(|s| STANDARD.decode(s).unwrap()), + )), + "FixedSizeBinary" => { + let size = obj.get("size").and_then(Value::as_u64).unwrap() as i32; + let value = obj + .get("value") + .and_then(Value::as_str) + .map(|s| STANDARD.decode(s).unwrap()); + Ok(ScalarValue::FixedSizeBinary(size, value)) + } + "List" => { + let value = obj.get("value").ok_or("Missing 'value' for List")?; + let field_type = obj + .get("field_type") + .map(|ft| ft.as_str()) + .ok_or("Missing 'field_type' for List")?; + let dt: DataType = string_to_data_type(field_type.unwrap())?; + let element: ScalarValue = json_to_scalar(value)?; + let array = element.to_array_of_size(1).unwrap(); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0])); + let field = Field::new("item", dt, true); + let list = GenericListArray::try_new(Arc::new(field), offsets, array, None)?; + Ok(ScalarValue::List(Arc::new(list))) + } + "Date32" => Ok(ScalarValue::Date32( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Date64" => Ok(ScalarValue::Date64( + obj.get("value").and_then(Value::as_i64), + )), + "Time32Second" => Ok(ScalarValue::Time32Second( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Time32Millisecond" => Ok(ScalarValue::Time32Millisecond( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + "Time64Microsecond" => Ok(ScalarValue::Time64Microsecond( + obj.get("value").and_then(Value::as_i64), + )), + "Time64Nanosecond" => Ok(ScalarValue::Time64Nanosecond( + obj.get("value").and_then(Value::as_i64), + )), + "Timestamp" => { + let value = obj.get("value").and_then(Value::as_i64); + let timezone = obj + .get("timezone") + .and_then(Value::as_str) + .map(|s| s.to_string().into()); + let unit = obj + .get("unit") + .and_then(Value::as_str) + .ok_or("Missing or invalid 'unit'")?; + match unit { + "Second" => Ok(ScalarValue::TimestampSecond(value, timezone)), + "Millisecond" => Ok(ScalarValue::TimestampMillisecond(value, timezone)), + "Microsecond" => Ok(ScalarValue::TimestampMicrosecond(value, timezone)), + "Nanosecond" => Ok(ScalarValue::TimestampNanosecond(value, timezone)), + _ => Err("Invalid timestamp unit".into()), + } + } + "IntervalYearMonth" => Ok(ScalarValue::IntervalYearMonth( + obj.get("value").and_then(Value::as_i64).map(|i| i as i32), + )), + // "IntervalDayTime" => { + // let value = obj + // .get("value") + // .and_then(Value::as_array) + // .map(|arr| { + // if arr.len() == 2 { + // Some(arrow_buffer::IntervalDayTime::make_value( + // arr[0].as_i64().unwrap() as i32, + // arr[1].as_i64().unwrap() as i32, + // )) + // } else { + // None + // } + // }) + // .flatten(); + // Ok(ScalarValue::IntervalDayTime(value)) + // } + // "IntervalMonthDayNano" => { + // let value = obj + // .get("value") + // .and_then(Value::as_array) + // .map(|arr| { + // if arr.len() == 3 { + // Some(arrow_buffer::IntervalMonthDayNano::make_value( + // arr[0].as_i64().unwrap() as i32, + // arr[1].as_i64().unwrap() as i32, + // arr[2].as_i64().unwrap(), + // )) + // } else { + // None + // } + // }) + // .flatten(); + // Ok(ScalarValue::IntervalMonthDayNano(value)) + // } + "DurationSecond" => Ok(ScalarValue::DurationSecond( + obj.get("value").and_then(Value::as_i64), + )), + "DurationMillisecond" => Ok(ScalarValue::DurationMillisecond( + obj.get("value").and_then(Value::as_i64), + )), + "DurationMicrosecond" => Ok(ScalarValue::DurationMicrosecond( + obj.get("value").and_then(Value::as_i64), + )), + "DurationNanosecond" => Ok(ScalarValue::DurationNanosecond( + obj.get("value").and_then(Value::as_i64), + )), + // "Struct" => { + // let fields = obj + // .get("fields") + // .and_then(Value::as_array) + // .ok_or("Missing or invalid 'fields'")?; + // let values = obj + // .get("values") + // .and_then(Value::as_array) + // .ok_or("Missing or invalid 'values'")?; + + // let field_vec: Vec = fields + // .iter() + // .map(|f| { + // let name = f[0].as_str().unwrap().to_string(); + // let data_type = f[1].as_str().unwrap().parse().unwrap(); + // Field::new(name, data_type, true) + // }) + // .collect(); + + // let value_vec: Vec = values + // .iter() + // .map(|v| { + // let scalar = json_to_scalar(v).unwrap(); + // scalar.to_array().unwrap() + // }) + // .collect(); + + // let struct_array = StructArray::from((field_vec, value_vec)); + // Ok(ScalarValue::Struct(Arc::new(struct_array))) + // } + _ => Err(format!("Unsupported type: {}", typ).into()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::ScalarValue; + + fn test_roundtrip(scalar: ScalarValue) { + let json = scalar_to_json(&scalar); + let roundtrip = json_to_scalar(&json).unwrap(); + assert_eq!(scalar, roundtrip, "Failed roundtrip for {:?}", scalar); + } + + #[test] + fn test_null() { + test_roundtrip(ScalarValue::Null); + } + + #[test] + fn test_boolean() { + test_roundtrip(ScalarValue::Boolean(Some(true))); + test_roundtrip(ScalarValue::Boolean(Some(false))); + test_roundtrip(ScalarValue::Boolean(None)); + } + + #[test] + fn test_float() { + test_roundtrip(ScalarValue::Float32(Some(3.24))); + test_roundtrip(ScalarValue::Float32(None)); + test_roundtrip(ScalarValue::Float64(Some(3.24159265359))); + test_roundtrip(ScalarValue::Float64(None)); + } + + #[test] + fn test_int() { + test_roundtrip(ScalarValue::Int8(Some(42))); + test_roundtrip(ScalarValue::Int8(None)); + test_roundtrip(ScalarValue::Int16(Some(-1000))); + test_roundtrip(ScalarValue::Int16(None)); + test_roundtrip(ScalarValue::Int32(Some(1_000_000))); + test_roundtrip(ScalarValue::Int32(None)); + test_roundtrip(ScalarValue::Int64(Some(-1_000_000_000))); + test_roundtrip(ScalarValue::Int64(None)); + } + + #[test] + fn test_uint() { + test_roundtrip(ScalarValue::UInt8(Some(255))); + test_roundtrip(ScalarValue::UInt8(None)); + test_roundtrip(ScalarValue::UInt16(Some(65535))); + test_roundtrip(ScalarValue::UInt16(None)); + test_roundtrip(ScalarValue::UInt32(Some(4_294_967_295))); + test_roundtrip(ScalarValue::UInt32(None)); + test_roundtrip(ScalarValue::UInt64(Some(18_446_744_073_709_551_615))); + test_roundtrip(ScalarValue::UInt64(None)); + } + + #[test] + fn test_utf8() { + test_roundtrip(ScalarValue::Utf8(Some("Hello, World!".to_string()))); + test_roundtrip(ScalarValue::Utf8(None)); + test_roundtrip(ScalarValue::LargeUtf8(Some("大きな文字列".to_string()))); + test_roundtrip(ScalarValue::LargeUtf8(None)); + } + + #[test] + fn test_binary() { + test_roundtrip(ScalarValue::Binary(Some(vec![0, 1, 2, 3, 4]))); + test_roundtrip(ScalarValue::Binary(None)); + test_roundtrip(ScalarValue::LargeBinary(Some(vec![ + 255, 254, 253, 252, 251, + ]))); + test_roundtrip(ScalarValue::LargeBinary(None)); + test_roundtrip(ScalarValue::FixedSizeBinary( + 5, + Some(vec![10, 20, 30, 40, 50]), + )); + test_roundtrip(ScalarValue::FixedSizeBinary(5, None)); + } + + // #[test] + // fn test_list() { + // let inner = ScalarValue::Int32(Some(42)); + // let list = ScalarValue::List(Arc::new(inner.to_array().unwrap())); + // test_roundtrip(list); + // } + + #[test] + fn test_timestamp() { + test_roundtrip(ScalarValue::TimestampSecond( + Some(1625097600), + Some("UTC".into()), + )); + test_roundtrip(ScalarValue::TimestampMillisecond(Some(1625097600000), None)); + test_roundtrip(ScalarValue::TimestampMicrosecond( + Some(1625097600000000), + Some("America/New_York".into()), + )); + test_roundtrip(ScalarValue::TimestampNanosecond( + Some(1625097600000000000), + None, + )); + } + + #[test] + fn test_serializable_scalar_value() { + let original = ScalarValue::Int32(Some(42)); + let serializable = SerializableScalarValue::from(original.clone()); + + // Serialize + let serialized = serde_json::to_string(&serializable).unwrap(); + + // Deserialize + let deserialized: SerializableScalarValue = serde_json::from_str(&serialized).unwrap(); + + // Convert back to ScalarValue + let result: ScalarValue = deserialized.into(); + + assert_eq!(original, result); + } +} diff --git a/crates/core/src/config_extensions/denormalized_config.rs b/crates/core/src/config_extensions/denormalized_config.rs index c75c3eb..640a4b8 100644 --- a/crates/core/src/config_extensions/denormalized_config.rs +++ b/crates/core/src/config_extensions/denormalized_config.rs @@ -1,5 +1,5 @@ -use datafusion::config::ConfigExtension; use datafusion::common::extensions_options; +use datafusion::config::ConfigExtension; extensions_options! { pub struct DenormalizedConfig { diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 5c890b2..a261ea7 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1,12 +1,12 @@ use std::sync::Arc; use tokio::sync::RwLock; +use datafusion::common::{DataFusionError, Result}; use datafusion::datasource::TableProvider; use datafusion::execution::{ config::SessionConfig, context::SessionContext, runtime_env::RuntimeEnv, session_state::SessionStateBuilder, }; -use datafusion::common::{DataFusionError, Result}; use crate::datasource::kafka::TopicReader; use crate::datastream::DataStream; diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index deaf05c..73b3cd2 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -1,3 +1,4 @@ +pub mod accumulators; pub mod config_extensions; pub mod context; pub mod datasource; diff --git a/crates/core/src/physical_plan/utils/accumulators.rs b/crates/core/src/physical_plan/utils/accumulators.rs index 554d6b2..caf938c 100644 --- a/crates/core/src/physical_plan/utils/accumulators.rs +++ b/crates/core/src/physical_plan/utils/accumulators.rs @@ -1,8 +1,8 @@ use std::sync::Arc; +use datafusion::common::Result; use datafusion::logical_expr::Accumulator; use datafusion::physical_expr::AggregateExpr; -use datafusion::common::Result; pub(crate) type AccumulatorItem = Box; diff --git a/crates/core/src/planner/streaming_window.rs b/crates/core/src/planner/streaming_window.rs index cb842e8..34ec2d2 100644 --- a/crates/core/src/planner/streaming_window.rs +++ b/crates/core/src/planner/streaming_window.rs @@ -129,7 +129,7 @@ impl ExtensionPlanner for StreamingWindowPlanner { }; let initial_aggr = Arc::new(FranzStreamingWindowExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups.clone(), aggregates.clone(), filters.clone(), diff --git a/crates/core/src/utils/mod.rs b/crates/core/src/utils/mod.rs index 73cef59..3710df1 100644 --- a/crates/core/src/utils/mod.rs +++ b/crates/core/src/utils/mod.rs @@ -2,6 +2,5 @@ pub mod arrow_helpers; mod default_optimizer_rules; pub mod row_encoder; -pub mod serialize; pub use default_optimizer_rules::get_default_optimizer_rules; diff --git a/crates/core/src/utils/serialize.rs b/crates/core/src/utils/serialize.rs deleted file mode 100644 index 0941e27..0000000 --- a/crates/core/src/utils/serialize.rs +++ /dev/null @@ -1,933 +0,0 @@ -// use arrow::datatypes::DataType; -// use serde::ser::SerializeStruct; -// use serde::{Deserialize, Deserializer, Serialize, Serializer}; -// use std::sync::Arc; - -// use datafusion::common::scalar::ScalarStructBuilder; -// use datafusion::common::ScalarValue; - -// use arrow::array::*; -// use arrow::datatypes::*; - -// pub struct SerializableScalarValue { -// value: ScalarValue, -// } - -// impl Serialize for SerializableScalarValue { -// fn serialize(&self, serializer: S) -> Result -// where -// S: Serializer, -// { -// match self.value { -// ScalarValue::Null => { -// let mut st = serializer.serialize_struct("ScalarValue", 1)?; -// st.serialize_field("type", "Null")?; -// st.end() -// } -// ScalarValue::Boolean(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Boolean")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Float16(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Float16")?; -// st.serialize_field("value", &v.map(|f| f.to_f32()))?; -// st.end() -// } -// ScalarValue::Float32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Float32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Float64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Float64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Decimal128(v, p, s) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Decimal128")?; -// st.serialize_field("value", &v.map(|x| x.to_string()))?; -// st.serialize_field("precision", p)?; -// st.serialize_field("scale", s)?; -// st.end() -// } -// ScalarValue::Decimal256(v, p, s) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Decimal256")?; -// st.serialize_field("value", &v.map(|x| x.to_string()))?; -// st.serialize_field("precision", p)?; -// st.serialize_field("scale", s)?; -// st.end() -// } -// ScalarValue::Int8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Int16(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int16")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Int32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Int64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Int64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt16(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt16")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::UInt64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "UInt64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Utf8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Utf8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::LargeUtf8(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "LargeUtf8")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Binary(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Binary")?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::LargeBinary(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "LargeBinary")?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::FixedSizeBinary(size, v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "FixedSizeBinary")?; -// st.serialize_field("size", size)?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::List(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "List")?; -// if let arr = v.as_ref() { -// let list_data_type = arr.data_type(); -// println!("arr len {}", arr.len()); -// if let DataType::List(field) = list_data_type { -// st.serialize_field("child_type", &field.data_type().to_string())?; -// let values = arr.value(0); -// let nested_values = if arr.is_null(0) { -// vec![None] -// } else { -// let ret = (0..values.len()) -// .map(|i| ScalarValue::try_from_array(&values, i).map(Some).unwrap()) -// .collect::>>(); -// ret -// }; -// st.serialize_field("value", &nested_values)?; -// } else { -// return Err(serde::ser::Error::custom("Invalid List data type")); -// } -// } else { -// st.serialize_field("child_type", &DataType::Null.to_string())?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// ScalarValue::Date32(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Date32")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Date64(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Date64")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time32Second(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time32Second")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time32Millisecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time32Millisecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time64Microsecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time64Microsecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Time64Nanosecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Time64Nanosecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::TimestampSecond(v, tz) -// | ScalarValue::TimestampMillisecond(v, tz) -// | ScalarValue::TimestampMicrosecond(v, tz) -// | ScalarValue::TimestampNanosecond(v, tz) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Timestamp")?; -// st.serialize_field("value", v)?; -// st.serialize_field("timezone", &tz.as_ref().map(|s| s.to_string()))?; -// st.serialize_field( -// "unit", -// match self.value { -// ScalarValue::TimestampSecond(_, _) => "Second", -// ScalarValue::TimestampMillisecond(_, _) => "Millisecond", -// ScalarValue::TimestampMicrosecond(_, _) => "Microsecond", -// ScalarValue::TimestampNanosecond(_, _) => "Nanosecond", -// _ => unreachable!(), -// }, -// )?; -// st.end() -// } -// ScalarValue::IntervalYearMonth(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "IntervalYearMonth")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::IntervalDayTime(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "IntervalDayTime")?; -// st.serialize_field("value", &v.map(|x| (x.days, x.milliseconds)))?; -// st.end() -// } -// ScalarValue::IntervalMonthDayNano(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "IntervalMonthDayNano")?; -// st.serialize_field("value", &v.map(|x| (x.months, x.days, x.nanoseconds)))?; -// st.end() -// } -// ScalarValue::DurationSecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationSecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::DurationMillisecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationMillisecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::DurationMicrosecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationMicrosecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::DurationNanosecond(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "DurationNanosecond")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::Union(v, fields, mode) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "Union")?; -// st.serialize_field("value", v)?; -// st.serialize_field( -// "fields", -// &fields -// .iter() -// .map(|(i, f)| (i, f.name().to_string(), f.data_type().to_string())) -// .collect::>(), -// )?; -// st.serialize_field( -// "mode", -// match mode { -// UnionMode::Sparse => "Sparse", -// UnionMode::Dense => "Dense", -// }, -// )?; -// st.end() -// } -// ScalarValue::Dictionary(key_type, value) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "Dictionary")?; -// st.serialize_field("key_type", &key_type.to_string())?; -// st.serialize_field("value", value)?; -// st.end() -// } -// ScalarValue::Utf8View(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "Utf8View")?; -// st.serialize_field("value", v)?; -// st.end() -// } -// ScalarValue::BinaryView(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 2)?; -// st.serialize_field("type", "BinaryView")?; -// st.serialize_field("value", &v.as_ref().map(|b| base64::encode(b)))?; -// st.end() -// } -// ScalarValue::LargeList(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "LargeList")?; -// if let arr = v.as_ref() { -// let list_data_type = arr.data_type(); -// if let DataType::LargeList(field) = list_data_type { -// st.serialize_field("child_type", &field.data_type().to_string())?; -// let values = arr.value(0); -// let nested_values = if arr.is_null(0) { -// vec![None] -// } else { -// let ret = (0..values.len()) -// .map(|i| ScalarValue::try_from_array(&values, i).map(Some).unwrap()) -// .collect::>>(); -// ret -// }; -// st.serialize_field("value", &nested_values)?; -// } else { -// return Err(serde::ser::Error::custom("Invalid LargeList data type")); -// } -// } else { -// st.serialize_field("child_type", &DataType::Null.to_string())?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// ScalarValue::FixedSizeList(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 4)?; -// st.serialize_field("type", "FixedSizeList")?; -// if let arr = v.as_ref() { -// let list_data_type = arr.data_type(); -// if let DataType::FixedSizeList(field, size) = list_data_type { -// st.serialize_field("child_type", &field.data_type().to_string())?; -// st.serialize_field("size", size)?; -// let values = arr.value(0); -// let nested_values = if arr.is_null(0) { -// vec![None] -// } else { -// let ret = (0..values.len()) -// .map(|i| ScalarValue::try_from_array(&values, i).map(Some).unwrap()) -// .collect::>>(); -// ret -// }; -// st.serialize_field("value", &nested_values)?; -// } else { -// return Err(serde::ser::Error::custom("Invalid FixedSizeList data type")); -// } -// } else { -// st.serialize_field("child_type", &DataType::Null.to_string())?; -// st.serialize_field("size", &0)?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// ScalarValue::Struct(v) => { -// let mut st = serializer.serialize_struct("ScalarValue", 3)?; -// st.serialize_field("type", "Struct")?; -// if let struct_arr = v.as_ref() { -// let fields: Vec<(String, String)> = struct_arr -// .fields() -// .iter() -// .map(|f| (f.name().to_string(), f.data_type().to_string())) -// .collect(); -// st.serialize_field("fields", &fields)?; - -// let values: Vec> = struct_arr -// .columns() -// .iter() -// .enumerate() -// .map(|(i, field)| { -// if struct_arr.is_null(0) { -// Ok(None) -// } else { -// ScalarValue::try_from_array(field, 0) -// .map(Some) -// .map_err(serde::ser::Error::custom) -// } -// }) -// .collect::>()?; -// st.serialize_field("value", &values)?; -// } else { -// st.serialize_field("fields", &Vec::<(String, String)>::new())?; -// st.serialize_field("value", &Option::>>::None)?; -// } -// st.end() -// } -// } -// } -// } - -// use std::str::FromStr; - -// impl<'de> Deserialize<'de> for ScalarValue { -// fn deserialize(deserializer: D) -> Result -// where -// D: Deserializer<'de>, -// { -// #[derive(Deserialize)] -// #[serde(tag = "type")] -// enum ScalarValueHelper { -// Null, -// Boolean { -// value: Option, -// }, -// Float16 { -// value: Option, -// }, -// Float32 { -// value: Option, -// }, -// Float64 { -// value: Option, -// }, -// Decimal128 { -// value: Option, -// precision: u8, -// scale: i8, -// }, -// Decimal256 { -// value: Option, -// precision: u8, -// scale: i8, -// }, -// Int8 { -// value: Option, -// }, -// Int16 { -// value: Option, -// }, -// Int32 { -// value: Option, -// }, -// Int64 { -// value: Option, -// }, -// UInt8 { -// value: Option, -// }, -// UInt16 { -// value: Option, -// }, -// UInt32 { -// value: Option, -// }, -// UInt64 { -// value: Option, -// }, -// Utf8 { -// value: Option, -// }, -// LargeUtf8 { -// value: Option, -// }, -// Binary { -// value: Option, -// }, -// LargeBinary { -// value: Option, -// }, -// FixedSizeBinary { -// size: i32, -// value: Option, -// }, -// List { -// child_type: String, -// value: Option>>, -// }, -// LargeList { -// child_type: String, -// value: Option>>, -// }, -// FixedSizeList { -// child_type: String, -// size: usize, -// value: Option>>, -// }, -// Date32 { -// value: Option, -// }, -// Date64 { -// value: Option, -// }, -// Time32Second { -// value: Option, -// }, -// Time32Millisecond { -// value: Option, -// }, -// Time64Microsecond { -// value: Option, -// }, -// Time64Nanosecond { -// value: Option, -// }, -// Timestamp { -// value: Option, -// timezone: Option, -// unit: String, -// }, -// IntervalYearMonth { -// value: Option, -// }, -// IntervalDayTime { -// value: Option<(i32, i32)>, -// }, -// IntervalMonthDayNano { -// value: Option<(i32, i32, i64)>, -// }, -// DurationSecond { -// value: Option, -// }, -// DurationMillisecond { -// value: Option, -// }, -// DurationMicrosecond { -// value: Option, -// }, -// DurationNanosecond { -// value: Option, -// }, -// Union { -// value: Option<(i8, Box)>, -// fields: Vec<(i8, String, String)>, -// mode: String, -// }, -// Dictionary { -// key_type: String, -// value: Box, -// }, -// Utf8View { -// value: Option, -// }, -// BinaryView { -// value: Option, -// }, -// Struct { -// fields: Vec<(String, String)>, -// value: Option>>, -// }, -// } - -// let helper = ScalarValueHelper::deserialize(deserializer)?; - -// Ok(match helper { -// ScalarValueHelper::Null => ScalarValue::Null, -// ScalarValueHelper::Boolean { value } => ScalarValue::Boolean(value), -// ScalarValueHelper::Float16 { value } => { -// ScalarValue::Float16(value.map(half::f16::from_f32)) -// } -// ScalarValueHelper::Float32 { value } => ScalarValue::Float32(value), -// ScalarValueHelper::Float64 { value } => ScalarValue::Float64(value), -// ScalarValueHelper::Decimal128 { -// value, -// precision, -// scale, -// } => ScalarValue::Decimal128( -// value.map(|s| s.parse().unwrap()), //TODO: fix me -// precision, -// scale, -// ), -// ScalarValueHelper::Decimal256 { -// value, -// precision, -// scale, -// } => ScalarValue::Decimal256(value.map(|s| s.parse().unwrap()), precision, scale), -// ScalarValueHelper::Int8 { value } => ScalarValue::Int8(value), -// ScalarValueHelper::Int16 { value } => ScalarValue::Int16(value), -// ScalarValueHelper::Int32 { value } => ScalarValue::Int32(value), -// ScalarValueHelper::Int64 { value } => ScalarValue::Int64(value), -// ScalarValueHelper::UInt8 { value } => ScalarValue::UInt8(value), -// ScalarValueHelper::UInt16 { value } => ScalarValue::UInt16(value), -// ScalarValueHelper::UInt32 { value } => ScalarValue::UInt32(value), -// ScalarValueHelper::UInt64 { value } => ScalarValue::UInt64(value), -// ScalarValueHelper::Utf8 { value } => ScalarValue::Utf8(value), -// ScalarValueHelper::LargeUtf8 { value } => ScalarValue::LargeUtf8(value), -// ScalarValueHelper::Binary { value } => { -// ScalarValue::Binary(value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::LargeBinary { value } => { -// ScalarValue::LargeBinary(value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::FixedSizeBinary { size, value } => { -// ScalarValue::FixedSizeBinary(size, value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::List { child_type, value } => { -// let field = Arc::new(Field::new( -// "item", -// DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, -// true, -// )); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// let data = ScalarValue::new_list(&scalar_values, &field.data_type()); -// ScalarValue::List(data) -// } -// ScalarValueHelper::LargeList { child_type, value } => { -// let field = Arc::new(Field::new( -// "item", -// DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, -// true, -// )); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// ScalarValue::LargeList(ScalarValue::new_large_list( -// &scalar_values, -// &field.data_type(), -// )) -// } -// ScalarValueHelper::FixedSizeList { -// child_type, -// size, -// value, -// } => { -// let field = Arc::new(Field::new( -// "item", -// DataType::from_str(&child_type).map_err(serde::de::Error::custom)?, -// true, -// )); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// let data_type = DataType::FixedSizeList(field, scalar_values.len() as i32); -// let value_data = ScalarValue::new_list(&scalar_values, &data_type).to_data(); -// let list_array = FixedSizeListArray::from(value_data); -// ScalarValue::FixedSizeList(Arc::new(list_array)) -// } -// ScalarValueHelper::Date32 { value } => ScalarValue::Date32(value), -// ScalarValueHelper::Date64 { value } => ScalarValue::Date64(value), -// ScalarValueHelper::Time32Second { value } => ScalarValue::Time32Second(value), -// ScalarValueHelper::Time32Millisecond { value } => ScalarValue::Time32Millisecond(value), -// ScalarValueHelper::Time64Microsecond { value } => ScalarValue::Time64Microsecond(value), -// ScalarValueHelper::Time64Nanosecond { value } => ScalarValue::Time64Nanosecond(value), -// ScalarValueHelper::Timestamp { -// value, -// timezone, -// unit, -// } => match unit.as_str() { -// "Second" => ScalarValue::TimestampSecond(value, timezone.map(Arc::from)), -// "Millisecond" => ScalarValue::TimestampMillisecond(value, timezone.map(Arc::from)), -// "Microsecond" => ScalarValue::TimestampMicrosecond(value, timezone.map(Arc::from)), -// "Nanosecond" => ScalarValue::TimestampNanosecond(value, timezone.map(Arc::from)), -// _ => return Err(serde::de::Error::custom("Invalid timestamp unit")), -// }, -// ScalarValueHelper::IntervalYearMonth { value } => ScalarValue::IntervalYearMonth(value), -// ScalarValueHelper::IntervalDayTime { value } => ScalarValue::IntervalDayTime( -// value.map(|(days, millis)| IntervalDayTime::new(days, millis)), -// ), -// ScalarValueHelper::IntervalMonthDayNano { value } => ScalarValue::IntervalMonthDayNano( -// value.map(|(months, days, nanos)| IntervalMonthDayNano::new(months, days, nanos)), -// ), -// ScalarValueHelper::DurationSecond { value } => ScalarValue::DurationSecond(value), -// ScalarValueHelper::DurationMillisecond { value } => { -// ScalarValue::DurationMillisecond(value) -// } -// ScalarValueHelper::DurationMicrosecond { value } => { -// ScalarValue::DurationMicrosecond(value) -// } -// ScalarValueHelper::DurationNanosecond { value } => { -// ScalarValue::DurationNanosecond(value) -// } -// ScalarValueHelper::Union { -// value, -// fields, -// mode, -// } => { -// let union_fields = fields -// .into_iter() -// .map(|(i, name, type_str)| { -// ( -// i, -// Arc::new(Field::new( -// name, -// DataType::from_str(&type_str).unwrap(), -// true, -// )), -// ) -// }) -// .collect(); -// let union_mode = match mode.as_str() { -// "Sparse" => UnionMode::Sparse, -// "Dense" => UnionMode::Dense, -// _ => return Err(serde::de::Error::custom("Invalid union mode")), -// }; -// ScalarValue::Union(value, union_fields, union_mode) -// } -// ScalarValueHelper::Dictionary { key_type, value } => { -// ScalarValue::Dictionary(Box::new(DataType::from_str(&key_type).unwrap()), value) -// } -// ScalarValueHelper::Utf8View { value } => ScalarValue::Utf8View(value), -// ScalarValueHelper::BinaryView { value } => { -// ScalarValue::BinaryView(value.map(|s| base64::decode(s).unwrap())) -// } -// ScalarValueHelper::Struct { fields, value } => { -// let struct_fields: Vec = fields -// .into_iter() -// .map(|(name, type_str)| { -// Field::new(name, DataType::from_str(&type_str).unwrap(), true) -// }) -// .collect(); -// let values: Vec> = value.unwrap_or_default(); -// let scalar_values: Vec = values -// .into_iter() -// .map(|v| v.unwrap_or(ScalarValue::Null)) -// .collect(); -// let mut builder = ScalarStructBuilder::new(); -// for (field, value) in struct_fields.clone().into_iter().zip(scalar_values) { -// builder = builder.with_scalar(field, value); -// } -// let struct_array = builder.build().unwrap(); - -// ScalarValue::Struct(Arc::new(StructArray::new( -// Fields::from(struct_fields), -// vec![struct_array.to_array().unwrap()], -// None, -// ))) -// } -// }) -// } -// } - -// #[cfg(test)] -// mod tests { -// use super::*; -// use serde_json; -// use std::sync::Arc; - -// fn test_serde_roundtrip(scalar: ScalarValue) { -// let serialized = serde_json::to_string(&scalar).unwrap(); -// let deserialized: ScalarValue = serde_json::from_str(&serialized).unwrap(); -// assert_eq!(scalar, deserialized); -// } - -// #[test] -// fn test_large_utf8() { -// test_serde_roundtrip(ScalarValue::LargeUtf8(Some("hello".to_string()))); -// test_serde_roundtrip(ScalarValue::LargeUtf8(None)); -// } - -// #[test] -// fn test_binary() { -// test_serde_roundtrip(ScalarValue::Binary(Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::Binary(None)); -// } - -// #[test] -// fn test_large_binary() { -// test_serde_roundtrip(ScalarValue::LargeBinary(Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::LargeBinary(None)); -// } - -// #[test] -// fn test_fixed_size_binary() { -// test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::FixedSizeBinary(3, None)); -// } - -// #[test] -// fn test_list() { -// let list = ListArray::from_iter_primitive::(vec![Some(vec![ -// Some(1), -// Some(2), -// Some(3), -// ])]); -// test_serde_roundtrip(ScalarValue::List(Arc::new(list))); -// } - -// #[test] -// fn test_large_list() { -// let list = LargeListArray::from_iter_primitive::(vec![Some(vec![ -// Some(1), -// Some(2), -// Some(3), -// ])]); -// test_serde_roundtrip(ScalarValue::LargeList(Arc::new(list))); -// } - -// // #[test] -// // fn test_fixed_size_list() { -// // let list = FixedSizeListArray::from_iter_primitive::( -// // vec![Some(vec![Some(1), Some(2), Some(3)])], -// // 3, -// // ); -// // test_serde_roundtrip(ScalarValue::FixedSizeList(Arc::new(list))); -// // } - -// #[test] -// fn test_date32() { -// test_serde_roundtrip(ScalarValue::Date32(Some(1000))); -// test_serde_roundtrip(ScalarValue::Date32(None)); -// } - -// #[test] -// fn test_date64() { -// test_serde_roundtrip(ScalarValue::Date64(Some(86400000))); -// test_serde_roundtrip(ScalarValue::Date64(None)); -// } - -// #[test] -// fn test_time32_second() { -// test_serde_roundtrip(ScalarValue::Time32Second(Some(3600))); -// test_serde_roundtrip(ScalarValue::Time32Second(None)); -// } - -// #[test] -// fn test_time32_millisecond() { -// test_serde_roundtrip(ScalarValue::Time32Millisecond(Some(3600000))); -// test_serde_roundtrip(ScalarValue::Time32Millisecond(None)); -// } - -// #[test] -// fn test_time64_microsecond() { -// test_serde_roundtrip(ScalarValue::Time64Microsecond(Some(3600000000))); -// test_serde_roundtrip(ScalarValue::Time64Microsecond(None)); -// } - -// #[test] -// fn test_time64_nanosecond() { -// test_serde_roundtrip(ScalarValue::Time64Nanosecond(Some(3600000000000))); -// test_serde_roundtrip(ScalarValue::Time64Nanosecond(None)); -// } - -// #[test] -// fn test_timestamp() { -// test_serde_roundtrip(ScalarValue::TimestampSecond( -// Some(1625097600), -// Some(Arc::from("UTC")), -// )); -// test_serde_roundtrip(ScalarValue::TimestampMillisecond(Some(1625097600000), None)); -// test_serde_roundtrip(ScalarValue::TimestampMicrosecond( -// Some(1625097600000000), -// Some(Arc::from("UTC")), -// )); -// test_serde_roundtrip(ScalarValue::TimestampNanosecond( -// Some(1625097600000000000), -// None, -// )); -// } - -// #[test] -// fn test_interval_year_month() { -// test_serde_roundtrip(ScalarValue::IntervalYearMonth(Some(14))); -// test_serde_roundtrip(ScalarValue::IntervalYearMonth(None)); -// } - -// #[test] -// fn test_interval_day_time() { -// test_serde_roundtrip(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( -// 5, 43200000, -// )))); -// test_serde_roundtrip(ScalarValue::IntervalDayTime(None)); -// } - -// #[test] -// fn test_interval_month_day_nano() { -// test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(Some( -// IntervalMonthDayNano::new(1, 15, 1000000000), -// ))); -// test_serde_roundtrip(ScalarValue::IntervalMonthDayNano(None)); -// } - -// #[test] -// fn test_duration() { -// test_serde_roundtrip(ScalarValue::DurationSecond(Some(3600))); -// test_serde_roundtrip(ScalarValue::DurationMillisecond(Some(3600000))); -// test_serde_roundtrip(ScalarValue::DurationMicrosecond(Some(3600000000))); -// test_serde_roundtrip(ScalarValue::DurationNanosecond(Some(3600000000000))); -// } - -// // #[test] -// // fn test_union() { -// // let fields = vec![ -// // (0, Arc::new(Field::new("f1", DataType::Int32, true))), -// // (1, Arc::new(Field::new("f2", DataType::Utf8, true))), -// // ]; -// // test_serde_roundtrip(ScalarValue::Union( -// // Some((0, Box::new(ScalarValue::Int32(Some(42))))), -// // fields.clone(), -// // UnionMode::Sparse, -// // )); -// // test_serde_roundtrip(ScalarValue::Union(None, fields, UnionMode::Dense)); -// // } - -// #[test] -// fn test_dictionary() { -// test_serde_roundtrip(ScalarValue::Dictionary( -// Box::new(DataType::Int8), -// Box::new(ScalarValue::Utf8(Some("hello".to_string()))), -// )); -// } - -// #[test] -// fn test_utf8_view() { -// test_serde_roundtrip(ScalarValue::Utf8View(Some("hello".to_string()))); -// test_serde_roundtrip(ScalarValue::Utf8View(None)); -// } - -// #[test] -// fn test_binary_view() { -// test_serde_roundtrip(ScalarValue::BinaryView(Some(vec![1, 2, 3]))); -// test_serde_roundtrip(ScalarValue::BinaryView(None)); -// } - -// /* #[test] -// fn test_struct() { -// let fields = vec![ -// Field::new("f1", DataType::Int32, true), -// Field::new("f2", DataType::Utf8, true), -// ]; -// let values = vec![ -// Some(ScalarValue::Int32(Some(42))), -// Some(ScalarValue::Utf8(Some("hello".to_string()))), -// ]; -// let struct_array = StructArray::from(values); -// test_serde_roundtrip(ScalarValue::Struct(Arc::new(struct_array))); -// } */ -// } diff --git a/examples/examples/csv_streaming.rs b/examples/examples/csv_streaming.rs index c08e37a..aee299f 100644 --- a/examples/examples/csv_streaming.rs +++ b/examples/examples/csv_streaming.rs @@ -2,8 +2,8 @@ use datafusion::common::test_util::datafusion_test_data; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::prelude::*; use datafusion::logical_expr::{col, max, min}; +use datafusion::prelude::*; /// This example demonstrates executing a simple query against an Arrow data source (CSV) and /// fetching results with streaming aggregation and streaming window diff --git a/examples/examples/kafka_rideshare.rs b/examples/examples/kafka_rideshare.rs index 59491b7..eb858f0 100644 --- a/examples/examples/kafka_rideshare.rs +++ b/examples/examples/kafka_rideshare.rs @@ -2,9 +2,9 @@ #![allow(unused_variables)] use datafusion::error::Result; -use datafusion::logical_expr::{col, max, min}; use datafusion::functions::core::expr_ext::FieldAccessor; use datafusion::functions_aggregate::count::count; +use datafusion::logical_expr::{col, max, min}; use df_streams_core::context::Context; use df_streams_core::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder}; diff --git a/examples/examples/seed_postgres_db.rs b/examples/examples/seed_postgres_db.rs deleted file mode 100644 index 9ea665a..0000000 --- a/examples/examples/seed_postgres_db.rs +++ /dev/null @@ -1,38 +0,0 @@ -use tokio_postgres::{Error, GenericClient, NoTls}; - -/// docker run --name postgres -e POSTGRES_PASSWORD=password -e POSTGRES_DB=postgres_db -p 5432:5432 -d postgres:16-alpine -#[tokio::main] -async fn main() -> Result<(), Error> { - let (client, connection) = tokio_postgres::connect( - "host=localhost user=postgres password=password dbname=postgres_db", - NoTls, - ) - .await?; - - // The connection object performs the actual communication with the database, - // so spawn it off to run on its own. - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {}", e); - } - }); - - let table_name = "companies"; - client - .execute( - format!( - r#"CREATE TABLE IF NOT EXISTS {} (id SERIAL PRIMARY KEY, name VARCHAR)"#, - table_name - ) - .as_str(), - &[], - ) - .await?; - - let stmt = client - .prepare(format!("INSERT INTO {} (name) VALUES ($1)", table_name).as_str()) - .await?; - client.execute(&stmt, &[&"test"]).await?; - - Ok(()) -} diff --git a/examples/examples/sink_to_postgres.rs b/examples/examples/sink_to_postgres.rs deleted file mode 100644 index bffe124..0000000 --- a/examples/examples/sink_to_postgres.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::time::Duration; - -use datafusion::error::Result; -use datafusion::functions_aggregate::average::avg; -use datafusion::logical_expr::{col, max, min}; - -use df_streams_core::context::Context; -use df_streams_core::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder}; -use df_streams_core::physical_plan::utils::time::TimestampUnit; - -#[tokio::main] -async fn main() -> Result<()> { - let sample_event = r#"{"occurred_at_ms": 1715201766763, "temperature": 87.2}"#; - - let bootstrap_servers = String::from("localhost:9092"); - - let ctx = Context::new()?; - - let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone()); - - let source_topic = topic_builder - .with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis) - .with_encoding("json")? - .with_topic(String::from("temperature")) - .infer_schema_from_json(sample_event)? - .build_reader(ConnectionOpts::from([ - ("auto.offset.reset".to_string(), "earliest".to_string()), - ("group.id".to_string(), "sample_pipeline".to_string()), - ])) - .await?; - - let ds = ctx.from_topic(source_topic).await?.streaming_window( - vec![], - vec![ - min(col("temperature")).alias("min"), - max(col("temperature")).alias("max"), - avg(col("temperature")).alias("average"), - ], - Duration::from_millis(1_000), // 5 second window - None, - )?; - - // ds.clone().print_stream().await?; - - println!("{}", ds.df.schema()); - - Ok(()) -} diff --git a/examples/examples/stream_join.rs b/examples/examples/stream_join.rs index f9d541c..223f26c 100644 --- a/examples/examples/stream_join.rs +++ b/examples/examples/stream_join.rs @@ -3,9 +3,9 @@ #![allow(unused_imports)] use datafusion::error::Result; -use datafusion::logical_expr::{col, max, min}; use datafusion::functions::core::expr_ext::FieldAccessor; use datafusion::functions_aggregate::count::count; +use datafusion::logical_expr::{col, max, min}; use df_streams_core::context::Context; use df_streams_core::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder};