diff --git a/Cargo.lock b/Cargo.lock index 2a1a81d..310f967 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1278,6 +1278,7 @@ name = "denormalized-common" version = "0.0.1" dependencies = [ "anyhow", + "apache-avro", "arrow", "datafusion", "pyo3", diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 9c28e89..fc274f9 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -11,8 +11,9 @@ default = ["python"] [dependencies] anyhow = "1.0.86" -datafusion = { workspace = true } +datafusion = { workspace = true, features = ["avro"] } arrow = { workspace = true } thiserror = "1.0.63" pyo3 = { workspace = true, optional = true } serde_json.workspace = true +apache-avro = "0.16.0" diff --git a/crates/common/src/error/mod.rs b/crates/common/src/error/mod.rs index 3849ee9..93a5eae 100644 --- a/crates/common/src/error/mod.rs +++ b/crates/common/src/error/mod.rs @@ -1,6 +1,7 @@ use std::result; use thiserror::Error; +use apache_avro::Error as AvroError; use arrow::error::ArrowError; use datafusion::error::DataFusionError; use serde_json::Error as JsonError; @@ -22,6 +23,8 @@ pub enum DenormalizedError { KafkaConfig(String), #[error("Arrow Error")] Arrow(#[from] ArrowError), + #[error("Avro Error")] + AvroError(#[from] AvroError), #[error("Json Error")] Json(#[from] JsonError), #[error(transparent)] diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 95091bb..8927dd0 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -7,7 +7,7 @@ edition = { workspace = true } denormalized-common = { workspace = true } denormalized-orchestrator = { workspace = true } -datafusion = { workspace = true } +datafusion = { workspace = true, features = ["avro"] } arrow = { workspace = true } arrow-schema = { workspace = true } diff --git a/crates/core/src/datasource/kafka/kafka_config.rs b/crates/core/src/datasource/kafka/kafka_config.rs index 03ab399..6559cb9 100644 --- a/crates/core/src/datasource/kafka/kafka_config.rs +++ b/crates/core/src/datasource/kafka/kafka_config.rs @@ -2,10 +2,15 @@ use std::collections::HashMap; use std::str::FromStr; use std::{sync::Arc, time::Duration}; +use apache_avro::Schema as AvroSchema; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; use datafusion::logical_expr::SortExpr; +use crate::formats::decoders::avro::AvroDecoder; +use crate::formats::decoders::json::JsonDecoder; +use crate::formats::decoders::utils::to_arrow_schema; +use crate::formats::decoders::Decoder; use crate::formats::StreamEncoding; use crate::physical_plan::utils::time::TimestampUnit; use crate::utils::arrow_helpers::infer_arrow_schema_from_json_value; @@ -53,6 +58,13 @@ impl KafkaReadConfig { let consumer: StreamConsumer = client_config.create().expect("Consumer creation failed"); Ok(consumer) } + + pub fn build_decoder(&self) -> Box { + match self.encoding { + StreamEncoding::Avro => Box::new(AvroDecoder::new(self.original_schema.clone())), + StreamEncoding::Json => Box::new(JsonDecoder::new(self.original_schema.clone())), + } + } } #[derive(Debug)] @@ -146,6 +158,15 @@ impl KafkaTopicBuilder { Ok(self) } + pub fn infer_schema_from_avro(&mut self, avro_schema_str: &str) -> Result<&mut Self> { + self.infer_schema = false; + let avro_schema: AvroSchema = + AvroSchema::parse_str(avro_schema_str).expect("Invalid schema!"); + let arrow_schema = to_arrow_schema(&avro_schema)?; + self.schema = Some(Arc::new(arrow_schema)); + Ok(self) + } + pub fn with_timestamp( &mut self, timestamp_column: String, diff --git a/crates/core/src/datasource/kafka/kafka_stream_read.rs b/crates/core/src/datasource/kafka/kafka_stream_read.rs index e82f530..4fd9b6e 100644 --- a/crates/core/src/datasource/kafka/kafka_stream_read.rs +++ b/crates/core/src/datasource/kafka/kafka_stream_read.rs @@ -11,8 +11,6 @@ use log::{debug, error}; use serde::{Deserialize, Serialize}; use crate::config_extensions::denormalized_config::DenormalizedConfig; -use crate::formats::decoders::json::JsonDecoder; -use crate::formats::decoders::Decoder; use crate::physical_plan::stream_table::PartitionStreamExt; use crate::physical_plan::utils::time::array_to_timestamp_array; use crate::state_backend::rocksdb_backend::get_global_rocksdb; @@ -133,7 +131,6 @@ impl PartitionStream for KafkaStreamRead { let mut builder = RecordBatchReceiverStreamBuilder::new(self.config.schema.clone(), 1); let tx = builder.tx(); let canonical_schema = self.config.schema.clone(); - let arrow_schema = self.config.original_schema.clone(); let timestamp_column: String = self.config.timestamp_column.clone(); let timestamp_unit = self.config.timestamp_unit.clone(); let batch_timeout = Duration::from_millis(100); @@ -143,6 +140,8 @@ impl PartitionStream for KafkaStreamRead { channel_tag = format!("{}_{}", node_id, partition_tag); create_channel(channel_tag.as_str(), 10); } + let mut decoder = self.config.build_decoder(); + builder.spawn(async move { let mut epoch = 0; if orchestrator::SHOULD_CHECKPOINT { @@ -150,7 +149,6 @@ impl PartitionStream for KafkaStreamRead { let msg = OrchestrationMessage::RegisterStream(channel_tag.clone()); orchestrator_sender.as_ref().unwrap().send(msg).unwrap(); } - let mut json_decoder: JsonDecoder = JsonDecoder::new(arrow_schema.clone()); loop { let mut last_offsets = HashMap::new(); if let Some(backend) = &state_backend { @@ -192,7 +190,7 @@ impl PartitionStream for KafkaStreamRead { { Ok(Ok(m)) => { let payload = m.payload().expect("Message payload is empty"); - json_decoder.push_to_buffer(payload.to_owned()); + decoder.push_to_buffer(payload.to_owned()); offsets_read.insert(m.partition(), m.offset()); } Ok(Err(err)) => { @@ -207,7 +205,7 @@ impl PartitionStream for KafkaStreamRead { } if !offsets_read.is_empty() { - let record_batch = json_decoder.to_record_batch().unwrap(); + let record_batch = decoder.to_record_batch().unwrap(); let ts_column = record_batch .column_by_name(timestamp_column.as_str()) .map(|ts_col| { diff --git a/crates/core/src/formats/decoders/avro.rs b/crates/core/src/formats/decoders/avro.rs new file mode 100644 index 0000000..ac249b3 --- /dev/null +++ b/crates/core/src/formats/decoders/avro.rs @@ -0,0 +1,159 @@ +use std::{io::Cursor, sync::Arc}; + +use arrow_array::RecordBatch; +use arrow_schema::Schema; +use datafusion::datasource::avro_to_arrow::ReaderBuilder; +use denormalized_common::DenormalizedError; + +use super::Decoder; + +#[derive(Clone)] +pub struct AvroDecoder { + schema: Arc, + cache: Vec>, + size: usize, +} + +impl Decoder for AvroDecoder { + fn push_to_buffer(&mut self, bytes: Vec) { + self.cache.push(bytes); + self.size += 1; + } + + fn to_record_batch(&mut self) -> Result { + if self.size == 0 { + return Ok(RecordBatch::new_empty(self.schema.clone())); + } + let all_bytes: Vec = self.cache.iter().flatten().cloned().collect(); + // Create a cursor from the concatenated bytes + let cursor = Cursor::new(all_bytes); + + // Build the reader + let mut reader = ReaderBuilder::new() + .with_batch_size(self.size) + .with_schema(self.schema.clone()) + .build(cursor)?; + + // Read the batch + match reader.next() { + Some(Ok(batch)) => Ok(batch), + Some(Err(e)) => Err(DenormalizedError::Arrow(e)), + None => Ok(RecordBatch::new_empty(self.schema.clone())), + } + } +} + +impl AvroDecoder { + pub fn new(schema: Arc) -> Self { + AvroDecoder { + schema, + cache: Vec::new(), + size: 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use apache_avro::{types::Record, Schema as AvroSchema, Writer}; + use arrow_array::{Int32Array, StringArray}; + use arrow_schema::{DataType, Field}; + + fn create_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])) + } + + fn create_avro_data(records: Vec<(i32, &str)>) -> Vec { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "id", "type": "int"}, + {"name": "name", "type": "string"} + ] + } + "#, + ) + .unwrap(); + + let mut writer = Writer::new(&avro_schema, Vec::new()); + + for (id, name) in records { + let mut record: Record<'_> = Record::new(writer.schema()).unwrap(); + record.put("id", id); + record.put("name", name); + writer.append(record).unwrap(); + } + + writer.into_inner().unwrap() + } + + #[test] + fn test_push_to_buffer() { + let schema = create_test_schema(); + let mut decoder = AvroDecoder::new(schema); + + decoder.push_to_buffer(vec![1, 2, 3]); + decoder.push_to_buffer(vec![4, 5, 6]); + + assert_eq!(decoder.size, 2); + assert_eq!(decoder.cache, vec![vec![1, 2, 3], vec![4, 5, 6]]); + } + + #[test] + fn test_empty_record_batch() { + let schema = create_test_schema(); + let mut decoder = AvroDecoder::new(schema.clone()); + + let result = decoder.to_record_batch().unwrap(); + + assert_eq!(result.schema(), schema); + assert_eq!(result.num_rows(), 0); + } + + #[test] + fn test_record_batch_with_data() { + let schema = create_test_schema(); + let mut decoder = AvroDecoder::new(schema.clone()); + + let avro_data = create_avro_data(vec![(1, "Alice")]); + decoder.push_to_buffer(avro_data); + + let result = decoder.to_record_batch().unwrap(); + + assert_eq!(result.schema(), schema); + assert_eq!(result.num_rows(), 1); + + let id_array = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let name_array = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(id_array.value(0), 1); + assert_eq!(name_array.value(0), "Alice"); + } + + #[test] + fn test_invalid_avro_data() { + let schema = create_test_schema(); + let mut decoder = AvroDecoder::new(schema); + + decoder.push_to_buffer(vec![1, 2, 3]); + + let result = decoder.to_record_batch(); + + assert!(matches!(result, Err(DenormalizedError::DataFusion(_)))); + } +} diff --git a/crates/core/src/formats/decoders/json.rs b/crates/core/src/formats/decoders/json.rs index a240d37..05e236b 100644 --- a/crates/core/src/formats/decoders/json.rs +++ b/crates/core/src/formats/decoders/json.rs @@ -7,6 +7,7 @@ use crate::utils::arrow_helpers::json_records_to_arrow_record_batch; use super::Decoder; +#[derive(Clone)] pub struct JsonDecoder { schema: Arc, cache: Vec>, diff --git a/crates/core/src/formats/decoders/mod.rs b/crates/core/src/formats/decoders/mod.rs index ae56178..19af557 100644 --- a/crates/core/src/formats/decoders/mod.rs +++ b/crates/core/src/formats/decoders/mod.rs @@ -1,10 +1,12 @@ use arrow_array::RecordBatch; use denormalized_common::DenormalizedError; -pub trait Decoder { +pub trait Decoder: Send + Sync { fn push_to_buffer(&mut self, bytes: Vec); fn to_record_batch(&mut self) -> Result; } +pub mod avro; pub mod json; +pub mod utils; diff --git a/crates/core/src/formats/decoders/utils.rs b/crates/core/src/formats/decoders/utils.rs new file mode 100644 index 0000000..11cba42 --- /dev/null +++ b/crates/core/src/formats/decoders/utils.rs @@ -0,0 +1,269 @@ +//TODO: Remove this once upstream pub changes are put in. + +use apache_avro::schema::{Alias, DecimalSchema, EnumSchema, FixedSchema, Name, RecordSchema}; +use apache_avro::types::Value; +use apache_avro::Error as AvErr; +use apache_avro::Schema as AvroSchema; +use arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode}; +use arrow::datatypes::{Field, UnionFields}; +use denormalized_common::error::{DenormalizedError, Result}; +use std::collections::HashMap; +use std::sync::Arc; + +/// Converts an avro schema to an arrow schema +pub fn to_arrow_schema(avro_schema: &apache_avro::Schema) -> Result { + let mut schema_fields = vec![]; + match avro_schema { + AvroSchema::Record(RecordSchema { fields, .. }) => { + for field in fields { + schema_fields.push(schema_to_field_with_props( + &field.schema, + Some(&field.name), + field.is_nullable(), + Some(external_props(&field.schema)), + )?) + } + } + schema => schema_fields.push(schema_to_field(schema, Some(""), false)?), + } + + let schema = Schema::new(schema_fields); + Ok(schema) +} + +fn schema_to_field( + schema: &apache_avro::Schema, + name: Option<&str>, + nullable: bool, +) -> Result { + schema_to_field_with_props(schema, name, nullable, Default::default()) +} + +fn schema_to_field_with_props( + schema: &AvroSchema, + name: Option<&str>, + nullable: bool, + props: Option>, +) -> Result { + let mut nullable = nullable; + let field_type: DataType = match schema { + AvroSchema::Ref { .. } => todo!("Add support for AvroSchema::Ref"), + AvroSchema::Null => DataType::Null, + AvroSchema::Boolean => DataType::Boolean, + AvroSchema::Int => DataType::Int32, + AvroSchema::Long => DataType::Int64, + AvroSchema::Float => DataType::Float32, + AvroSchema::Double => DataType::Float64, + AvroSchema::Bytes => DataType::Binary, + AvroSchema::String => DataType::Utf8, + AvroSchema::Array(item_schema) => DataType::List(Arc::new(schema_to_field_with_props( + item_schema, + Some("element"), + false, + None, + )?)), + AvroSchema::Map(value_schema) => { + let value_field = schema_to_field_with_props(value_schema, Some("value"), false, None)?; + DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(value_field.data_type().clone()), + ) + } + AvroSchema::Union(us) => { + // If there are only two variants and one of them is null, set the other type as the field data type + let has_nullable = us + .find_schema_with_known_schemata::(&Value::Null, None, &None) + .is_some(); + let sub_schemas = us.variants(); + if has_nullable && sub_schemas.len() == 2 { + nullable = true; + if let Some(schema) = sub_schemas + .iter() + .find(|&schema| !matches!(schema, AvroSchema::Null)) + { + schema_to_field_with_props(schema, None, has_nullable, None)? + .data_type() + .clone() + } else { + return Err(DenormalizedError::AvroError(AvErr::GetUnionDuplicate)); + } + } else { + let fields = sub_schemas + .iter() + .map(|s| schema_to_field_with_props(s, None, has_nullable, None)) + .collect::>>()?; + let type_ids = 0_i8..fields.len() as i8; + DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) + } + } + AvroSchema::Record(RecordSchema { fields, .. }) => { + let fields: Result<_> = fields + .iter() + .map(|field| { + let mut props = HashMap::new(); + if let Some(doc) = &field.doc { + props.insert("avro::doc".to_string(), doc.clone()); + } + /*if let Some(aliases) = fields.aliases { + props.insert("aliases", aliases); + }*/ + schema_to_field_with_props(&field.schema, Some(&field.name), false, Some(props)) + }) + .collect(); + DataType::Struct(fields?) + } + AvroSchema::Enum(EnumSchema { .. }) => DataType::Utf8, + AvroSchema::Fixed(FixedSchema { size, .. }) => DataType::FixedSizeBinary(*size as i32), + AvroSchema::Decimal(DecimalSchema { + precision, scale, .. + }) => DataType::Decimal128(*precision as u8, *scale as i8), + AvroSchema::Uuid => DataType::FixedSizeBinary(16), + AvroSchema::Date => DataType::Date32, + AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond), + AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None), + AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None), + AvroSchema::LocalTimestampMillis => todo!(), + AvroSchema::LocalTimestampMicros => todo!(), + AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond), + }; + + let data_type = field_type.clone(); + let name = name.unwrap_or_else(|| default_field_name(&data_type)); + + let mut field = Field::new(name, field_type, nullable); + field.set_metadata(props.unwrap_or_default()); + Ok(field) +} + +fn default_field_name(dt: &DataType) -> &str { + match dt { + DataType::Null => "null", + DataType::Boolean => "bit", + DataType::Int8 => "tinyint", + DataType::Int16 => "smallint", + DataType::Int32 => "int", + DataType::Int64 => "bigint", + DataType::UInt8 => "uint1", + DataType::UInt16 => "uint2", + DataType::UInt32 => "uint4", + DataType::UInt64 => "uint8", + DataType::Float16 => "float2", + DataType::Float32 => "float4", + DataType::Float64 => "float8", + DataType::Date32 => "dateday", + DataType::Date64 => "datemilli", + DataType::Time32(tu) | DataType::Time64(tu) => match tu { + TimeUnit::Second => "timesec", + TimeUnit::Millisecond => "timemilli", + TimeUnit::Microsecond => "timemicro", + TimeUnit::Nanosecond => "timenano", + }, + DataType::Timestamp(tu, tz) => { + if tz.is_some() { + match tu { + TimeUnit::Second => "timestampsectz", + TimeUnit::Millisecond => "timestampmillitz", + TimeUnit::Microsecond => "timestampmicrotz", + TimeUnit::Nanosecond => "timestampnanotz", + } + } else { + match tu { + TimeUnit::Second => "timestampsec", + TimeUnit::Millisecond => "timestampmilli", + TimeUnit::Microsecond => "timestampmicro", + TimeUnit::Nanosecond => "timestampnano", + } + } + } + DataType::Duration(_) => "duration", + DataType::Interval(unit) => match unit { + IntervalUnit::YearMonth => "intervalyear", + IntervalUnit::DayTime => "intervalmonth", + IntervalUnit::MonthDayNano => "intervalmonthdaynano", + }, + DataType::Binary => "varbinary", + DataType::FixedSizeBinary(_) => "fixedsizebinary", + DataType::LargeBinary => "largevarbinary", + DataType::Utf8 => "varchar", + DataType::LargeUtf8 => "largevarchar", + DataType::List(_) => "list", + DataType::FixedSizeList(_, _) => "fixed_size_list", + DataType::LargeList(_) => "largelist", + DataType::Struct(_) => "struct", + DataType::Union(_, _) => "union", + DataType::Dictionary(_, _) => "map", + DataType::Map(_, _) => unimplemented!("Map support not implemented"), + DataType::RunEndEncoded(_, _) => { + unimplemented!("RunEndEncoded support not implemented") + } + DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { + unimplemented!("View support not implemented") + } + DataType::Decimal128(_, _) => "decimal", + DataType::Decimal256(_, _) => "decimal", + } +} + +fn external_props(schema: &AvroSchema) -> HashMap { + let mut props = HashMap::new(); + match &schema { + AvroSchema::Record(RecordSchema { + doc: Some(ref doc), .. + }) + | AvroSchema::Enum(EnumSchema { + doc: Some(ref doc), .. + }) + | AvroSchema::Fixed(FixedSchema { + doc: Some(ref doc), .. + }) => { + props.insert("avro::doc".to_string(), doc.clone()); + } + _ => {} + } + match &schema { + AvroSchema::Record(RecordSchema { + name: Name { namespace, .. }, + aliases: Some(aliases), + .. + }) + | AvroSchema::Enum(EnumSchema { + name: Name { namespace, .. }, + aliases: Some(aliases), + .. + }) + | AvroSchema::Fixed(FixedSchema { + name: Name { namespace, .. }, + aliases: Some(aliases), + .. + }) => { + let aliases: Vec = aliases + .iter() + .map(|alias| aliased(alias, namespace.as_deref(), None)) + .collect(); + props.insert( + "avro::aliases".to_string(), + format!("[{}]", aliases.join(",")), + ); + } + _ => {} + } + props +} + +/// Returns the fully qualified name for a field +pub fn aliased(alias: &Alias, namespace: Option<&str>, default_namespace: Option<&str>) -> String { + if alias.namespace().is_some() { + alias.fullname(None) + } else { + let namespace = namespace.as_ref().copied().or(default_namespace); + + match namespace { + Some(ref namespace) => format!("{}.{}", namespace, alias.name()), + None => alias.fullname(None), + } + } +}