Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stream join in python #51

Merged
merged 8 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/core/src/datasource/kafka/kafka_stream_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ impl PartitionStream for KafkaStreamRead {
max_timestamp,
offsets_read,
};
let _ = state_backend
state_backend
.as_ref()
.put(channel_tag.as_bytes().to_vec(), off.to_bytes().unwrap());
debug!("checkpointed offsets {:?}", off);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl GroupedWindowAggStream {
context,
epoch: 0,
partition,
channel_tag: channel_tag,
channel_tag,
receiver,
state_backend,
};
Expand All @@ -184,7 +184,7 @@ impl GroupedWindowAggStream {
.collect();
let _ = stream.ensure_window_frames_for_ranges(&ranges);
state.frames.iter().for_each(|f| {
let _ = stream.update_accumulators_for_frame(f.window_start_time, &f);
let _ = stream.update_accumulators_for_frame(f.window_start_time, f);
});
let state_watermark = state.watermark.unwrap();
stream.process_watermark(RecordBatchWatermark {
Expand Down Expand Up @@ -387,7 +387,7 @@ impl GroupedWindowAggStream {

let watermark = {
let watermark_lock = self.latest_watermark.lock().unwrap();
watermark_lock.clone()
*watermark_lock
};

let checkpointed_state = CheckpointedGroupedWindowAggStream {
Expand Down
2 changes: 1 addition & 1 deletion examples/examples/simple_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn main() -> Result<()> {
]))
.await?;

let _ctx = Context::new()?
Context::new()?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1"))
.await
.from_topic(source_topic)
Expand Down
3 changes: 3 additions & 0 deletions py-denormalized/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@ docs/_build/

# Pyenv
.python-version

.ipynb_checkpoints/
Untitled.ipynb
1 change: 1 addition & 0 deletions py-denormalized/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
[project.optional-dependencies]
tests = ["pytest"]
feast = ["feast"]
dev = []

[tool.maturin]
python-source = "python"
Expand Down
5 changes: 5 additions & 0 deletions py-denormalized/python/denormalized/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def with_column(self, name: str, predicate: Expr) -> "DataStream":
DataStream: A new DataStream with the additional column.
"""
return DataStream(self.ds.with_column(name, to_internal_expr(predicate)))

def drop_columns(self, columns: list[str]) -> "DataStream":
"""Drops columns from the DataStream.
"""
return DataStream(self.ds.drop_columns(columns))

def join_on(
self, right: "DataStream", join_type: str, on_exprs: list[Expr]
Expand Down
72 changes: 72 additions & 0 deletions py-denormalized/python/examples/stream_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""stream_aggregate example."""

import json
import signal
import sys
import pprint as pp

from denormalized import Context
from denormalized.datafusion import col, expr
from denormalized.datafusion import functions as f
from denormalized.datafusion import lit


def signal_handler(sig, frame):
print("You pressed Ctrl+C!")
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)

bootstrap_server = "localhost:9092"

sample_event = {
"occurred_at_ms": 100,
"sensor_name": "foo",
"reading": 0.0,
}


def print_batch(rb):
pp.pprint(rb.to_pydict())


ctx = Context()
temperature_ds = ctx.from_topic(
"temperature", json.dumps(sample_event), bootstrap_server
)

humidity_ds = (
ctx.from_topic("humidity", json.dumps(sample_event), bootstrap_server)
.with_column("humidity_sensor", col("sensor_name"))
.drop_columns(["sensor_name"])
.window(
[col("humidity_sensor")],
[
f.count(col("reading")).alias("avg_humidity"),
],
4000,
None,
)
.with_column("humidity_window_start_time", col("window_start_time"))
.with_column("humidity_window_end_time", col("window_end_time"))
.drop_columns(["window_start_time", "window_end_time"])
)

joined_ds = (
temperature_ds.window(
[col("sensor_name")],
[
f.avg(col("reading")).alias("avg_temperature"),
],
4000,
None,
)
.join(
humidity_ds,
"inner",
["sensor_name", "window_start_time"],
["humidity_sensor", "humidity_window_start_time"],
)
.sink(print_batch)
)
41 changes: 31 additions & 10 deletions py-denormalized/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ use tokio::task::JoinHandle;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::common::JoinType;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::display::DisplayableExecutionPlan;
use datafusion_python::expr::{join::PyJoinType, PyExpr};

use denormalized::datastream::DataStream;

use crate::errors::{py_denormalized_err, Result};
use crate::errors::{py_denormalized_err, DenormalizedError, Result};
use crate::utils::{get_tokio_runtime, python_print, wait_for_future};

#[pyclass(name = "PyDataStream", module = "denormalized", subclass)]
Expand Down Expand Up @@ -86,6 +87,13 @@ impl PyDataStream {
Ok(Self::new(ds))
}

pub fn drop_columns(&self, columns: Vec<String>) -> Result<Self> {
let columns_ref: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();

let ds = self.ds.as_ref().clone().drop_columns(&columns_ref)?;
Ok(Self::new(ds))
}

pub fn join_on(
&self,
_right: PyDataStream,
Expand All @@ -95,29 +103,42 @@ impl PyDataStream {
todo!()
}

#[pyo3(signature = (right, join_type, left_cols, right_cols, filter=None))]
#[pyo3(signature = (right, how, left_cols, right_cols, filter=None))]
pub fn join(
&self,
right: PyDataStream,
join_type: PyJoinType,
how: &str,
left_cols: Vec<String>,
right_cols: Vec<String>,
filter: Option<PyExpr>,
) -> PyResult<Self> {
let right_ds = right.ds.as_ref().clone();

let join_type = match how {
"inner" => JoinType::Inner,
"left" => JoinType::Left,
"right" => JoinType::Right,
"full" => JoinType::Full,
"semi" => JoinType::LeftSemi,
"anti" => JoinType::LeftAnti,
how => {
return Err(DenormalizedError::Common(format!(
"The join type {how} does not exist or is not implemented"
))
.into());
}
};

let filter = filter.map(|f| f.into());

let left_cols = left_cols.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let right_cols = right_cols.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();

let ds = self.ds.as_ref().clone().join(
right_ds,
join_type.into(),
&left_cols,
&right_cols,
filter,
)?;
let ds =
self.ds
.as_ref()
.clone()
.join(right_ds, join_type, &left_cols, &right_cols, filter)?;
Ok(Self::new(ds))
}

Expand Down
Loading