From 3c36be7da5f8dc36a9b9793169838995c6267bb2 Mon Sep 17 00:00:00 2001 From: Matt Green Date: Thu, 19 Sep 2024 20:06:00 -0700 Subject: [PATCH] make sink_python work --- .../python/denormalized/datastream.py | 4 ++-- .../python/examples/stream_aggregate.py | 6 ++++- py-denormalized/src/datastream.rs | 23 ++++++++++--------- py-denormalized/src/errors.rs | 9 ++++++++ 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/py-denormalized/python/denormalized/datastream.py b/py-denormalized/python/denormalized/datastream.py index c9b12f7..0919902 100644 --- a/py-denormalized/python/denormalized/datastream.py +++ b/py-denormalized/python/denormalized/datastream.py @@ -160,6 +160,6 @@ def sink_kafka(self, bootstrap_servers: str, topic: str) -> None: """ self.ds.sink_kafka(bootstrap_servers, topic) - def sink_python(self) -> None: + def sink_python(self, func) -> None: """Sink the DataStream to a Python function.""" - self.ds.sink_python() + self.ds.sink_python(func) diff --git a/py-denormalized/python/examples/stream_aggregate.py b/py-denormalized/python/examples/stream_aggregate.py index 0b9b7dc..ccd8026 100644 --- a/py-denormalized/python/examples/stream_aggregate.py +++ b/py-denormalized/python/examples/stream_aggregate.py @@ -23,6 +23,10 @@ def signal_handler(sig, frame): "reading": 0.0, } +def sample_func(rb): + print("hello world2!") + print(len(rb)) + ctx = Context() ds = ctx.from_topic("temperature", json.dumps(sample_event), bootstrap_server) @@ -41,4 +45,4 @@ def signal_handler(sig, frame): None, ).filter( expr.Expr.column("max") > (expr.Expr.literal(pa.scalar(113))) -).sink_python() +).sink_python(sample_func) diff --git a/py-denormalized/src/datastream.rs b/py-denormalized/src/datastream.rs index 34b1e9c..6e56e70 100644 --- a/py-denormalized/src/datastream.rs +++ b/py-denormalized/src/datastream.rs @@ -8,13 +8,14 @@ use tokio::task::JoinHandle; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::PyArrowType; +use datafusion::arrow::pyarrow::ToPyArrow; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion_python::expr::{join::PyJoinType, PyExpr}; use denormalized::datastream::DataStream; -use crate::errors::py_denormalized_err; +use crate::errors::{py_denormalized_err, Result}; use crate::utils::{get_tokio_runtime, python_print, wait_for_future}; #[pyclass(name = "PyDataStream", module = "denormalized", subclass)] @@ -198,11 +199,12 @@ impl PyDataStream { Ok(()) } - pub fn sink_python(&self, py: Python) -> PyResult<()> { + /// Execute the dataframe and pass the resulting recordbatch to a python function + pub fn sink_python(&self, func: PyObject, py: Python) -> PyResult<()> { let ds = self.ds.as_ref().clone(); let rt = &get_tokio_runtime(py).0; - let fut: JoinHandle> = rt.spawn(async move { + let fut: JoinHandle> = rt.spawn(async move { let mut stream: SendableRecordBatchStream = ds.df.as_ref().clone().execute_stream().await?; @@ -211,15 +213,13 @@ impl PyDataStream { _ = tokio::signal::ctrl_c() => break, // Explicitly check for ctrl-c and exit // loop if it occurs message = stream.next() => { - match message.transpose(){ + match message.transpose() { Ok(Some(batch)) => { - println!( - "{}", - datafusion::common::arrow::util::pretty::pretty_format_batches(&[ - batch - ]) - .unwrap() - ); + Python::with_gil(|py| -> PyResult<()> { + let batch = batch.clone().to_pyarrow(py)?; + func.call1(py, (batch,))?; + Ok(()) + })?; }, Ok(None) => {}, Err(err) => { @@ -233,6 +233,7 @@ impl PyDataStream { Ok(()) }); + // rt.block_on(fut).map_err(py_denormalized_err)??; let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??; Ok(()) diff --git a/py-denormalized/src/errors.rs b/py-denormalized/src/errors.rs index ece47a8..0cb514a 100644 --- a/py-denormalized/src/errors.rs +++ b/py-denormalized/src/errors.rs @@ -5,6 +5,7 @@ use std::error::Error; use std::fmt::Debug; use datafusion::arrow::error::ArrowError; +use datafusion::error::DataFusionError; use denormalized::common::error::DenormalizedError as InnerDenormalizedError; use pyo3::{exceptions::PyException, PyErr}; @@ -16,6 +17,7 @@ pub enum DenormalizedError { ArrowError(ArrowError), Common(String), PythonError(PyErr), + DataFusionError(DataFusionError), } impl fmt::Display for DenormalizedError { @@ -25,6 +27,7 @@ impl fmt::Display for DenormalizedError { DenormalizedError::ArrowError(e) => write!(f, "Arrow error: {e:?}"), DenormalizedError::PythonError(e) => write!(f, "Python error {e:?}"), DenormalizedError::Common(e) => write!(f, "{e}"), + DenormalizedError::DataFusionError(e) => write!(f, "DataFusionError{e}"), } } } @@ -56,6 +59,12 @@ impl From for PyErr { } } +impl From for DenormalizedError { + fn from(err: DataFusionError) -> DenormalizedError { + DenormalizedError::DataFusionError(err) + } +} + impl Error for DenormalizedError {} pub fn py_type_err(e: impl Debug) -> PyErr {