From e5f77d26c5dc5ea6e2d56aa51fc297c9d66fdb24 Mon Sep 17 00:00:00 2001 From: Amey Chaugule Date: Tue, 17 Dec 2024 14:22:35 -0800 Subject: [PATCH] Graceful shutdown handling --- crates/common/src/error/mod.rs | 2 + crates/core/src/datastream.rs | 161 ++++++++++-------- .../continuous/grouped_window_agg_stream.rs | 6 +- crates/core/src/utils/serialization.rs | 2 +- examples/examples/simple_aggregation.rs | 33 ++-- 5 files changed, 117 insertions(+), 87 deletions(-) diff --git a/crates/common/src/error/mod.rs b/crates/common/src/error/mod.rs index d99a1b8..2134ac6 100644 --- a/crates/common/src/error/mod.rs +++ b/crates/common/src/error/mod.rs @@ -17,6 +17,8 @@ pub enum DenormalizedError { // #[allow(clippy::disallowed_types)] #[error("DataFusion error")] DataFusion(#[from] DataFusionError), + #[error("Shutdown")] + Shutdown(), #[error("RocksDB error: {0}")] RocksDB(String), #[error("Kafka error")] diff --git a/crates/core/src/datastream.rs b/crates/core/src/datastream.rs index 6b22f72..687f9d6 100644 --- a/crates/core/src/datastream.rs +++ b/crates/core/src/datastream.rs @@ -230,7 +230,6 @@ impl DataStream { let (session_state, plan) = self.df.as_ref().clone().into_parts(); let physical_plan = self.df.as_ref().clone().create_physical_plan().await?; let node_id = physical_plan.node_id(); - debug!("topline node id = {:?}", node_id); let displayable_plan = DisplayableExecutionPlan::new(physical_plan.as_ref()); println!("{}", displayable_plan.indent(true)); @@ -243,62 +242,49 @@ impl DataStream { }) } - /// execute the stream and print the results to stdout. - /// Mainly used for development and debugging - pub async fn print_stream(mut self) -> Result<()> { + async fn with_orchestrator(&mut self, stream_fn: F) -> Result + where + F: FnOnce(watch::Receiver) -> Fut, + Fut: std::future::Future>, + { self.start_shutdown_listener(); - let mut maybe_orchestrator_handle = None; - let config = self.context.session_context.copied_config(); let config_options = config.options().extensions.get::(); - let should_checkpoint = config_options.map_or(false, |c| c.checkpoint); + let mut maybe_orchestrator_handle = None; + + // Start orchestrator if checkpointing is enabled if should_checkpoint { let mut orchestrator = Orchestrator::default(); let cloned_shutdown_rx = self.shutdown_rx.clone(); let orchestrator_handle = SpawnedTask::spawn_blocking(move || orchestrator.run(10, cloned_shutdown_rx)); - - maybe_orchestrator_handle = Some(orchestrator_handle) + maybe_orchestrator_handle = Some(orchestrator_handle); } - let mut stream: SendableRecordBatchStream = - self.df.as_ref().clone().execute_stream().await?; - - // Stream loop with shutdown check - loop { - tokio::select! { - // Check if shutdown signal has changed - _ = self.shutdown_rx.changed() => { - info!("Graceful shutdown initiated, exiting stream loop..."); - - break; - } - // Handle the next batch from the DataFusion stream - next_batch = stream.next() => { - match next_batch.transpose() { - Ok(Some(batch)) => { - println!( - "{}", - datafusion::common::arrow::util::pretty::pretty_format_batches(&[batch]) - .unwrap() - ); - } - Ok(None) => { - info!("No more RecordBatch in stream"); - break; // End of stream - } - Err(err) => { - log::error!("Error reading stream: {:?}", err); - return Err(err.into()); - } - } - } + // Run the stream processing function + + let mut shutdown_rx = self.shutdown_rx.clone(); + + let result = tokio::select! { + res = stream_fn(shutdown_rx.clone()) => { + // `stream_fn` completed first + res + }, + _ = shutdown_rx.changed() => { + // Shutdown signal received first + log::info!("Shutdown signal received while the pipeline was running, cancelling..."); + // return early or handle cancellation gracefully + // For example, you might return Ok(()) or some cancellation error: + return Err(denormalized_common::DenormalizedError::Shutdown()); } - } + }; + + //let result = stream_fn(self.shutdown_rx.clone()).await; + // Cleanup log::info!("Stream processing stopped. Cleaning up..."); if should_checkpoint { @@ -309,7 +295,7 @@ impl DataStream { } } - // Join the orchestrator handle if it exists, ensuring it is joined and awaited + // Join orchestrator if it was started if let Some(orchestrator_handle) = maybe_orchestrator_handle { log::info!("Waiting for orchestrator task to complete..."); match orchestrator_handle.join_unwind().await { @@ -317,35 +303,76 @@ impl DataStream { Err(e) => log::error!("Error joining orchestrator task: {:?}", e), } } - Ok(()) - } - - /// execute the stream and write the results to a give kafka topic - pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> { - let processed_schema = Arc::new(datafusion::common::arrow::datatypes::Schema::from( - self.df.schema(), - )); - - let sink_topic = KafkaTopicBuilder::new(bootstrap_servers.clone()) - .with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis) - .with_encoding("json")? - .with_topic(topic.clone()) - .with_schema(processed_schema) - .build_writer(ConnectionOpts::new()) - .await?; - self.context - .register_table(topic.clone(), Arc::new(sink_topic)) - .await?; + result + } - self.df - .as_ref() - .clone() - .write_table(topic.as_str(), DataFrameWriteOptions::default()) + /// execute the stream and print the results to stdout. + /// Mainly used for development and debugging + pub async fn print_stream(self) -> Result<()> { + self.clone() + .with_orchestrator(|_shutdown_rx| async move { + let mut stream: SendableRecordBatchStream = + self.df.as_ref().clone().execute_stream().await?; + + loop { + match stream.next().await.transpose() { + Ok(Some(batch)) => { + if batch.num_rows() > 0 { + println!( + "{}", + datafusion::common::arrow::util::pretty::pretty_format_batches( + &[batch] + ) + .unwrap() + ); + } + } + Ok(None) => { + info!("No more RecordBatches in stream"); + break; // End of stream + } + Err(err) => { + log::error!("Error reading stream: {:?}", err); + return Err(err.into()); + } + } + } + Ok(()) + }) .await?; - Ok(()) } + + pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> { + self.clone() + .with_orchestrator(|_shutdown_rx| async move { + let processed_schema = Arc::new( + datafusion::common::arrow::datatypes::Schema::from(self.df.schema()), + ); + + let sink_topic = KafkaTopicBuilder::new(bootstrap_servers.clone()) + .with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis) + .with_encoding("json")? + .with_topic(topic.clone()) + .with_schema(processed_schema) + .build_writer(ConnectionOpts::new()) + .await?; + + self.context + .register_table(topic.clone(), Arc::new(sink_topic)) + .await?; + + self.df + .as_ref() + .clone() + .write_table(topic.as_str(), DataFrameWriteOptions::default()) + .await?; + + Ok(()) + }) + .await + } } /// Trait that allows both DataStream and DataFrame objects to be joined to diff --git a/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs b/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs index 536822b..0646a7d 100644 --- a/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs +++ b/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs @@ -632,11 +632,10 @@ impl GroupedAggWindowFrame { &mut self, state: &CheckpointedGroupedWindowFrame, ) -> Result<(), DataFusionError> { - let _ = self - .accumulators + self.accumulators .iter_mut() .zip(state.accumulators.iter()) - .map(|(acc, checkpointed_acc)| { + .for_each(|(acc, checkpointed_acc)| { let group_indices = (0..checkpointed_acc.num_groups).collect::>(); acc.merge_batch( &checkpointed_acc.states.arrays, @@ -644,6 +643,7 @@ impl GroupedAggWindowFrame { None, checkpointed_acc.num_groups, ) + .unwrap(); }); Ok(()) } diff --git a/crates/core/src/utils/serialization.rs b/crates/core/src/utils/serialization.rs index 13169e7..2156f1c 100644 --- a/crates/core/src/utils/serialization.rs +++ b/crates/core/src/utils/serialization.rs @@ -297,7 +297,7 @@ mod tests { use arrow_schema::{Field, Fields}; use datafusion::{ functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, - scalar::ScalarValue, + physical_expr::GroupsAccumulatorAdapter, scalar::ScalarValue, }; use std::sync::Arc; diff --git a/examples/examples/simple_aggregation.rs b/examples/examples/simple_aggregation.rs index 9341529..f1628e1 100644 --- a/examples/examples/simple_aggregation.rs +++ b/examples/examples/simple_aggregation.rs @@ -1,8 +1,8 @@ use std::time::Duration; use datafusion::functions_aggregate::count::count; -use datafusion::functions_aggregate::expr_fn::{max, min}; -use datafusion::logical_expr::{col, lit}; +use datafusion::functions_aggregate::expr_fn::{avg, max, min}; +use datafusion::logical_expr::col; use denormalized::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder}; use denormalized::prelude::*; @@ -17,27 +17,29 @@ async fn main() -> Result<()> { .filter_level(log::LevelFilter::Debug) .init(); - let bootstrap_servers = String::from("localhost:9092"); + let bootstrap_servers = String::from("localhost:19092"); let config = Context::default_config().set_bool("denormalized_config.checkpoint", false); - let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers); + let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone()); // Connect to source topic let source_topic = topic_builder .with_topic(String::from("temperature")) .infer_schema_from_json(get_sample_json().as_str())? .with_encoding("json")? - .with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis) + //.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis) .build_reader(ConnectionOpts::from([ ("auto.offset.reset".to_string(), "latest".to_string()), ("group.id".to_string(), "sample_pipeline".to_string()), ])) .await?; - let _ctx = Context::with_config(config)? - //.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1")) - //.await + let ds = Context::with_config(config)? + // .with_slatedb_backend(String::from( + // "/tmp/checkpoints/simple-aggregation-example", + // )) + // .await .from_topic(source_topic) .await? .window( @@ -46,14 +48,13 @@ async fn main() -> Result<()> { count(col("reading")).alias("count"), min(col("reading")).alias("min"), max(col("reading")).alias("max"), - //avg(col("reading")).alias("average"), + avg(col("reading")).alias("average"), ], - Duration::from_millis(1_000), // aggregate every 1 second - None, // None means tumbling window - )? - .filter(col("max").gt(lit(113)))? - .print_stream() // Print out the results - .await?; - + Duration::from_millis(5_000), // Window length + None, // Slide duration. None defaults to a tumbling window. + )?; + ds.print_stream().await?; + //ds.sink_kafka(bootstrap_servers, String::from("checkpointed-output")) + // .await?; Ok(()) }