diff --git a/crates/curp/src/server/storage/wal/codec.rs b/crates/curp/src/server/storage/wal/codec.rs index 958f6af21..fc93801c3 100644 --- a/crates/curp/src/server/storage/wal/codec.rs +++ b/crates/curp/src/server/storage/wal/codec.rs @@ -21,6 +21,10 @@ const ENTRY: u8 = 0x01; const SEAL: u8 = 0x02; /// Commit frame type const COMMIT: u8 = 0x03; +/// The size in bytes of an frame header +const FRAME_HEADER_SIZE: usize = 8; +/// The size in bytes of an sha256 checksum +const CHECK_SUM_SIZE: usize = 32; /// Getting the frame type trait FrameType { @@ -118,7 +122,10 @@ where /// Encodes a frame fn encode(&mut self, frames: Vec>) -> Result, Self::Error> { - let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect(); + let mut frame_data = Vec::new(); + for frame in frames { + frame_data.extend_from_slice(&frame.encode()); + } let commit_frame = CommitFrame::new_from_data(&frame_data); frame_data.extend_from_slice(&commit_frame.encode()); @@ -192,18 +199,18 @@ where /// | Commit | 0x03 | Stores the checksum | /// |------------+-------+-------------------------------------------------------| fn decode(src: &[u8]) -> Result, WALError> { - if src.len() < 8 { + if src.len() < FRAME_HEADER_SIZE { return Ok(None); } - let header: [u8; 8] = src[0..8] + let header: [u8; FRAME_HEADER_SIZE] = src[..FRAME_HEADER_SIZE] .try_into() .unwrap_or_else(|_| unreachable!("this conversion will always succeed")); let frame_type = header[0]; match frame_type { INVALID => Err(WALError::MaybeEnded), - ENTRY => Self::decode_entry(header, &src[8..]), + ENTRY => Self::decode_entry(header, &src[FRAME_HEADER_SIZE..]), SEAL => Self::decode_seal_index(header), - COMMIT => Self::decode_commit(&src[8..]), + COMMIT => Self::decode_commit(&src[FRAME_HEADER_SIZE..]), _ => Err(WALError::Corrupted(CorruptType::Codec( "Unexpected frame type".to_owned(), ))), @@ -211,7 +218,10 @@ where } /// Decodes an entry frame from source - fn decode_entry(header: [u8; 8], src: &[u8]) -> Result, WALError> { + fn decode_entry( + header: [u8; FRAME_HEADER_SIZE], + src: &[u8], + ) -> Result, WALError> { let len: usize = Self::decode_u64_from_header(header).numeric_cast(); if src.len() < len { return Ok(None); @@ -220,31 +230,42 @@ where let entry: LogEntry = bincode::deserialize(payload) .map_err(|e| WALError::Corrupted(CorruptType::Codec(e.to_string())))?; - Ok(Some((Self::Data(DataFrameOwned::Entry(entry)), 8 + len))) + Ok(Some(( + Self::Data(DataFrameOwned::Entry(entry)), + FRAME_HEADER_SIZE + len, + ))) } /// Decodes an seal index frame from source - fn decode_seal_index(header: [u8; 8]) -> Result, WALError> { + fn decode_seal_index( + header: [u8; FRAME_HEADER_SIZE], + ) -> Result, WALError> { let index = Self::decode_u64_from_header(header); - Ok(Some((Self::Data(DataFrameOwned::SealIndex(index)), 8))) + Ok(Some(( + Self::Data(DataFrameOwned::SealIndex(index)), + FRAME_HEADER_SIZE, + ))) } /// Decodes a commit frame from source fn decode_commit(src: &[u8]) -> Result, WALError> { - if src.len() < 32 { + if src.len() < CHECK_SUM_SIZE { return Ok(None); } - let checksum = src[..32].to_vec(); + let checksum = src[..CHECK_SUM_SIZE].to_vec(); - Ok(Some((Self::Commit(CommitFrame { checksum }), 8 + 32))) + Ok(Some(( + Self::Commit(CommitFrame { checksum }), + FRAME_HEADER_SIZE + CHECK_SUM_SIZE, + ))) } /// Gets a u64 from the header /// /// NOTE: The u64 is encoded using 7 bytes, it can be either a length /// or a log index that is smaller than `2^56` - fn decode_u64_from_header(mut header: [u8; 8]) -> u64 { + fn decode_u64_from_header(mut header: [u8; FRAME_HEADER_SIZE]) -> u64 { header.rotate_left(1); header[7] = 0; u64::from_le_bytes(header) @@ -282,17 +303,18 @@ where .unwrap_or_else(|_| unreachable!("serialization should never fail")); let len = entry_bytes.len(); assert_eq!(len >> 56, 0, "log entry length: {len} too large"); - let len_bytes = len.to_le_bytes().into_iter().take(7); - let header = std::iter::once(self.frame_type()).chain(len_bytes); - header.chain(entry_bytes).collect() + let mut bytes = Vec::with_capacity(FRAME_HEADER_SIZE + entry_bytes.len()); + bytes.push(self.frame_type()); + bytes.extend_from_slice(&len.to_le_bytes()[..7]); + bytes.extend_from_slice(&entry_bytes); + bytes } DataFrame::SealIndex(index) => { assert_eq!(index >> 56, 0, "log index: {index} too large"); - // use the first 7 bytes - let index_bytes = index.to_le_bytes().into_iter().take(7); - std::iter::once(self.frame_type()) - .chain(index_bytes) - .collect() + let mut bytes = index.to_le_bytes(); + bytes.rotate_right(1); + bytes[0] = self.frame_type(); + bytes.to_vec() } } } @@ -319,9 +341,16 @@ impl FrameType for CommitFrame { } impl FrameEncoder for CommitFrame { + #[allow( + clippy::arithmetic_side_effects, // won't overflow + clippy::indexing_slicing // index position is always valid + )] fn encode(&self) -> Vec { - let header = std::iter::once(self.frame_type()).chain([0u8; 7]); - header.chain(self.checksum.clone()).collect() + let mut bytes = Vec::with_capacity(FRAME_HEADER_SIZE + self.checksum.len()); + bytes.extend_from_slice(&[0; FRAME_HEADER_SIZE]); + bytes[0] = self.frame_type(); + bytes.extend_from_slice(&self.checksum); + bytes } } diff --git a/crates/curp/src/server/storage/wal/pipeline.rs b/crates/curp/src/server/storage/wal/pipeline.rs index 7be1d4540..5e8385a37 100644 --- a/crates/curp/src/server/storage/wal/pipeline.rs +++ b/crates/curp/src/server/storage/wal/pipeline.rs @@ -34,6 +34,10 @@ pub(super) struct FilePipeline { stopped: Arc, /// Join handle of the allocation task file_alloc_task_handle: Option>, + // #[cfg_attr(not(madsim), allow(unused))] + #[cfg(madsim)] + /// File count used in madsim tests + file_count: usize, } impl FilePipeline { @@ -43,13 +47,13 @@ impl FilePipeline { error!("Failed to clean up tmp files: {e}"); } - let (file_tx, file_rx) = flume::bounded(1); let dir_c = dir.clone(); let stopped = Arc::new(AtomicBool::new(false)); let stopped_c = Arc::clone(&stopped); #[cfg(not(madsim))] { + let (file_tx, file_rx) = flume::bounded(1); let file_alloc_task_handle = std::thread::spawn(move || { let mut file_count = 0; loop { @@ -85,36 +89,13 @@ impl FilePipeline { #[cfg(madsim)] { - let _ignore = tokio::spawn(async move { - let mut file_count = 0; - loop { - match Self::alloc(&dir_c, file_size, &mut file_count) { - Ok(file) => { - if file_tx.send_async(file).await.is_err() { - // The receiver is already dropped, stop this task - break; - } - if stopped_c.load(Ordering::Relaxed) { - if let Err(e) = Self::clean_up(&dir_c) { - error!("failed to clean up pipeline temp files: {e}"); - } - break; - } - } - Err(e) => { - error!("failed to allocate file: {e}"); - break; - } - } - } - }); - Self { dir, file_size, - file_iter: Some(file_rx.into_iter()), + file_iter: None, stopped, file_alloc_task_handle: None, + file_count: 0, } } } @@ -161,6 +142,7 @@ impl Drop for FilePipeline { impl Iterator for FilePipeline { type Item = io::Result; + #[cfg(not(madsim))] fn next(&mut self) -> Option { if self.stopped.load(Ordering::Relaxed) { return None; @@ -171,6 +153,14 @@ impl Iterator for FilePipeline { .next() .map(Ok) } + + #[cfg(madsim)] + fn next(&mut self) -> Option { + if self.stopped.load(Ordering::Relaxed) { + return None; + } + Some(Self::alloc(&self.dir, self.file_size, &mut self.file_count)) + } } impl std::fmt::Debug for FilePipeline { diff --git a/crates/curp/src/server/storage/wal/segment.rs b/crates/curp/src/server/storage/wal/segment.rs index 24c5deb06..c50ab6573 100644 --- a/crates/curp/src/server/storage/wal/segment.rs +++ b/crates/curp/src/server/storage/wal/segment.rs @@ -96,13 +96,17 @@ impl WALSegment { &mut self, ) -> Result>, WALError> where - C: Serialize + DeserializeOwned + 'static, + C: Serialize + DeserializeOwned + 'static + std::fmt::Debug, { let frame_batches = self.read_all(WAL::::new())?; + let frame_batches_filtered: Vec<_> = frame_batches + .into_iter() + .filter(|b| !b.is_empty()) + .collect(); // The highest_index of this segment let mut highest_index = u64::MAX; // We get the last frame batch to check it's type - if let Some(frames) = frame_batches.last() { + if let Some(frames) = frame_batches_filtered.last() { let frame = frames .last() .unwrap_or_else(|| unreachable!("a batch should contains at least one frame")); @@ -115,13 +119,16 @@ impl WALSegment { self.update_seal_index(highest_index); // Get log entries that index is no larger than `highest_index` - Ok(frame_batches.into_iter().flatten().filter_map(move |f| { - if let DataFrameOwned::Entry(e) = f { - (e.index <= highest_index).then_some(e) - } else { - None - } - })) + Ok(frame_batches_filtered + .into_iter() + .flatten() + .filter_map(move |f| { + if let DataFrameOwned::Entry(e) = f { + (e.index <= highest_index).then_some(e) + } else { + None + } + })) } /// Seal the current segment