Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding config option for checkpointing #50

Merged
merged 10 commits into from
Nov 7, 2024
Merged
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ Details about developing the python bindings can be found in [py-denormalized/RE

### Checkpointing

We use SlateDB for state backend. Initialize your Job Context to a path to local directory -
We use SlateDB for state backend. Initialize your Job Context with a custom config and a path for SlateDB backend to store state -

```
let ctx = Context::new()?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1"))
let config = Context::default_config().set_bool("denormalized_config.checkpoint", true);
let ctx = Context::with_config(config)?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg/job1"))
.await;
```

Expand Down
24 changes: 17 additions & 7 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use datafusion::execution::{
session_state::SessionStateBuilder,
};

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::datasource::kafka::TopicReader;
use crate::datastream::DataStream;
use crate::physical_optimizer::EnsureHashPartititionOnGroupByForStreamingAggregates;
Expand All @@ -17,12 +18,13 @@ use denormalized_common::error::{DenormalizedError, Result};

#[derive(Clone)]
pub struct Context {
pub session_conext: Arc<SessionContext>,
pub session_context: Arc<SessionContext>,
}

impl Context {
pub fn new() -> Result<Self, DenormalizedError> {
let config = SessionConfig::new()
pub fn default_config() -> SessionConfig {
let ext_config = DenormalizedConfig::default();
let mut config = SessionConfig::new()
.set(
"datafusion.execution.batch_size",
&datafusion::common::ScalarValue::UInt64(Some(32)),
Expand All @@ -34,8 +36,16 @@ impl Context {
&datafusion::common::ScalarValue::Boolean(Some(false)),
);

let runtime = Arc::new(RuntimeEnv::default());
let _ = config.options_mut().extensions.insert(ext_config);
config
}

pub fn new() -> Result<Self, DenormalizedError> {
Context::with_config(Context::default_config())
}

pub fn with_config(config: SessionConfig) -> Result<Self, DenormalizedError> {
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_default_features()
.with_config(config)
Expand All @@ -48,15 +58,15 @@ impl Context {
.build();

Ok(Self {
session_conext: Arc::new(SessionContext::new_with_state(state)),
session_context: Arc::new(SessionContext::new_with_state(state)),
})
}

pub async fn from_topic(&self, topic: TopicReader) -> Result<DataStream, DenormalizedError> {
let topic_name = topic.0.topic.clone();
self.register_table(topic_name.clone(), Arc::new(topic))
.await?;
let df = self.session_conext.table(topic_name.as_str()).await?;
let df = self.session_context.table(topic_name.as_str()).await?;
let ds = DataStream::new(Arc::new(df), Arc::new(self.clone()));
Ok(ds)
}
Expand All @@ -66,7 +76,7 @@ impl Context {
name: String,
table: Arc<impl TableProvider + 'static>,
) -> Result<(), DenormalizedError> {
self.session_conext
self.session_context
.register_table(name.as_str(), table.clone())?;

Ok(())
Expand Down
27 changes: 16 additions & 11 deletions crates/core/src/datasource/kafka/kafka_stream_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use arrow_array::{Array, ArrayRef, PrimitiveArray, RecordBatch, StringArray, Str
use arrow_schema::{DataType, Field, SchemaRef, TimeUnit};
use crossbeam::channel;
use denormalized_orchestrator::channel_manager::{create_channel, get_sender, take_receiver};
use denormalized_orchestrator::orchestrator::{self, OrchestrationMessage};
use denormalized_orchestrator::orchestrator::{OrchestrationMessage};
use futures::executor::block_on;
use log::{debug, error};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -83,13 +83,13 @@ impl PartitionStream for KafkaStreamRead {
}

fn execute(&self, ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let _config_options = ctx
let config_options = ctx
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let mut should_checkpoint = false; //config_options.map_or(false, |c| c.checkpoint);
let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);

let node_id = self.exec_node_id.unwrap();
let partition_tag = self
Expand All @@ -101,13 +101,16 @@ impl PartitionStream for KafkaStreamRead {

let channel_tag = format!("{}_{}", node_id, partition_tag);
let mut serialized_state: Option<Vec<u8>> = None;
let state_backend = get_global_slatedb().unwrap();
let mut state_backend = None;

let mut starting_offsets: HashMap<i32, i64> = HashMap::new();
if orchestrator::SHOULD_CHECKPOINT {

if should_checkpoint {
create_channel(channel_tag.as_str(), 10);
let backend = get_global_slatedb().unwrap();
debug!("checking for last checkpointed offsets");
serialized_state = block_on(state_backend.clone().get(channel_tag.as_bytes().to_vec()));
serialized_state = block_on(backend.get(channel_tag.as_bytes().to_vec()));
state_backend = Some(backend);
}

if let Some(serialized_state) = serialized_state {
Expand Down Expand Up @@ -151,25 +154,26 @@ impl PartitionStream for KafkaStreamRead {
builder.spawn(async move {
let mut epoch = 0;
let mut receiver: Option<channel::Receiver<OrchestrationMessage>> = None;
if orchestrator::SHOULD_CHECKPOINT {
if should_checkpoint {
let orchestrator_sender = get_sender("orchestrator");
let msg: OrchestrationMessage =
OrchestrationMessage::RegisterStream(channel_tag.clone());
orchestrator_sender.as_ref().unwrap().send(msg).unwrap();
receiver = take_receiver(channel_tag.as_str());
}
let mut checkpoint_batch = false;

loop {
//let mut checkpoint_barrier: Option<String> = None;
let mut _checkpoint_barrier: Option<i64> = None;

if orchestrator::SHOULD_CHECKPOINT {
if should_checkpoint {
let r = receiver.as_ref().unwrap();
for message in r.try_iter() {
debug!("received checkpoint barrier for {:?}", message);
if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message {
epoch = epoch_ts;
should_checkpoint = true;
checkpoint_batch = true;
}
}
}
Expand Down Expand Up @@ -245,7 +249,7 @@ impl PartitionStream for KafkaStreamRead {
let tx_result = tx.send(Ok(timestamped_record_batch)).await;
match tx_result {
Ok(_) => {
if should_checkpoint {
if checkpoint_batch {
debug!("about to checkpoint offsets");
let off = BatchReadMetadata {
epoch,
Expand All @@ -255,9 +259,10 @@ impl PartitionStream for KafkaStreamRead {
};
let _ = state_backend
.as_ref()
.unwrap()
.put(channel_tag.as_bytes().to_vec(), off.to_bytes().unwrap());
debug!("checkpointed offsets {:?}", off);
should_checkpoint = false;
checkpoint_batch = false;
}
}
Err(err) => error!("result err {:?}. shutdown signal detected.", err),
Expand Down
19 changes: 13 additions & 6 deletions crates/core/src/datastream.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use datafusion::common::runtime::SpawnedTask;
use datafusion::logical_expr::LogicalPlan;
use datafusion::physical_plan::ExecutionPlanProperties;
use denormalized_orchestrator::orchestrator;
use futures::StreamExt;
use log::debug;
use log::info;
Expand All @@ -18,6 +17,7 @@ use datafusion::logical_expr::{
};
use datafusion::physical_plan::display::DisplayableExecutionPlan;

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::context::Context;
use crate::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder};
use crate::logical_plan::StreamingLogicalPlanBuilder;
Expand Down Expand Up @@ -231,7 +231,12 @@ impl DataStream {

let mut maybe_orchestrator_handle = None;

if orchestrator::SHOULD_CHECKPOINT {
let config = self.context.session_context.copied_config();
let config_options = config.options().extensions.get::<DenormalizedConfig>();

let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);

if should_checkpoint {
let mut orchestrator = Orchestrator::default();
let cloned_shutdown_rx = self.shutdown_rx.clone();
let orchestrator_handle =
Expand Down Expand Up @@ -277,10 +282,12 @@ impl DataStream {

log::info!("Stream processing stopped. Cleaning up...");

let state_backend = get_global_slatedb();
if let Ok(db) = state_backend {
log::info!("Closing the state backend (slatedb)...");
db.close().await.unwrap();
if should_checkpoint {
let state_backend = get_global_slatedb();
if let Ok(db) = state_backend {
log::info!("Closing the state backend (slatedb)...");
db.close().await.unwrap();
}
}

// Join the orchestrator handle if it exists, ensuring it is joined and awaited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ use datafusion::{

use denormalized_orchestrator::{
channel_manager::take_receiver,
orchestrator::{self, OrchestrationMessage},
orchestrator::{OrchestrationMessage},
};
use futures::{executor::block_on, Stream, StreamExt};
use log::debug;
use serde::{Deserialize, Serialize};

use crate::{
config_extensions::denormalized_config::DenormalizedConfig,
physical_plan::utils::time::RecordBatchWatermark,
state_backend::slatedb::{get_global_slatedb, SlateDBWrapper},
utils::serialization::ArrayContainer,
Expand Down Expand Up @@ -73,11 +74,11 @@ pub struct GroupedWindowAggStream {
group_by: PhysicalGroupBy,
group_schema: Arc<Schema>,
context: Arc<TaskContext>,
epoch: i64,
checkpoint: bool,
partition: usize,
channel_tag: String,
receiver: Option<Receiver<OrchestrationMessage>>,
state_backend: Arc<SlateDBWrapper>,
state_backend: Option<Arc<SlateDBWrapper>>,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -147,11 +148,23 @@ impl GroupedWindowAggStream {
.and_then(|tag| take_receiver(tag.as_str()));

let channel_tag: String = channel_tag.unwrap_or(String::from(""));
let state_backend = get_global_slatedb().unwrap();

let serialized_state = block_on(state_backend.get(channel_tag.as_bytes().to_vec()));
let config_options = context
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let checkpoint = config_options.map_or(false, |c| c.checkpoint);

let mut serialized_state: Option<Vec<u8>> = None;
let mut state_backend = None;
if checkpoint {
let backend = get_global_slatedb().unwrap();
serialized_state = block_on(backend.get(channel_tag.as_bytes().to_vec()));
state_backend = Some(backend);
}

//let window_frames: BTreeMap<SystemTime, GroupedAggWindowFrame> = BTreeMap::new();
let mut stream = Self {
schema: agg_schema,
input,
Expand All @@ -166,7 +179,7 @@ impl GroupedWindowAggStream {
group_by,
group_schema,
context,
epoch: 0,
checkpoint,
partition,
channel_tag: channel_tag,
receiver,
Expand Down Expand Up @@ -340,19 +353,19 @@ impl GroupedWindowAggStream {
return Poll::Pending;
}
};
self.epoch += 1;

if orchestrator::SHOULD_CHECKPOINT {
let mut checkpoint_batch = false;

if self.checkpoint {
let r = self.receiver.as_ref().unwrap();
let mut epoch: u128 = 0;
for message in r.try_iter() {
debug!("received checkpoint barrier for {:?}", message);
if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message {
epoch = epoch_ts;
if let OrchestrationMessage::CheckpointBarrier(_epoch_ts) = message {
checkpoint_batch = true;
}
}

if epoch != 0 {
if checkpoint_batch {
// Prepare data for checkpointing

// Clone or extract necessary data
Expand Down Expand Up @@ -400,7 +413,7 @@ impl GroupedWindowAggStream {
let key = self.channel_tag.as_bytes().to_vec();

// Clone or use `Arc` for `state_backend`
let state_backend = self.state_backend.clone();
let state_backend = self.state_backend.clone().unwrap();

state_backend.put(key, serialized_checkpoint);
}
Expand Down
25 changes: 18 additions & 7 deletions crates/core/src/physical_plan/continuous/streaming_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ use datafusion::{
};
use denormalized_orchestrator::{
channel_manager::{create_channel, get_sender},
orchestrator::{self, OrchestrationMessage},
orchestrator::{OrchestrationMessage},
};
use futures::{Stream, StreamExt};
use tracing::debug;

use crate::physical_plan::{
continuous::grouped_window_agg_stream::GroupedWindowAggStream,
utils::{
accumulators::{create_accumulators, AccumulatorItem},
time::{system_time_from_epoch, RecordBatchWatermark},
use crate::{
config_extensions::denormalized_config::DenormalizedConfig,
physical_plan::{
continuous::grouped_window_agg_stream::GroupedWindowAggStream,
utils::{
accumulators::{create_accumulators, AccumulatorItem},
time::{system_time_from_epoch, RecordBatchWatermark},
},
},
};

Expand Down Expand Up @@ -427,7 +430,15 @@ impl ExecutionPlan for StreamingWindowExec {
.node_id()
.expect("expected node id to be set.");

let channel_tag = if orchestrator::SHOULD_CHECKPOINT {
let config_options = context
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let checkpoint = config_options.map_or(false, |c| c.checkpoint);

let channel_tag = if checkpoint {
let tag = format!("{}_{}", node_id, partition);
create_channel(tag.as_str(), 10);
let orchestrator_sender = get_sender("orchestrator");
Expand Down
2 changes: 0 additions & 2 deletions crates/orchestrator/src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ pub struct Orchestrator {
senders: HashMap<String, channel::Sender<OrchestrationMessage>>,
}

pub const SHOULD_CHECKPOINT: bool = false; // THIS WILL BE MOVED INTO CONFIG

/**
* 1. Keep track of checkpoint per source.
* 2. Tell each downstream which checkpoints it needs to know.
Expand Down
5 changes: 5 additions & 0 deletions examples/examples/simple_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ async fn main() -> Result<()> {
.init();

let bootstrap_servers = String::from("localhost:9092");
let config = Context::default_config().set_bool("denormalized_config.checkpoint", true);
let ctx = Context::with_config(config)?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg/job1"))
.await;

let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers);

// Connect to source topic
Expand Down
Loading