Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful shutdown handling #69

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/common/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub enum DenormalizedError {
// #[allow(clippy::disallowed_types)]
#[error("DataFusion error")]
DataFusion(#[from] DataFusionError),
#[error("Shutdown")]
Shutdown(),
#[error("RocksDB error: {0}")]
RocksDB(String),
#[error("Kafka error")]
Expand Down
161 changes: 94 additions & 67 deletions crates/core/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ impl DataStream {
let (session_state, plan) = self.df.as_ref().clone().into_parts();
let physical_plan = self.df.as_ref().clone().create_physical_plan().await?;
let node_id = physical_plan.node_id();
debug!("topline node id = {:?}", node_id);
let displayable_plan = DisplayableExecutionPlan::new(physical_plan.as_ref());

println!("{}", displayable_plan.indent(true));
Expand All @@ -243,62 +242,49 @@ impl DataStream {
})
}

/// execute the stream and print the results to stdout.
/// Mainly used for development and debugging
pub async fn print_stream(mut self) -> Result<()> {
async fn with_orchestrator<F, Fut, T>(&mut self, stream_fn: F) -> Result<T>
where
F: FnOnce(watch::Receiver<bool>) -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
self.start_shutdown_listener();

let mut maybe_orchestrator_handle = None;

let config = self.context.session_context.copied_config();
let config_options = config.options().extensions.get::<DenormalizedConfig>();

let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);

let mut maybe_orchestrator_handle = None;

// Start orchestrator if checkpointing is enabled
if should_checkpoint {
let mut orchestrator = Orchestrator::default();
let cloned_shutdown_rx = self.shutdown_rx.clone();
let orchestrator_handle =
SpawnedTask::spawn_blocking(move || orchestrator.run(10, cloned_shutdown_rx));

maybe_orchestrator_handle = Some(orchestrator_handle)
maybe_orchestrator_handle = Some(orchestrator_handle);
}

let mut stream: SendableRecordBatchStream =
self.df.as_ref().clone().execute_stream().await?;

// Stream loop with shutdown check
loop {
tokio::select! {
// Check if shutdown signal has changed
_ = self.shutdown_rx.changed() => {
info!("Graceful shutdown initiated, exiting stream loop...");

break;
}
// Handle the next batch from the DataFusion stream
next_batch = stream.next() => {
match next_batch.transpose() {
Ok(Some(batch)) => {
println!(
"{}",
datafusion::common::arrow::util::pretty::pretty_format_batches(&[batch])
.unwrap()
);
}
Ok(None) => {
info!("No more RecordBatch in stream");
break; // End of stream
}
Err(err) => {
log::error!("Error reading stream: {:?}", err);
return Err(err.into());
}
}
}
// Run the stream processing function

let mut shutdown_rx = self.shutdown_rx.clone();

let result = tokio::select! {
res = stream_fn(shutdown_rx.clone()) => {
// `stream_fn` completed first
res
},
_ = shutdown_rx.changed() => {
// Shutdown signal received first
log::info!("Shutdown signal received while the pipeline was running, cancelling...");
// return early or handle cancellation gracefully
// For example, you might return Ok(()) or some cancellation error:
return Err(denormalized_common::DenormalizedError::Shutdown());
}
}
};

//let result = stream_fn(self.shutdown_rx.clone()).await;

// Cleanup
log::info!("Stream processing stopped. Cleaning up...");

if should_checkpoint {
Expand All @@ -309,43 +295,84 @@ impl DataStream {
}
}

// Join the orchestrator handle if it exists, ensuring it is joined and awaited
// Join orchestrator if it was started
if let Some(orchestrator_handle) = maybe_orchestrator_handle {
log::info!("Waiting for orchestrator task to complete...");
match orchestrator_handle.join_unwind().await {
Ok(_) => log::info!("Orchestrator task completed successfully."),
Err(e) => log::error!("Error joining orchestrator task: {:?}", e),
}
}
Ok(())
}

/// execute the stream and write the results to a give kafka topic
pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> {
let processed_schema = Arc::new(datafusion::common::arrow::datatypes::Schema::from(
self.df.schema(),
));

let sink_topic = KafkaTopicBuilder::new(bootstrap_servers.clone())
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.with_encoding("json")?
.with_topic(topic.clone())
.with_schema(processed_schema)
.build_writer(ConnectionOpts::new())
.await?;

self.context
.register_table(topic.clone(), Arc::new(sink_topic))
.await?;
result
}

self.df
.as_ref()
.clone()
.write_table(topic.as_str(), DataFrameWriteOptions::default())
/// execute the stream and print the results to stdout.
/// Mainly used for development and debugging
pub async fn print_stream(self) -> Result<()> {
self.clone()
.with_orchestrator(|_shutdown_rx| async move {
let mut stream: SendableRecordBatchStream =
self.df.as_ref().clone().execute_stream().await?;

loop {
match stream.next().await.transpose() {
Ok(Some(batch)) => {
if batch.num_rows() > 0 {
println!(
"{}",
datafusion::common::arrow::util::pretty::pretty_format_batches(
&[batch]
)
.unwrap()
);
}
}
Ok(None) => {
info!("No more RecordBatches in stream");
break; // End of stream
}
Err(err) => {
log::error!("Error reading stream: {:?}", err);
return Err(err.into());
}
}
}
Ok(())
})
.await?;

Ok(())
}

pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> {
self.clone()
.with_orchestrator(|_shutdown_rx| async move {
let processed_schema = Arc::new(
datafusion::common::arrow::datatypes::Schema::from(self.df.schema()),
);

let sink_topic = KafkaTopicBuilder::new(bootstrap_servers.clone())
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.with_encoding("json")?
.with_topic(topic.clone())
.with_schema(processed_schema)
.build_writer(ConnectionOpts::new())
.await?;

self.context
.register_table(topic.clone(), Arc::new(sink_topic))
.await?;

self.df
.as_ref()
.clone()
.write_table(topic.as_str(), DataFrameWriteOptions::default())
.await?;

Ok(())
})
.await
}
}

