diff --git a/README.md b/README.md index eaf2b0d..3ba7348 100644 --- a/README.md +++ b/README.md @@ -96,11 +96,12 @@ Details about developing the python bindings can be found in [py-denormalized/RE ### Checkpointing -We use SlateDB for state backend. Initialize your Job Context to a path to local directory - +We use SlateDB for state backend. Initialize your Job Context with a custom config and a path for SlateDB backend to store state - ``` - let ctx = Context::new()? - .with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1")) + let config = Context::default_config().set_bool("denormalized_config.checkpoint", true); + let ctx = Context::with_config(config)? + .with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg/job1")) .await; ``` diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 3419720..f3b078d 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -6,6 +6,7 @@ use datafusion::execution::{ session_state::SessionStateBuilder, }; +use crate::config_extensions::denormalized_config::DenormalizedConfig; use crate::datasource::kafka::TopicReader; use crate::datastream::DataStream; use crate::physical_optimizer::EnsureHashPartititionOnGroupByForStreamingAggregates; @@ -17,12 +18,13 @@ use denormalized_common::error::{DenormalizedError, Result}; #[derive(Clone)] pub struct Context { - pub session_conext: Arc, + pub session_context: Arc, } impl Context { - pub fn new() -> Result { - let config = SessionConfig::new() + pub fn default_config() -> SessionConfig { + let ext_config = DenormalizedConfig::default(); + let mut config = SessionConfig::new() .set( "datafusion.execution.batch_size", &datafusion::common::ScalarValue::UInt64(Some(32)), @@ -34,8 +36,16 @@ impl Context { &datafusion::common::ScalarValue::Boolean(Some(false)), ); - let runtime = Arc::new(RuntimeEnv::default()); + let _ = config.options_mut().extensions.insert(ext_config); + config + } + pub fn new() -> Result { + Context::with_config(Context::default_config()) + } + + pub fn with_config(config: SessionConfig) -> Result { + let runtime = Arc::new(RuntimeEnv::default()); let state = SessionStateBuilder::new() .with_default_features() .with_config(config) @@ -48,7 +58,7 @@ impl Context { .build(); Ok(Self { - session_conext: Arc::new(SessionContext::new_with_state(state)), + session_context: Arc::new(SessionContext::new_with_state(state)), }) } @@ -56,7 +66,7 @@ impl Context { let topic_name = topic.0.topic.clone(); self.register_table(topic_name.clone(), Arc::new(topic)) .await?; - let df = self.session_conext.table(topic_name.as_str()).await?; + let df = self.session_context.table(topic_name.as_str()).await?; let ds = DataStream::new(Arc::new(df), Arc::new(self.clone())); Ok(ds) } @@ -66,7 +76,7 @@ impl Context { name: String, table: Arc, ) -> Result<(), DenormalizedError> { - self.session_conext + self.session_context .register_table(name.as_str(), table.clone())?; Ok(()) diff --git a/crates/core/src/datasource/kafka/kafka_stream_read.rs b/crates/core/src/datasource/kafka/kafka_stream_read.rs index 9d484ee..caed655 100644 --- a/crates/core/src/datasource/kafka/kafka_stream_read.rs +++ b/crates/core/src/datasource/kafka/kafka_stream_read.rs @@ -7,7 +7,7 @@ use arrow_array::{Array, ArrayRef, PrimitiveArray, RecordBatch, StringArray, Str use arrow_schema::{DataType, Field, SchemaRef, TimeUnit}; use crossbeam::channel; use denormalized_orchestrator::channel_manager::{create_channel, get_sender, take_receiver}; -use denormalized_orchestrator::orchestrator::{self, OrchestrationMessage}; +use denormalized_orchestrator::orchestrator::OrchestrationMessage; use futures::executor::block_on; use log::{debug, error}; use serde::{Deserialize, Serialize}; @@ -83,13 +83,13 @@ impl PartitionStream for KafkaStreamRead { } fn execute(&self, ctx: Arc) -> SendableRecordBatchStream { - let _config_options = ctx + let config_options = ctx .session_config() .options() .extensions .get::(); - let mut should_checkpoint = false; //config_options.map_or(false, |c| c.checkpoint); + let should_checkpoint = config_options.map_or(false, |c| c.checkpoint); let node_id = self.exec_node_id.unwrap(); let partition_tag = self @@ -101,13 +101,16 @@ impl PartitionStream for KafkaStreamRead { let channel_tag = format!("{}_{}", node_id, partition_tag); let mut serialized_state: Option> = None; - let state_backend = get_global_slatedb().unwrap(); + let mut state_backend = None; let mut starting_offsets: HashMap = HashMap::new(); - if orchestrator::SHOULD_CHECKPOINT { + + if should_checkpoint { create_channel(channel_tag.as_str(), 10); + let backend = get_global_slatedb().unwrap(); debug!("checking for last checkpointed offsets"); - serialized_state = block_on(state_backend.clone().get(channel_tag.as_bytes().to_vec())); + serialized_state = block_on(backend.get(channel_tag.as_bytes().to_vec())); + state_backend = Some(backend); } if let Some(serialized_state) = serialized_state { @@ -151,25 +154,26 @@ impl PartitionStream for KafkaStreamRead { builder.spawn(async move { let mut epoch = 0; let mut receiver: Option> = None; - if orchestrator::SHOULD_CHECKPOINT { + if should_checkpoint { let orchestrator_sender = get_sender("orchestrator"); let msg: OrchestrationMessage = OrchestrationMessage::RegisterStream(channel_tag.clone()); orchestrator_sender.as_ref().unwrap().send(msg).unwrap(); receiver = take_receiver(channel_tag.as_str()); } + let mut checkpoint_batch = false; loop { //let mut checkpoint_barrier: Option = None; let mut _checkpoint_barrier: Option = None; - if orchestrator::SHOULD_CHECKPOINT { + if should_checkpoint { let r = receiver.as_ref().unwrap(); for message in r.try_iter() { debug!("received checkpoint barrier for {:?}", message); if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message { epoch = epoch_ts; - should_checkpoint = true; + checkpoint_batch = true; } } } @@ -245,7 +249,7 @@ impl PartitionStream for KafkaStreamRead { let tx_result = tx.send(Ok(timestamped_record_batch)).await; match tx_result { Ok(_) => { - if should_checkpoint { + if checkpoint_batch { debug!("about to checkpoint offsets"); let off = BatchReadMetadata { epoch, @@ -255,9 +259,10 @@ impl PartitionStream for KafkaStreamRead { }; state_backend .as_ref() + .unwrap() .put(channel_tag.as_bytes().to_vec(), off.to_bytes().unwrap()); debug!("checkpointed offsets {:?}", off); - should_checkpoint = false; + checkpoint_batch = false; } } Err(err) => error!("result err {:?}. shutdown signal detected.", err), diff --git a/crates/core/src/datastream.rs b/crates/core/src/datastream.rs index 42a8f4f..ea86b72 100644 --- a/crates/core/src/datastream.rs +++ b/crates/core/src/datastream.rs @@ -1,7 +1,6 @@ use datafusion::common::runtime::SpawnedTask; use datafusion::logical_expr::LogicalPlan; use datafusion::physical_plan::ExecutionPlanProperties; -use denormalized_orchestrator::orchestrator; use futures::StreamExt; use log::debug; use log::info; @@ -18,6 +17,7 @@ use datafusion::logical_expr::{ }; use datafusion::physical_plan::display::DisplayableExecutionPlan; +use crate::config_extensions::denormalized_config::DenormalizedConfig; use crate::context::Context; use crate::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder}; use crate::logical_plan::StreamingLogicalPlanBuilder; @@ -240,7 +240,12 @@ impl DataStream { let mut maybe_orchestrator_handle = None; - if orchestrator::SHOULD_CHECKPOINT { + let config = self.context.session_context.copied_config(); + let config_options = config.options().extensions.get::(); + + let should_checkpoint = config_options.map_or(false, |c| c.checkpoint); + + if should_checkpoint { let mut orchestrator = Orchestrator::default(); let cloned_shutdown_rx = self.shutdown_rx.clone(); let orchestrator_handle = @@ -286,10 +291,12 @@ impl DataStream { log::info!("Stream processing stopped. Cleaning up..."); - let state_backend = get_global_slatedb(); - if let Ok(db) = state_backend { - log::info!("Closing the state backend (slatedb)..."); - db.close().await.unwrap(); + if should_checkpoint { + let state_backend = get_global_slatedb(); + if let Ok(db) = state_backend { + log::info!("Closing the state backend (slatedb)..."); + db.close().await.unwrap(); + } } // Join the orchestrator handle if it exists, ensuring it is joined and awaited diff --git a/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs b/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs index da43430..0388a23 100644 --- a/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs +++ b/crates/core/src/physical_plan/continuous/grouped_window_agg_stream.rs @@ -37,14 +37,14 @@ use datafusion::{ }; use denormalized_orchestrator::{ - channel_manager::take_receiver, - orchestrator::{self, OrchestrationMessage}, + channel_manager::take_receiver, orchestrator::OrchestrationMessage, }; use futures::{executor::block_on, Stream, StreamExt}; use log::debug; use serde::{Deserialize, Serialize}; use crate::{ + config_extensions::denormalized_config::DenormalizedConfig, physical_plan::utils::time::RecordBatchWatermark, state_backend::slatedb::{get_global_slatedb, SlateDBWrapper}, utils::serialization::ArrayContainer, @@ -73,11 +73,11 @@ pub struct GroupedWindowAggStream { group_by: PhysicalGroupBy, group_schema: Arc, context: Arc, - epoch: i64, + checkpoint: bool, partition: usize, channel_tag: String, receiver: Option>, - state_backend: Arc, + state_backend: Option>, } #[derive(Serialize, Deserialize)] @@ -147,11 +147,23 @@ impl GroupedWindowAggStream { .and_then(|tag| take_receiver(tag.as_str())); let channel_tag: String = channel_tag.unwrap_or(String::from("")); - let state_backend = get_global_slatedb().unwrap(); - let serialized_state = block_on(state_backend.get(channel_tag.as_bytes().to_vec())); + let config_options = context + .session_config() + .options() + .extensions + .get::(); + + let checkpoint = config_options.map_or(false, |c| c.checkpoint); + + let mut serialized_state: Option> = None; + let mut state_backend = None; + if checkpoint { + let backend = get_global_slatedb().unwrap(); + serialized_state = block_on(backend.get(channel_tag.as_bytes().to_vec())); + state_backend = Some(backend); + } - //let window_frames: BTreeMap = BTreeMap::new(); let mut stream = Self { schema: agg_schema, input, @@ -166,7 +178,7 @@ impl GroupedWindowAggStream { group_by, group_schema, context, - epoch: 0, + checkpoint, partition, channel_tag, receiver, @@ -340,19 +352,19 @@ impl GroupedWindowAggStream { return Poll::Pending; } }; - self.epoch += 1; - if orchestrator::SHOULD_CHECKPOINT { + let mut checkpoint_batch = false; + + if self.checkpoint { let r = self.receiver.as_ref().unwrap(); - let mut epoch: u128 = 0; for message in r.try_iter() { debug!("received checkpoint barrier for {:?}", message); - if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message { - epoch = epoch_ts; + if let OrchestrationMessage::CheckpointBarrier(_epoch_ts) = message { + checkpoint_batch = true; } } - if epoch != 0 { + if checkpoint_batch { // Prepare data for checkpointing // Clone or extract necessary data @@ -400,7 +412,7 @@ impl GroupedWindowAggStream { let key = self.channel_tag.as_bytes().to_vec(); // Clone or use `Arc` for `state_backend` - let state_backend = self.state_backend.clone(); + let state_backend = self.state_backend.clone().unwrap(); state_backend.put(key, serialized_checkpoint); } diff --git a/crates/core/src/physical_plan/continuous/streaming_window.rs b/crates/core/src/physical_plan/continuous/streaming_window.rs index e70261b..c6ea378 100644 --- a/crates/core/src/physical_plan/continuous/streaming_window.rs +++ b/crates/core/src/physical_plan/continuous/streaming_window.rs @@ -40,16 +40,19 @@ use datafusion::{ }; use denormalized_orchestrator::{ channel_manager::{create_channel, get_sender}, - orchestrator::{self, OrchestrationMessage}, + orchestrator::OrchestrationMessage, }; use futures::{Stream, StreamExt}; use tracing::debug; -use crate::physical_plan::{ - continuous::grouped_window_agg_stream::GroupedWindowAggStream, - utils::{ - accumulators::{create_accumulators, AccumulatorItem}, - time::{system_time_from_epoch, RecordBatchWatermark}, +use crate::{ + config_extensions::denormalized_config::DenormalizedConfig, + physical_plan::{ + continuous::grouped_window_agg_stream::GroupedWindowAggStream, + utils::{ + accumulators::{create_accumulators, AccumulatorItem}, + time::{system_time_from_epoch, RecordBatchWatermark}, + }, }, }; @@ -427,7 +430,15 @@ impl ExecutionPlan for StreamingWindowExec { .node_id() .expect("expected node id to be set."); - let channel_tag = if orchestrator::SHOULD_CHECKPOINT { + let config_options = context + .session_config() + .options() + .extensions + .get::(); + + let checkpoint = config_options.map_or(false, |c| c.checkpoint); + + let channel_tag = if checkpoint { let tag = format!("{}_{}", node_id, partition); create_channel(tag.as_str(), 10); let orchestrator_sender = get_sender("orchestrator"); diff --git a/crates/orchestrator/src/orchestrator.rs b/crates/orchestrator/src/orchestrator.rs index 6acfeed..ae9c1f0 100644 --- a/crates/orchestrator/src/orchestrator.rs +++ b/crates/orchestrator/src/orchestrator.rs @@ -20,8 +20,6 @@ pub struct Orchestrator { senders: HashMap>, } -pub const SHOULD_CHECKPOINT: bool = false; // THIS WILL BE MOVED INTO CONFIG - /** * 1. Keep track of checkpoint per source. * 2. Tell each downstream which checkpoints it needs to know. diff --git a/examples/examples/simple_aggregation.rs b/examples/examples/simple_aggregation.rs index d785d65..9341529 100644 --- a/examples/examples/simple_aggregation.rs +++ b/examples/examples/simple_aggregation.rs @@ -18,6 +18,9 @@ async fn main() -> Result<()> { .init(); let bootstrap_servers = String::from("localhost:9092"); + + let config = Context::default_config().set_bool("denormalized_config.checkpoint", false); + let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers); // Connect to source topic @@ -32,9 +35,9 @@ async fn main() -> Result<()> { ])) .await?; - Context::new()? - .with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1")) - .await + let _ctx = Context::with_config(config)? + //.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1")) + //.await .from_topic(source_topic) .await? .window( diff --git a/examples/examples/stream_join.rs b/examples/examples/stream_join.rs index 5c6c672..2edc25e 100644 --- a/examples/examples/stream_join.rs +++ b/examples/examples/stream_join.rs @@ -17,10 +17,7 @@ async fn main() -> Result<()> { let bootstrap_servers = String::from("localhost:9092"); - let ctx = Context::new()? - .with_slatedb_backend(String::from("/tmp/checkpoints/stream-join-checkpoint-1")) - .await; - + let ctx = Context::new()?; let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone()); let source_topic_builder = topic_builder @@ -73,8 +70,12 @@ async fn main() -> Result<()> { .join( humidity_ds, JoinType::Inner, - &["sensor_name", "window_start_time"], - &["humidity_sensor", "humidity_window_start_time"], + &["sensor_name", "window_start_time", "window_end_time"], + &[ + "humidity_sensor", + "humidity_window_start_time", + "humidity_window_end_time", + ], None, )?;