Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example python udaf #53

Merged
merged 11 commits into from
Nov 13, 2024
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/target
.vscode
.DS_Store
.ipynb_checkpoints/
Untitled.ipynb
2 changes: 1 addition & 1 deletion crates/core/src/logical_plan/streaming_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub struct StreamingWindowSchema {
impl StreamingWindowSchema {
pub fn try_new(aggr_expr: Aggregate) -> Result<Self> {
let inner_schema = aggr_expr.schema.inner().clone();
let fields = inner_schema.flattened_fields().to_owned();
let fields = inner_schema.fields();

let mut builder = SchemaBuilder::new();

Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/physical_plan/continuous/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub(crate) fn create_group_accumulator(
}

fn add_window_columns_to_schema(schema: SchemaRef) -> Schema {
let fields = schema.flattened_fields().to_owned();
let fields = schema.fields();

let mut builder = SchemaBuilder::new();

Expand Down
11 changes: 5 additions & 6 deletions py-denormalized/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@ build-backend = "maturin"
[project]
name = "denormalized"
requires-python = ">=3.12"
classifiers = [ ]
dynamic = ["version"] # Version specified in py-denormalized/Cargo.toml
classifiers = []
dynamic = ["version"] # Version specified in py-denormalized/Cargo.toml
description = ""
dependencies = [
"pyarrow>=17.0.0",
"datafusion>=40.1.0",
]
dependencies = ["pyarrow>=17.0.0", "datafusion>=40.1.0"]

[project.optional-dependencies]
tests = ["pytest"]
Expand All @@ -30,6 +27,8 @@ dev-dependencies = [
"pytest>=8.3.2",
"maturin>=1.7.4",
"pyarrow-stubs>=17.11",
"pandas>=2.2.3",
"jupyterlab>=4.3.0",
]

# Enable docstring linting using the google style guide
Expand Down
82 changes: 82 additions & 0 deletions py-denormalized/python/examples/udaf_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""stream_aggregate example."""

import json
import signal
import sys
from collections import Counter
from typing import List
import pyarrow as pa

from denormalized import Context
from denormalized.datafusion import Accumulator, col
from denormalized.datafusion import functions as f
from denormalized.datafusion import udaf


def signal_handler(sig, frame):
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)

bootstrap_server = "localhost:9092"

sample_event = {
"occurred_at_ms": 100,
"sensor_name": "foo",
"reading": 0.0,
}

class TotalValuesRead(Accumulator):
# Define the state type as a struct containing a map
acc_state_type = pa.struct([("counts", pa.map_(pa.string(), pa.int64()))])

def __init__(self):
self.counts = Counter()

def update(self, values: pa.Array) -> None:
# Update counter with new values
if values is not None:
self.counts.update(values.to_pylist())

def merge(self, states: pa.Array) -> None:
# Merge multiple states into this accumulator
if states is None or len(states) == 0:
return
for state in states:
if state is not None:
counts_map = state.to_pylist()[0] # will always be one element struct
for k, v in counts_map["counts"]:
self.counts[k] += v

def state(self) -> List[pa.Scalar]:
# Convert current state to Arrow array format
result = {"counts": dict(self.counts.items())}
return [pa.scalar(result, type=pa.struct([("counts", pa.map_(pa.string(), pa.int64()))]))]

def evaluate(self) -> pa.Scalar:
return self.state()[0]


input_type = [pa.string()]
return_type = TotalValuesRead.acc_state_type
state_type = [TotalValuesRead.acc_state_type]
sample_udaf = udaf(TotalValuesRead, input_type, return_type, state_type, "stable")


def print_batch(rb: pa.RecordBatch):
if not len(rb):
return
print(rb)

ctx = Context()
ds = ctx.from_topic("temperature", json.dumps(sample_event), bootstrap_server, "occurred_at_ms")

ds = ds.window(
[],
[
sample_udaf(col("sensor_name")).alias("count"),
],
2000,
None,
).sink(print_batch)
Loading
Loading