/// Trait that allows both DataStream and DataFrame objects to be joined to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,18 +632,18 @@ impl GroupedAggWindowFrame {
&mut self,
state: &CheckpointedGroupedWindowFrame,
) -> Result<(), DataFusionError> {
let _ = self
.accumulators
self.accumulators
.iter_mut()
.zip(state.accumulators.iter())
.map(|(acc, checkpointed_acc)| {
.for_each(|(acc, checkpointed_acc)| {
let group_indices = (0..checkpointed_acc.num_groups).collect::<Vec<usize>>();
acc.merge_batch(
&checkpointed_acc.states.arrays,
&group_indices,
None,
checkpointed_acc.num_groups,
)
.unwrap();
});
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/utils/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ mod tests {
use arrow_schema::{Field, Fields};
use datafusion::{
functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator,
scalar::ScalarValue,
physical_expr::GroupsAccumulatorAdapter, scalar::ScalarValue,
};
use std::sync::Arc;

Expand Down
33 changes: 17 additions & 16 deletions examples/examples/simple_aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::time::Duration;

use datafusion::functions_aggregate::count::count;
use datafusion::functions_aggregate::expr_fn::{max, min};
use datafusion::logical_expr::{col, lit};
use datafusion::functions_aggregate::expr_fn::{avg, max, min};
use datafusion::logical_expr::col;

use denormalized::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder};
use denormalized::prelude::*;
Expand All @@ -17,27 +17,29 @@ async fn main() -> Result<()> {
.filter_level(log::LevelFilter::Debug)
.init();

let bootstrap_servers = String::from("localhost:9092");
let bootstrap_servers = String::from("localhost:19092");

let config = Context::default_config().set_bool("denormalized_config.checkpoint", false);

let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers);
let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone());

// Connect to source topic
let source_topic = topic_builder
.with_topic(String::from("temperature"))
.infer_schema_from_json(get_sample_json().as_str())?
.with_encoding("json")?
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
//.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.build_reader(ConnectionOpts::from([
("auto.offset.reset".to_string(), "latest".to_string()),
("group.id".to_string(), "sample_pipeline".to_string()),
]))
.await?;

let _ctx = Context::with_config(config)?
//.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1"))
//.await
let ds = Context::with_config(config)?
// .with_slatedb_backend(String::from(
// "/tmp/checkpoints/simple-aggregation-example",
// ))
// .await
.from_topic(source_topic)
.await?
.window(
Expand All @@ -46,14 +48,13 @@ async fn main() -> Result<()> {
count(col("reading")).alias("count"),
min(col("reading")).alias("min"),
max(col("reading")).alias("max"),
//avg(col("reading")).alias("average"),
avg(col("reading")).alias("average"),
],
Duration::from_millis(1_000), // aggregate every 1 second
None, // None means tumbling window
)?
.filter(col("max").gt(lit(113)))?
.print_stream() // Print out the results
.await?;

Duration::from_millis(5_000), // Window length
None, // Slide duration. None defaults to a tumbling window.
)?;
ds.print_stream().await?;
//ds.sink_kafka(bootstrap_servers, String::from("checkpointed-output"))
// .await?;
Ok(())
}
Loading