Skip to content

Commit

Permalink
simple example works
Browse files Browse the repository at this point in the history
  • Loading branch information
emgeee committed Sep 16, 2024
1 parent cc89d37 commit 18d36d0
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 33 deletions.
2 changes: 1 addition & 1 deletion examples/examples/simple_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use denormalized_examples::get_sample_json;
#[tokio::main]
async fn main() -> Result<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Debug)
.filter_level(log::LevelFilter::Info)
.init();

let sample_event = get_sample_json();
Expand Down
158 changes: 153 additions & 5 deletions py-denormalized/python/denormalized/datastream.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,169 @@
from denormalized._internal import PyDataStream

import pyarrow as pa
from datafusion import Expr
from denormalized._internal import PyDataStream
from denormalized._internal import expr as internal_expr


class DataStream:
"""DataStream."""
"""Represents a stream of data that can be manipulated using various operations."""

def __init__(self, ds: PyDataStream) -> None:
"""__init__."""
"""Initialize a new DataStream object.
Args:
ds (PyDataStream): The underlying PyDataStream object.
"""
self.ds = ds

def __repr__(self):
"""Return a string representation of the DataStream object.
Returns:
str: A string representation of the DataStream.
"""
return self.ds.__repr__()

def __str__(self):
"""Return a string description of the DataStream object.
Returns:
str: A string description of the DataStream.
"""
return self.ds.__str__()

def schema(self) -> pa.Schema:
"""Schema."""
"""Get the schema of the DataStream.
Returns:
pa.Schema: The PyArrow schema of the DataStream.
"""
return self.ds.schema()

def print_expr(self, expr: Expr):
"""Print the given expression.
Args:
expr (Expr): The expression to print.
"""
self.ds.print_expr(expr)

def select(self, expr_list: list[Expr]) -> "DataStream":
"""Select specific columns or expressions from the DataStream.
Args:
expr_list (list[Expr]): A list of expressions to select.
Returns:
DataStream: A new DataStream with the selected columns/expressions.
"""
return DataStream(self.ds.select(expr_list))

def filter(self, predicate: Expr) -> "DataStream":
"""Filter the DataStream based on a predicate.
Args:
predicate (Expr): The filter predicate.
Returns:
DataStream: A new DataStream with the filter applied.
"""
return DataStream(self.ds.filter(predicate))

def join_on(
self, right: "DataStream", join_type: str, on_exprs: list[Expr]
) -> "DataStream":
"""Join this DataStream with another one based on join expressions.
Args:
right (DataStream): The right DataStream to join with.
join_type (str): The type of join to perform.
on_exprs (list[Expr]): The expressions to join on.
Returns:
DataStream: A new DataStream resulting from the join operation.
"""
return DataStream(self.ds.join_on(right.ds, join_type, on_exprs))

def join(
self,
right: "DataStream",
join_type: str,
left_cols: list[str],
right_cols: list[str],
filter: Expr = None,
) -> "DataStream":
"""Join this DataStream with another one based on column names.
Args:
right (DataStream): The right DataStream to join with.
join_type (str): The type of join to perform.
left_cols (list[str]): The columns from the left DataStream to join on.
right_cols (list[str]): The columns from the right DataStream to join on.
filter (Expr, optional): An additional filter to apply to the join.
Returns:
DataStream: A new DataStream resulting from the join operation.
"""
return DataStream(
self.ds.join(right.ds, join_type, left_cols, right_cols, filter)
)

def window(
self,
group_expr: list[Expr],
aggr_expr: list[Expr],
window_length_millis: int,
slide_millis: int = None,
) -> "DataStream":
"""Apply a windowing operation to the DataStream.
Args:
group_expr (list[Expr]): The expressions to group by.
aggr_expr (list[Expr]): The aggregation expressions to apply.
window_length_millis (int): The length of the window in milliseconds.
slide_millis (int, optional): The slide interval of the window in
milliseconds.
Returns:
DataStream: A new DataStream with the windowing operation applied.
"""
return DataStream(
self.ds.window(group_expr, aggr_expr, window_length_millis, slide_millis)
)

def print_stream(self) -> None:
"""Print the contents of the DataStream."""
self.ds.print_stream()

def print_schema(self) -> "DataStream":
"""Print the schema of the DataStream.
Returns:
DataStream: This DataStream object for method chaining.
"""
return DataStream(self.ds.print_schema())

def print_plan(self) -> "DataStream":
"""Print the execution plan of the DataStream.
Returns:
DataStream: This DataStream object for method chaining.
"""
return DataStream(self.ds.print_plan())

def print_physical_plan(self) -> "DataStream":
"""Print the physical execution plan of the DataStream.
Returns:
DataStream: This DataStream object for method chaining.
"""
return DataStream(self.ds.print_physical_plan())

