Skip to content

Commit

Permalink
add From methods for context/pycontext
Browse files Browse the repository at this point in the history
  • Loading branch information
emgeee committed Sep 17, 2024
1 parent 26b0b20 commit bf12e6d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 56 deletions.
66 changes: 33 additions & 33 deletions crates/core/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,39 +131,6 @@ impl DataStream {
})
}

/// execute the stream and print the results to stdout.
/// Mainly used for development and debugging
pub async fn print_stream(&self) -> Result<()> {
if orchestrator::SHOULD_CHECKPOINT {
let plan = self.df.as_ref().clone().create_physical_plan().await?;
let node_ids = extract_node_ids_and_partitions(&plan);
let max_buffer_size = node_ids.iter().map(|x| x.1).sum::<usize>();
let mut orchestrator = Orchestrator::default();
SpawnedTask::spawn_blocking(move || orchestrator.run(max_buffer_size));
}

let mut stream: SendableRecordBatchStream =
self.df.as_ref().clone().execute_stream().await?;
loop {
match stream.next().await.transpose() {
Ok(Some(batch)) => {
println!(
"{}",
datafusion::common::arrow::util::pretty::pretty_format_batches(&[batch])
.unwrap()
);
}
Ok(None) => {
log::warn!("No RecordBatch in stream");
}
Err(err) => {
log::error!("Error reading stream: {:?}", err);
return Err(err.into());
}
}
}
}

/// Return the schema of DataFrame that backs the DataStream
pub fn schema(&self) -> &DFSchema {
self.df.schema()
Expand Down Expand Up @@ -198,6 +165,39 @@ impl DataStream {
})
}

/// execute the stream and print the results to stdout.
/// Mainly used for development and debugging
pub async fn print_stream(&self) -> Result<()> {
if orchestrator::SHOULD_CHECKPOINT {
let plan = self.df.as_ref().clone().create_physical_plan().await?;
let node_ids = extract_node_ids_and_partitions(&plan);
let max_buffer_size = node_ids.iter().map(|x| x.1).sum::<usize>();
let mut orchestrator = Orchestrator::default();
SpawnedTask::spawn_blocking(move || orchestrator.run(max_buffer_size));
}

let mut stream: SendableRecordBatchStream =
self.df.as_ref().clone().execute_stream().await?;
loop {
match stream.next().await.transpose() {
Ok(Some(batch)) => {
println!(
"{}",
datafusion::common::arrow::util::pretty::pretty_format_batches(&[batch])
.unwrap()
);
}
Ok(None) => {
log::warn!("No RecordBatch in stream");
}
Err(err) => {
log::error!("Error reading stream: {:?}", err);
return Err(err.into());
}
}
}
}

/// execute the stream and write the results to a give kafka topic
pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> {
let processed_schema = Arc::new(datafusion::common::arrow::datatypes::Schema::from(
Expand Down
4 changes: 3 additions & 1 deletion py-denormalized/python/examples/stream_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""stream_aggregate example."""
import json

import pyarrow as pa
from denormalized import Context, DataStream
from denormalized import Context
from denormalized._internal import expr
from denormalized._internal import functions as f

Expand Down
26 changes: 26 additions & 0 deletions py-denormalized/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,32 @@ pub struct PyContext {

impl PyContext {}

impl From<Context> for PyContext {
fn from(context: Context) -> Self {
PyContext {
context: Arc::new(context),
}
}
}

impl From<PyContext> for Context {
fn from(py_context: PyContext) -> Self {
Arc::try_unwrap(py_context.context).unwrap_or_else(|arc| (*arc).clone())
}
}

impl From<Arc<Context>> for PyContext {
fn from(context: Arc<Context>) -> Self {
PyContext { context }
}
}

impl From<PyContext> for Arc<Context> {
fn from(py_context: PyContext) -> Self {
py_context.context
}
}

#[pymethods]
impl PyContext {
/// creates a new PyDataFrame
Expand Down
35 changes: 13 additions & 22 deletions py-denormalized/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use datafusion_python::expr::{join::PyJoinType, PyExpr};
use denormalized::datastream::DataStream;

use crate::errors::py_denormalized_err;
use crate::utils::{get_tokio_runtime, wait_for_future};
use crate::utils::{get_tokio_runtime, wait_for_future, python_print};

#[pyclass(name = "PyDataStream", module = "denormalized", subclass)]
#[derive(Clone)]
Expand Down Expand Up @@ -144,18 +144,6 @@ impl PyDataStream {
Ok(Self::new(ds))
}

pub fn print_stream(&self, py: Python) -> PyResult<()> {
// Implement the method using the original Rust code
let ds = self.ds.clone();
let rt = &get_tokio_runtime(py).0;
let fut: JoinHandle<denormalized::common::error::Result<()>> =
rt.spawn(async move { ds.print_stream().await });

let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??;

Ok(())
}

pub fn print_schema(&self, py: Python) -> PyResult<Self> {
let schema = format!("{}", self.ds.schema());
python_print(py, schema)?;
Expand Down Expand Up @@ -187,17 +175,20 @@ impl PyDataStream {
Ok(self.to_owned())
}

pub fn print_stream(&self, py: Python) -> PyResult<()> {
// Implement the method using the original Rust code
let ds = self.ds.clone();
let rt = &get_tokio_runtime(py).0;
let fut: JoinHandle<denormalized::common::error::Result<()>> =
rt.spawn(async move { ds.print_stream().await });

let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??;

Ok(())
}

pub fn sink_kafka(&self, _bootstrap_servers: String, _topic: String) -> PyResult<()> {
// Implement the method using the original Rust code
todo!()
}
}


fn python_print(py: Python, str: String) -> PyResult<()> {
// Import the Python 'builtins' module to access the print function
// Note that println! does not print to the Python debug console and is not visible in notebooks for instance
let print = py.import_bound("builtins")?.getattr("print")?;
print.call1((str,))?;
Ok(())
}
10 changes: 10 additions & 0 deletions py-denormalized/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ where
py.allow_threads(|| runtime.block_on(f))
}


/// Print a string to the python console
pub fn python_print(py: Python, str: String) -> PyResult<()> {
// Import the Python 'builtins' module to access the print function
// Note that println! does not print to the Python debug console and is not visible in notebooks for instance
let print = py.import_bound("builtins")?.getattr("print")?;
print.call1((str,))?;
Ok(())
}

// pub(crate) fn parse_volatility(value: &str) -> Result<Volatility, DataFusionError> {
// Ok(match value {
// "immutable" => Volatility::Immutable,
Expand Down

0 comments on commit bf12e6d

Please sign in to comment.