Skip to content

Commit

Permalink
Adding avro decoder (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
ameyc authored Sep 12, 2024
1 parent c5135cf commit 03bf773
Show file tree
Hide file tree
Showing 10 changed files with 464 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ default = ["python"]

[dependencies]
anyhow = "1.0.86"
datafusion = { workspace = true }
datafusion = { workspace = true, features = ["avro"] }
arrow = { workspace = true }
thiserror = "1.0.63"
pyo3 = { workspace = true, optional = true }
serde_json.workspace = true
apache-avro = "0.16.0"
3 changes: 3 additions & 0 deletions crates/common/src/error/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::result;
use thiserror::Error;

use apache_avro::Error as AvroError;
use arrow::error::ArrowError;
use datafusion::error::DataFusionError;
use serde_json::Error as JsonError;
Expand All @@ -22,6 +23,8 @@ pub enum DenormalizedError {
KafkaConfig(String),
#[error("Arrow Error")]
Arrow(#[from] ArrowError),
#[error("Avro Error")]
AvroError(#[from] AvroError),
#[error("Json Error")]
Json(#[from] JsonError),
#[error(transparent)]
Expand Down
2 changes: 1 addition & 1 deletion crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = { workspace = true }
denormalized-common = { workspace = true }
denormalized-orchestrator = { workspace = true }

datafusion = { workspace = true }
datafusion = { workspace = true, features = ["avro"] }

arrow = { workspace = true }
arrow-schema = { workspace = true }
Expand Down
21 changes: 21 additions & 0 deletions crates/core/src/datasource/kafka/kafka_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ use std::collections::HashMap;
use std::str::FromStr;
use std::{sync::Arc, time::Duration};

use apache_avro::Schema as AvroSchema;
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit};

use datafusion::logical_expr::SortExpr;

use crate::formats::decoders::avro::AvroDecoder;
use crate::formats::decoders::json::JsonDecoder;
use crate::formats::decoders::utils::to_arrow_schema;
use crate::formats::decoders::Decoder;
use crate::formats::StreamEncoding;
use crate::physical_plan::utils::time::TimestampUnit;
use crate::utils::arrow_helpers::infer_arrow_schema_from_json_value;
Expand Down Expand Up @@ -53,6 +58,13 @@ impl KafkaReadConfig {
let consumer: StreamConsumer = client_config.create().expect("Consumer creation failed");
Ok(consumer)
}

pub fn build_decoder(&self) -> Box<dyn Decoder> {
match self.encoding {
StreamEncoding::Avro => Box::new(AvroDecoder::new(self.original_schema.clone())),
StreamEncoding::Json => Box::new(JsonDecoder::new(self.original_schema.clone())),
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -146,6 +158,15 @@ impl KafkaTopicBuilder {
Ok(self)
}

pub fn infer_schema_from_avro(&mut self, avro_schema_str: &str) -> Result<&mut Self> {
self.infer_schema = false;
let avro_schema: AvroSchema =
AvroSchema::parse_str(avro_schema_str).expect("Invalid schema!");
let arrow_schema = to_arrow_schema(&avro_schema)?;
self.schema = Some(Arc::new(arrow_schema));
Ok(self)
}

pub fn with_timestamp(
&mut self,
timestamp_column: String,
Expand Down
10 changes: 4 additions & 6 deletions crates/core/src/datasource/kafka/kafka_stream_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use log::{debug, error};
use serde::{Deserialize, Serialize};

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::formats::decoders::json::JsonDecoder;
use crate::formats::decoders::Decoder;
use crate::physical_plan::stream_table::PartitionStreamExt;
use crate::physical_plan::utils::time::array_to_timestamp_array;
use crate::state_backend::rocksdb_backend::get_global_rocksdb;
Expand Down Expand Up @@ -133,7 +131,6 @@ impl PartitionStream for KafkaStreamRead {
let mut builder = RecordBatchReceiverStreamBuilder::new(self.config.schema.clone(), 1);
let tx = builder.tx();
let canonical_schema = self.config.schema.clone();
let arrow_schema = self.config.original_schema.clone();
let timestamp_column: String = self.config.timestamp_column.clone();
let timestamp_unit = self.config.timestamp_unit.clone();
let batch_timeout = Duration::from_millis(100);
Expand All @@ -143,14 +140,15 @@ impl PartitionStream for KafkaStreamRead {
channel_tag = format!("{}_{}", node_id, partition_tag);
create_channel(channel_tag.as_str(), 10);
}
let mut decoder = self.config.build_decoder();

builder.spawn(async move {
let mut epoch = 0;
if orchestrator::SHOULD_CHECKPOINT {
let orchestrator_sender = get_sender("orchestrator");
let msg = OrchestrationMessage::RegisterStream(channel_tag.clone());
orchestrator_sender.as_ref().unwrap().send(msg).unwrap();
}
let mut json_decoder: JsonDecoder = JsonDecoder::new(arrow_schema.clone());
loop {
let mut last_offsets = HashMap::new();
if let Some(backend) = &state_backend {
Expand Down Expand Up @@ -192,7 +190,7 @@ impl PartitionStream for KafkaStreamRead {
{
Ok(Ok(m)) => {
let payload = m.payload().expect("Message payload is empty");
json_decoder.push_to_buffer(payload.to_owned());
decoder.push_to_buffer(payload.to_owned());
offsets_read.insert(m.partition(), m.offset());
}
Ok(Err(err)) => {
Expand All @@ -207,7 +205,7 @@ impl PartitionStream for KafkaStreamRead {
}

if !offsets_read.is_empty() {
let record_batch = json_decoder.to_record_batch().unwrap();
let record_batch = decoder.to_record_batch().unwrap();
let ts_column = record_batch
.column_by_name(timestamp_column.as_str())
.map(|ts_col| {
Expand Down
159 changes: 159 additions & 0 deletions crates/core/src/formats/decoders/avro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use std::{io::Cursor, sync::Arc};

use arrow_array::RecordBatch;
use arrow_schema::Schema;
use datafusion::datasource::avro_to_arrow::ReaderBuilder;
use denormalized_common::DenormalizedError;

use super::Decoder;

#[derive(Clone)]
pub struct AvroDecoder {
schema: Arc<Schema>,
cache: Vec<Vec<u8>>,
size: usize,
}

impl Decoder for AvroDecoder {
fn push_to_buffer(&mut self, bytes: Vec<u8>) {
self.cache.push(bytes);
self.size += 1;
}

fn to_record_batch(&mut self) -> Result<arrow_array::RecordBatch, DenormalizedError> {
if self.size == 0 {
return Ok(RecordBatch::new_empty(self.schema.clone()));
}
let all_bytes: Vec<u8> = self.cache.iter().flatten().cloned().collect();
// Create a cursor from the concatenated bytes
let cursor = Cursor::new(all_bytes);

// Build the reader
let mut reader = ReaderBuilder::new()
.with_batch_size(self.size)
.with_schema(self.schema.clone())
.build(cursor)?;

// Read the batch
match reader.next() {
Some(Ok(batch)) => Ok(batch),
Some(Err(e)) => Err(DenormalizedError::Arrow(e)),
None => Ok(RecordBatch::new_empty(self.schema.clone())),
}
}
}

impl AvroDecoder {
pub fn new(schema: Arc<Schema>) -> Self {
AvroDecoder {
schema,
cache: Vec::new(),
size: 0,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use apache_avro::{types::Record, Schema as AvroSchema, Writer};
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field};

fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]))
}

fn create_avro_data(records: Vec<(i32, &str)>) -> Vec<u8> {
let avro_schema = AvroSchema::parse_str(
r#"
{
"type": "record",
"name": "test",
"fields": [
{"name": "id", "type": "int"},
{"name": "name", "type": "string"}
]
}
"#,
)
.unwrap();

let mut writer = Writer::new(&avro_schema, Vec::new());

for (id, name) in records {
let mut record: Record<'_> = Record::new(writer.schema()).unwrap();
record.put("id", id);
record.put("name", name);
writer.append(record).unwrap();
}

writer.into_inner().unwrap()
}

#[test]
fn test_push_to_buffer() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema);

decoder.push_to_buffer(vec![1, 2, 3]);
decoder.push_to_buffer(vec![4, 5, 6]);

assert_eq!(decoder.size, 2);
assert_eq!(decoder.cache, vec![vec![1, 2, 3], vec![4, 5, 6]]);
}

#[test]
fn test_empty_record_batch() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema.clone());

let result = decoder.to_record_batch().unwrap();

assert_eq!(result.schema(), schema);
assert_eq!(result.num_rows(), 0);
}

#[test]
fn test_record_batch_with_data() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema.clone());

let avro_data = create_avro_data(vec![(1, "Alice")]);
decoder.push_to_buffer(avro_data);

let result = decoder.to_record_batch().unwrap();

assert_eq!(result.schema(), schema);
assert_eq!(result.num_rows(), 1);

let id_array = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let name_array = result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();

assert_eq!(id_array.value(0), 1);
assert_eq!(name_array.value(0), "Alice");
}

#[test]
fn test_invalid_avro_data() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema);

decoder.push_to_buffer(vec![1, 2, 3]);

let result = decoder.to_record_batch();

assert!(matches!(result, Err(DenormalizedError::DataFusion(_))));
}
}
1 change: 1 addition & 0 deletions crates/core/src/formats/decoders/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::utils::arrow_helpers::json_records_to_arrow_record_batch;

use super::Decoder;

#[derive(Clone)]
pub struct JsonDecoder {
schema: Arc<Schema>,
cache: Vec<Vec<u8>>,
Expand Down
4 changes: 3 additions & 1 deletion crates/core/src/formats/decoders/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use arrow_array::RecordBatch;
use denormalized_common::DenormalizedError;

pub trait Decoder {
pub trait Decoder: Send + Sync {
fn push_to_buffer(&mut self, bytes: Vec<u8>);

fn to_record_batch(&mut self) -> Result<RecordBatch, DenormalizedError>;
}

pub mod avro;
pub mod json;
pub mod utils;
Loading

0 comments on commit 03bf773

Please sign in to comment.