From 6d3ec66be01670ca0fcf5b95f2df3d0214da8065 Mon Sep 17 00:00:00 2001 From: Matt Green Date: Sat, 21 Sep 2024 15:47:22 -0700 Subject: [PATCH] Add udf example (#41) * Add udf example * formatting * remove comments --- Cargo.lock | 1 + crates/core/src/datastream.rs | 7 ++ examples/Cargo.toml | 3 +- examples/examples/udf_example.rs | 129 ++++++++++++++++++++++++++++++ py-denormalized/src/datastream.rs | 8 +- 5 files changed, 143 insertions(+), 5 deletions(-) create mode 100644 examples/examples/udf_example.rs diff --git a/Cargo.lock b/Cargo.lock index 20a3dc1..c76efd5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1334,6 +1334,7 @@ name = "denormalized-examples" version = "0.0.1" dependencies = [ "arrow", + "arrow-array", "arrow-schema", "datafusion", "denormalized", diff --git a/crates/core/src/datastream.rs b/crates/core/src/datastream.rs index 03c1e96..482ebd2 100644 --- a/crates/core/src/datastream.rs +++ b/crates/core/src/datastream.rs @@ -64,6 +64,13 @@ impl DataStream { }) } + pub fn with_column(self, name: &str, expr: Expr) -> Result { + Ok(Self { + df: Arc::new(self.df.as_ref().clone().with_column(name, expr)?), + context: self.context.clone(), + }) + } + // Join two streams using the specified expression pub fn join_on( self, diff --git a/examples/Cargo.toml b/examples/Cargo.toml index b807118..2af54c7 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -9,8 +9,9 @@ denormalized = { workspace = true } datafusion = { workspace = true } -arrow = { workspace = true } +arrow = { workspace = true, features = ["prettyprint"] } arrow-schema = { workspace = true } +arrow-array = { workspace = true } tracing = { workspace = true } futures = { workspace = true } tracing-log = { workspace = true } diff --git a/examples/examples/udf_example.rs b/examples/examples/udf_example.rs new file mode 100644 index 0000000..2206bea --- /dev/null +++ b/examples/examples/udf_example.rs @@ -0,0 +1,129 @@ +use std::any::Any; +use std::sync::Arc; +use std::time::Duration; + +use datafusion::common::cast::as_float64_array; +use datafusion::functions_aggregate::average::avg; +use datafusion::functions_aggregate::count::count; +use datafusion::functions_aggregate::expr_fn::{max, min}; +use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion::logical_expr::Volatility; +use datafusion::logical_expr::{col, lit}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; + +use denormalized::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder}; +use denormalized::prelude::*; + +use arrow::array::{ArrayRef, Float64Array}; +use arrow::datatypes::DataType; + +use denormalized_examples::get_sample_json; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .init(); + + let sample_event = get_sample_json(); + + let bootstrap_servers = String::from("localhost:9092"); + + let ctx = Context::new()?; + let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone()); + + let source_topic = topic_builder + .with_topic(String::from("temperature")) + .infer_schema_from_json(sample_event.as_str())? + .with_encoding("json")? + .with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis) + .build_reader(ConnectionOpts::from([ + ("auto.offset.reset".to_string(), "latest".to_string()), + ("group.id".to_string(), "sample_pipeline".to_string()), + ])) + .await?; + + let sample_udf = ScalarUDF::from(SampleUdf::new()); + + let ds = ctx + .from_topic(source_topic) + .await? + .window( + vec![col("sensor_name")], + vec![ + count(col("reading")).alias("count"), + min(col("reading")).alias("min"), + max(col("reading")).alias("max"), + avg(col("reading")).alias("average"), + ], + Duration::from_millis(1_000), + None, + )? + .filter(col("max").gt(lit(113)))? + .with_column("sample", sample_udf.call(vec![col("max")]))?; + + ds.clone().print_physical_plan().await?; + ds.clone().print_stream().await?; + + Ok(()) +} + +#[derive(Debug, Clone)] +struct SampleUdf { + signature: Signature, +} + +impl SampleUdf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SampleUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sample_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(DataType::Float64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion::error::Result { + assert_eq!(args.len(), 1); + let value = &args[0]; + assert_eq!(value.data_type(), DataType::Float64); + + let args = ColumnarValue::values_to_arrays(args)?; + let value = as_float64_array(&args[0]).expect("cast failed"); + + let array = value + .iter() + .map(|v| match v { + Some(f) => { + let value = f + 20_f64; + Some(value) + } + _ => None, + }) + .collect::(); + + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + } + + fn output_ordering( + &self, + input: &[ExprProperties], + ) -> datafusion::error::Result { + Ok(input[0].sort_properties) + } +} diff --git a/py-denormalized/src/datastream.rs b/py-denormalized/src/datastream.rs index 6e56e70..fa5f604 100644 --- a/py-denormalized/src/datastream.rs +++ b/py-denormalized/src/datastream.rs @@ -136,7 +136,7 @@ impl PyDataStream { // Use u64 for durations since using PyDelta type requires non-Py_LIMITED_API to be // enabled let window_length_duration = Duration::from_millis(window_length_millis); - let window_slide_duration = slide_millis.map(|d| Duration::from_millis(d)); + let window_slide_duration = slide_millis.map(Duration::from_millis); let ds = self.ds.as_ref().clone().window( groups, @@ -183,7 +183,7 @@ impl PyDataStream { let fut: JoinHandle> = rt.spawn(async move { ds.print_stream().await }); - let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??; + wait_for_future(py, fut).map_err(py_denormalized_err)??; Ok(()) } @@ -194,7 +194,7 @@ impl PyDataStream { let fut: JoinHandle> = rt.spawn(async move { ds.sink_kafka(bootstrap_servers, topic).await }); - let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??; + wait_for_future(py, fut).map_err(py_denormalized_err)??; Ok(()) } @@ -234,7 +234,7 @@ impl PyDataStream { }); // rt.block_on(fut).map_err(py_denormalized_err)??; - let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??; + wait_for_future(py, fut).map_err(py_denormalized_err)??; Ok(()) }