Skip to content

Commit

Permalink
Add udf example (#41)
Browse files Browse the repository at this point in the history
* Add udf example

* formatting

* remove comments
  • Loading branch information
emgeee authored Sep 21, 2024
1 parent 2a7aad9 commit 6d3ec66
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions crates/core/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl DataStream {
})
}

pub fn with_column(self, name: &str, expr: Expr) -> Result<Self> {
Ok(Self {
df: Arc::new(self.df.as_ref().clone().with_column(name, expr)?),
context: self.context.clone(),
})
}

// Join two streams using the specified expression
pub fn join_on(
self,
Expand Down
3 changes: 2 additions & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ denormalized = { workspace = true }

datafusion = { workspace = true }

arrow = { workspace = true }
arrow = { workspace = true, features = ["prettyprint"] }
arrow-schema = { workspace = true }
arrow-array = { workspace = true }
tracing = { workspace = true }
futures = { workspace = true }
tracing-log = { workspace = true }
Expand Down
129 changes: 129 additions & 0 deletions examples/examples/udf_example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use std::any::Any;
use std::sync::Arc;
use std::time::Duration;

use datafusion::common::cast::as_float64_array;
use datafusion::functions_aggregate::average::avg;
use datafusion::functions_aggregate::count::count;
use datafusion::functions_aggregate::expr_fn::{max, min};
use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion::logical_expr::Volatility;
use datafusion::logical_expr::{col, lit};
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature};

use denormalized::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder};
use denormalized::prelude::*;

use arrow::array::{ArrayRef, Float64Array};
use arrow::datatypes::DataType;

use denormalized_examples::get_sample_json;

#[tokio::main]
async fn main() -> Result<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();

let sample_event = get_sample_json();

let bootstrap_servers = String::from("localhost:9092");

let ctx = Context::new()?;
let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone());

let source_topic = topic_builder
.with_topic(String::from("temperature"))
.infer_schema_from_json(sample_event.as_str())?
.with_encoding("json")?
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.build_reader(ConnectionOpts::from([
("auto.offset.reset".to_string(), "latest".to_string()),
("group.id".to_string(), "sample_pipeline".to_string()),
]))
.await?;

let sample_udf = ScalarUDF::from(SampleUdf::new());

let ds = ctx
.from_topic(source_topic)
.await?
.window(
vec![col("sensor_name")],
vec![
count(col("reading")).alias("count"),
min(col("reading")).alias("min"),
max(col("reading")).alias("max"),
avg(col("reading")).alias("average"),
],
Duration::from_millis(1_000),
None,
)?
.filter(col("max").gt(lit(113)))?
.with_column("sample", sample_udf.call(vec![col("max")]))?;

ds.clone().print_physical_plan().await?;
ds.clone().print_stream().await?;

Ok(())
}

#[derive(Debug, Clone)]
struct SampleUdf {
signature: Signature,
}

impl SampleUdf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SampleUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"sample_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
Ok(DataType::Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion::error::Result<ColumnarValue> {
assert_eq!(args.len(), 1);
let value = &args[0];
assert_eq!(value.data_type(), DataType::Float64);

let args = ColumnarValue::values_to_arrays(args)?;
let value = as_float64_array(&args[0]).expect("cast failed");

let array = value
.iter()
.map(|v| match v {
Some(f) => {
let value = f + 20_f64;
Some(value)
}
_ => None,
})
.collect::<Float64Array>();

Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
}

fn output_ordering(
&self,
input: &[ExprProperties],
) -> datafusion::error::Result<SortProperties> {
Ok(input[0].sort_properties)
}
}
8 changes: 4 additions & 4 deletions py-denormalized/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl PyDataStream {
// Use u64 for durations since using PyDelta type requires non-Py_LIMITED_API to be
// enabled
let window_length_duration = Duration::from_millis(window_length_millis);
let window_slide_duration = slide_millis.map(|d| Duration::from_millis(d));
let window_slide_duration = slide_millis.map(Duration::from_millis);

let ds = self.ds.as_ref().clone().window(
groups,
Expand Down Expand Up @@ -183,7 +183,7 @@ impl PyDataStream {
let fut: JoinHandle<denormalized::common::error::Result<()>> =
rt.spawn(async move { ds.print_stream().await });

let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??;
wait_for_future(py, fut).map_err(py_denormalized_err)??;

Ok(())
}
Expand All @@ -194,7 +194,7 @@ impl PyDataStream {

let fut: JoinHandle<denormalized::common::error::Result<()>> =
rt.spawn(async move { ds.sink_kafka(bootstrap_servers, topic).await });
let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??;
wait_for_future(py, fut).map_err(py_denormalized_err)??;

Ok(())
}
Expand Down Expand Up @@ -234,7 +234,7 @@ impl PyDataStream {
});

// rt.block_on(fut).map_err(py_denormalized_err)??;
let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??;
wait_for_future(py, fut).map_err(py_denormalized_err)??;

Ok(())
}
Expand Down

0 comments on commit 6d3ec66

Please sign in to comment.