diff --git a/crates/curp/src/server/storage/wal/codec.rs b/crates/curp/src/server/storage/wal/codec.rs index b891ccbd1..958f6af21 100644 --- a/crates/curp/src/server/storage/wal/codec.rs +++ b/crates/curp/src/server/storage/wal/codec.rs @@ -39,7 +39,7 @@ trait FrameEncoder { #[derive(Debug)] pub(super) struct WAL { /// Frames stored in decoding - frames: Vec>, + frames: Vec>, /// The hasher state for decoding hasher: H, } @@ -48,7 +48,7 @@ pub(super) struct WAL { #[derive(Debug)] enum WALFrame { /// Data frame type - Data(DataFrame), + Data(DataFrameOwned), /// Commit frame type Commit(CommitFrame), } @@ -58,13 +58,25 @@ enum WALFrame { /// Contains either a log entry or a seal index #[derive(Debug, Clone)] #[cfg_attr(test, derive(PartialEq))] -pub(crate) enum DataFrame { +pub(crate) enum DataFrameOwned { /// A Frame containing a log entry Entry(LogEntry), /// A Frame containing the sealed index SealIndex(LogIndex), } +/// The data frame +/// +/// Contains either a log entry or a seal index +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub(crate) enum DataFrame<'a, C> { + /// A Frame containing a log entry + Entry(&'a LogEntry), + /// A Frame containing the sealed index + SealIndex(LogIndex), +} + /// The commit frame /// /// This frames contains a SHA256 checksum of all previous frames since last commit @@ -98,14 +110,14 @@ impl WAL { } } -impl Encoder>> for WAL +impl Encoder>> for WAL where C: Serialize, { type Error = io::Error; /// Encodes a frame - fn encode(&mut self, frames: Vec>) -> Result, Self::Error> { + fn encode(&mut self, frames: Vec>) -> Result, Self::Error> { let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect(); let commit_frame = CommitFrame::new_from_data(&frame_data); frame_data.extend_from_slice(&commit_frame.encode()); @@ -118,7 +130,7 @@ impl Decoder for WAL where C: Serialize + DeserializeOwned, { - type Item = Vec>; + type Item = Vec>; type Error = WALError; @@ -208,14 +220,14 @@ where let entry: LogEntry = bincode::deserialize(payload) .map_err(|e| WALError::Corrupted(CorruptType::Codec(e.to_string())))?; - Ok(Some((Self::Data(DataFrame::Entry(entry)), 8 + len))) + Ok(Some((Self::Data(DataFrameOwned::Entry(entry)), 8 + len))) } /// Decodes an seal index frame from source fn decode_seal_index(header: [u8; 8]) -> Result, WALError> { let index = Self::decode_u64_from_header(header); - Ok(Some((Self::Data(DataFrame::SealIndex(index)), 8))) + Ok(Some((Self::Data(DataFrameOwned::SealIndex(index)), 8))) } /// Decodes a commit frame from source @@ -239,7 +251,17 @@ where } } -impl FrameType for DataFrame { +impl DataFrameOwned { + /// Converts `DataFrameOwned` to `DataFrame` + pub(super) fn get_ref(&self) -> DataFrame<'_, C> { + match *self { + DataFrameOwned::Entry(ref entry) => DataFrame::Entry(entry), + DataFrameOwned::SealIndex(index) => DataFrame::SealIndex(index), + } + } +} + +impl FrameType for DataFrame<'_, C> { fn frame_type(&self) -> u8 { match *self { DataFrame::Entry(_) => ENTRY, @@ -248,7 +270,7 @@ impl FrameType for DataFrame { } } -impl FrameEncoder for DataFrame +impl FrameEncoder for DataFrame<'_, C> where C: Serialize, { @@ -322,31 +344,31 @@ mod tests { async fn frame_encode_decode_is_ok() { let mut codec = WAL::::new(); let entry = LogEntry::::new(1, 1, ProposeId(1, 2), EntryData::Empty); - let data_frame = DataFrame::Entry(entry.clone()); - let seal_frame = DataFrame::::SealIndex(1); - let mut encoded = codec.encode(vec![data_frame]).unwrap(); - encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap()); + let data_frame = DataFrameOwned::Entry(entry.clone()); + let seal_frame = DataFrameOwned::::SealIndex(1); + let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap(); + encoded.extend_from_slice(&codec.encode(vec![seal_frame.get_ref()]).unwrap()); let (data_frame_get, len) = codec.decode(&encoded).unwrap(); let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap(); - let DataFrame::Entry(ref entry_get) = data_frame_get[0] else { + let DataFrameOwned::Entry(ref entry_get) = data_frame_get[0] else { panic!("frame should be type: DataFrame::Entry"); }; - let DataFrame::SealIndex(ref index) = seal_frame_get[0] else { + let DataFrameOwned::SealIndex(index) = seal_frame_get[0] else { panic!("frame should be type: DataFrame::Entry"); }; assert_eq!(*entry_get, entry); - assert_eq!(*index, 1); + assert_eq!(index, 1); } #[tokio::test] async fn frame_zero_write_will_be_detected() { let mut codec = WAL::::new(); let entry = LogEntry::::new(1, 1, ProposeId(1, 2), EntryData::Empty); - let data_frame = DataFrame::Entry(entry.clone()); - let seal_frame = DataFrame::::SealIndex(1); - let mut encoded = codec.encode(vec![data_frame]).unwrap(); + let data_frame = DataFrameOwned::Entry(entry.clone()); + let seal_frame = DataFrameOwned::::SealIndex(1); + let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap(); encoded[0] = 0; let err = codec.decode(&encoded).unwrap_err(); @@ -357,9 +379,9 @@ mod tests { async fn frame_corrupt_will_be_detected() { let mut codec = WAL::::new(); let entry = LogEntry::::new(1, 1, ProposeId(1, 2), EntryData::Empty); - let data_frame = DataFrame::Entry(entry.clone()); - let seal_frame = DataFrame::::SealIndex(1); - let mut encoded = codec.encode(vec![data_frame]).unwrap(); + let data_frame = DataFrameOwned::Entry(entry.clone()); + let seal_frame = DataFrameOwned::::SealIndex(1); + let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap(); encoded[1] = 0; let err = codec.decode(&encoded).unwrap_err(); diff --git a/crates/curp/src/server/storage/wal/pipeline.rs b/crates/curp/src/server/storage/wal/pipeline.rs index deadd5418..7be1d4540 100644 --- a/crates/curp/src/server/storage/wal/pipeline.rs +++ b/crates/curp/src/server/storage/wal/pipeline.rs @@ -6,6 +6,7 @@ use std::{ Arc, }, task::Poll, + thread::JoinHandle, }; use clippy_utilities::OverflowArithmetic; @@ -28,15 +29,19 @@ pub(super) struct FilePipeline { /// /// As tokio::fs is generally slower than std::fs, we use synchronous file allocation. /// Please also refer to the issue discussed on the tokio repo: https://github.com/tokio-rs/tokio/issues/3664 - file_iter: flume::IntoIter, + file_iter: Option>, /// Stopped flag stopped: Arc, + /// Join handle of the allocation task + file_alloc_task_handle: Option>, } impl FilePipeline { /// Creates a new `FilePipeline` - pub(super) fn new(dir: PathBuf, file_size: u64) -> io::Result { - Self::clean_up(&dir)?; + pub(super) fn new(dir: PathBuf, file_size: u64) -> Self { + if let Err(e) = Self::clean_up(&dir) { + error!("Failed to clean up tmp files: {e}"); + } let (file_tx, file_rx) = flume::bounded(1); let dir_c = dir.clone(); @@ -44,61 +49,74 @@ impl FilePipeline { let stopped_c = Arc::clone(&stopped); #[cfg(not(madsim))] - let _ignore = std::thread::spawn(move || { - let mut file_count = 0; - loop { - match Self::alloc(&dir_c, file_size, &mut file_count) { - Ok(file) => { - if file_tx.send(file).is_err() { - // The receiver is already dropped, stop this task - break; - } - if stopped_c.load(Ordering::Relaxed) { - if let Err(e) = Self::clean_up(&dir_c) { - error!("failed to clean up pipeline temp files: {e}"); + { + let file_alloc_task_handle = std::thread::spawn(move || { + let mut file_count = 0; + loop { + match Self::alloc(&dir_c, file_size, &mut file_count) { + Ok(file) => { + if file_tx.send(file).is_err() { + // The receiver is already dropped, stop this task + break; + } + if stopped_c.load(Ordering::Relaxed) { + if let Err(e) = Self::clean_up(&dir_c) { + error!("failed to clean up pipeline temp files: {e}"); + } + break; } + } + Err(e) => { + error!("failed to allocate file: {e}"); break; } } - Err(e) => { - error!("failed to allocate file: {e}"); - break; - } } + }); + + Self { + dir, + file_size, + file_iter: Some(file_rx.into_iter()), + stopped, + file_alloc_task_handle: Some(file_alloc_task_handle), } - }); + } #[cfg(madsim)] - let _ignore = tokio::spawn(async move { - let mut file_count = 0; - loop { - match Self::alloc(&dir_c, file_size, &mut file_count) { - Ok(file) => { - if file_tx.send_async(file).await.is_err() { - // The receiver is already dropped, stop this task - break; - } - if stopped_c.load(Ordering::Relaxed) { - if let Err(e) = Self::clean_up(&dir_c) { - error!("failed to clean up pipeline temp files: {e}"); + { + let _ignore = tokio::spawn(async move { + let mut file_count = 0; + loop { + match Self::alloc(&dir_c, file_size, &mut file_count) { + Ok(file) => { + if file_tx.send_async(file).await.is_err() { + // The receiver is already dropped, stop this task + break; } + if stopped_c.load(Ordering::Relaxed) { + if let Err(e) = Self::clean_up(&dir_c) { + error!("failed to clean up pipeline temp files: {e}"); + } + break; + } + } + Err(e) => { + error!("failed to allocate file: {e}"); break; } } - Err(e) => { - error!("failed to allocate file: {e}"); - break; - } } + }); + + Self { + dir, + file_size, + file_iter: Some(file_rx.into_iter()), + stopped, + file_alloc_task_handle: None, } - }); - - Ok(Self { - dir, - file_size, - file_iter: file_rx.into_iter(), - stopped, - }) + } } /// Stops the pipeline @@ -132,6 +150,11 @@ impl FilePipeline { impl Drop for FilePipeline { fn drop(&mut self) { self.stop(); + // Drops the file rx so that the allocation task could exit + drop(self.file_iter.take()); + if let Some(Err(e)) = self.file_alloc_task_handle.take().map(JoinHandle::join) { + error!("failed to join file allocation task: {e:?}"); + } } } @@ -142,7 +165,11 @@ impl Iterator for FilePipeline { if self.stopped.load(Ordering::Relaxed) { return None; } - self.file_iter.next().map(Ok) + self.file_iter + .as_mut() + .unwrap_or_else(|| unreachable!("Option is always `Some`")) + .next() + .map(Ok) } } @@ -164,7 +191,7 @@ mod tests { async fn file_pipeline_is_ok() { let file_size = 1024; let dir = tempfile::tempdir().unwrap(); - let mut pipeline = FilePipeline::new(dir.as_ref().into(), file_size).unwrap(); + let mut pipeline = FilePipeline::new(dir.as_ref().into(), file_size); let check_size = |mut file: LockedFile| { let file = file.into_std(); diff --git a/crates/curp/src/server/storage/wal/segment.rs b/crates/curp/src/server/storage/wal/segment.rs index 217d1f066..96166c7e4 100644 --- a/crates/curp/src/server/storage/wal/segment.rs +++ b/crates/curp/src/server/storage/wal/segment.rs @@ -17,14 +17,15 @@ use tokio::{ }; use tokio_stream::StreamExt; +use crate::log_entry::LogEntry; + use super::{ - codec::{DataFrame, WAL}, + codec::{DataFrame, DataFrameOwned, WAL}, error::{CorruptType, WALError}, framed::{Decoder, Encoder}, util::{get_checksum, parse_u64, validate_data, LockedFile}, WAL_FILE_EXT, WAL_MAGIC, WAL_VERSION, }; -use crate::log_entry::LogEntry; /// The size of wal file header in bytes const WAL_HEADER_SIZE: usize = 56; @@ -106,7 +107,7 @@ impl WALSegment { let frame = frames .last() .unwrap_or_else(|| unreachable!("a batch should contains at least one frame")); - if let DataFrame::SealIndex(index) = *frame { + if let DataFrameOwned::SealIndex(index) = *frame { highest_index = index; } } @@ -116,7 +117,7 @@ impl WALSegment { // Get log entries that index is no larger than `highest_index` Ok(frame_batches.into_iter().flatten().filter_map(move |f| { - if let DataFrame::Entry(e) = f { + if let DataFrameOwned::Entry(e) = f { (e.index <= highest_index).then_some(e) } else { None @@ -185,7 +186,7 @@ impl WALSegment { /// Updates the seal index pub(super) fn update_seal_index(&mut self, index: LogIndex) { - self.seal_index = self.seal_index.max(index); + self.seal_index = index; } /// Get the size of the segment @@ -306,9 +307,10 @@ mod tests { use curp_test_utils::test_cmd::TestCommand; - use super::*; use crate::log_entry::EntryData; + use super::*; + #[test] fn gen_parse_header_is_correct() { fn corrupt(mut header: Vec, pos: usize) -> Vec { @@ -369,7 +371,7 @@ mod tests { let frames: Vec<_> = (0..100) .map(|i| { - DataFrame::Entry(LogEntry::new( + DataFrameOwned::Entry(LogEntry::new( i, 1, crate::rpc::ProposeId(0, 0), @@ -377,7 +379,11 @@ mod tests { )) }) .collect(); - segment.write_sync(frames.clone(), WAL::new()); + + segment.write_sync( + frames.iter().map(DataFrameOwned::get_ref).collect(), + WAL::new(), + ); drop(segment); @@ -386,7 +392,7 @@ mod tests { let recovered: Vec<_> = segment .recover_segment_logs::() .unwrap() - .map(|e| DataFrame::Entry(e)) + .map(|e| DataFrameOwned::Entry(e)) .collect(); assert_eq!(frames, recovered); }