def sink_kafka(self, bootstrap_servers: str, topic: str) -> None:
"""Sink the DataStream to a Kafka topic.
Args:
bootstrap_servers (str): The Kafka bootstrap servers.
topic (str): The Kafka topic to sink the data to.
"""
self.ds.sink_kafka(bootstrap_servers, topic)
30 changes: 18 additions & 12 deletions py-denormalized/python/examples/stream_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import json

import pyarrow as pa
from denormalized import Context, DataStream
from datafusion import Expr
from denormalized._internal import expr
from denormalized._internal import functions as f

# from denormalized._internal import expr as expr_internal

sample_event = {
"occurred_at_ms": 100,
Expand All @@ -10,17 +14,19 @@
}

ctx = Context()
print('here')
print(ctx)
ds = ctx.from_topic("temperature", json.dumps(sample_event), "localhost:9092")

expr = Expr.literal(4)
print(expr)

print(ds.schema())


# from denormalized._internal import PyContext
#
# ctx_internal = PyContext()
# ctx_internal.foo()
ds.window(
[expr.Expr.column("sensor_name")],
[
f.count(expr.Expr.column("reading"), distinct=False, filter=None).alias(
"count"
),
f.min(expr.Expr.column("reading")).alias("min"),
f.max(expr.Expr.column("reading")).alias("max"),
f.avg(expr.Expr.column("reading")).alias("average"),
],
1000,
None,
).filter(expr.Expr.column("max") > (expr.Expr.literal(pa.scalar(113)))).print_stream()
6 changes: 0 additions & 6 deletions py-denormalized/python/tests/test_all.py

This file was deleted.

1 change: 0 additions & 1 deletion py-denormalized/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ impl PyContext {
Ok("foo wtf".to_string())
}


fn __repr__(&self, _py: Python) -> PyResult<String> {
Ok("__repr__ PyContext".to_string())
}
Expand Down
21 changes: 16 additions & 5 deletions py-denormalized/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ use pyo3::prelude::*;
use std::sync::Arc;
use std::time::Duration;

use tokio::task::JoinHandle;

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion_python::expr::{join::PyJoinType, PyExpr};

use denormalized::datastream::DataStream;

use datafusion::arrow::pyarrow::PyArrowType;
use datafusion_python::expr::{join::PyJoinType, PyExpr};
use crate::errors::py_denormalized_err;
use crate::utils::{get_tokio_runtime, wait_for_future};

#[pyclass(name = "PyDataStream", module = "denormalized", subclass)]
#[derive(Clone)]
Expand Down Expand Up @@ -119,9 +123,16 @@ impl PyDataStream {
println!("{:?}", expr);
}

pub fn print_stream(&self) -> PyResult<()> {
pub fn print_stream(&self, py: Python) -> PyResult<()> {
// Implement the method using the original Rust code
todo!()
let ds = self.ds.clone();
let rt = &get_tokio_runtime(py).0;
let fut: JoinHandle<denormalized::common::error::Result<()>> =
rt.spawn(async move { ds.print_stream().await });

let _ = wait_for_future(py, fut).map_err(py_denormalized_err)??;

Ok(())
}

pub fn print_schema(&self) -> PyResult<Self> {
Expand All @@ -139,7 +150,7 @@ impl PyDataStream {
todo!()
}

pub fn sink_kafka(&self, bootstrap_servers: String, topic: String) -> PyResult<()> {
pub fn sink_kafka(&self, _bootstrap_servers: String, _topic: String) -> PyResult<()> {
// Implement the method using the original Rust code
todo!()
}
Expand Down
2 changes: 1 addition & 1 deletion py-denormalized/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::error::Error;
use std::fmt::Debug;

use datafusion::arrow::error::ArrowError;
use pyo3::{exceptions::PyException, PyErr};
use denormalized::common::error::DenormalizedError as InnerDenormalizedError;
use pyo3::{exceptions::PyException, PyErr};

pub type Result<T> = std::result::Result<T, DenormalizedError>;

Expand Down
4 changes: 3 additions & 1 deletion py-denormalized/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime);

/// A Python module implemented in Rust.
#[pymodule]
fn _internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<datastream::PyDataStream>()?;
m.add_class::<context::PyContext>()?;

datafusion_python::_internal(py, m)?;

Ok(())
}
2 changes: 1 addition & 1 deletion py-denormalized/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tokio::runtime::Runtime;

/// Utility to get the Tokio Runtime from Python
pub(crate) fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
let datafusion = py.import_bound("datafusion._internal").unwrap();
let datafusion = py.import_bound("denormalized._internal").unwrap();
let tmp = datafusion.getattr("runtime").unwrap();
match tmp.extract::<PyRef<TokioRuntime>>() {
Ok(runtime) => runtime,
Expand Down

0 comments on commit 18d36d0

Please sign in to comment.