From 504687912d5b533f15dd2b6c94024f85f7ca302f Mon Sep 17 00:00:00 2001 From: bsbds <69835502+bsbds@users.noreply.github.com> Date: Tue, 26 Mar 2024 10:22:41 +0800 Subject: [PATCH] refactor: use reference type in DataFrame Signed-off-by: bsbds <69835502+bsbds@users.noreply.github.com> --- crates/curp/src/server/storage/wal/codec.rs | 68 ++++++++++++------- crates/curp/src/server/storage/wal/segment.rs | 24 ++++--- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/crates/curp/src/server/storage/wal/codec.rs b/crates/curp/src/server/storage/wal/codec.rs index 41c2a6575..a124497a3 100644 --- a/crates/curp/src/server/storage/wal/codec.rs +++ b/crates/curp/src/server/storage/wal/codec.rs @@ -39,7 +39,7 @@ trait FrameEncoder { #[derive(Debug)] pub(super) struct WAL { /// Frames stored in decoding - frames: Vec>, + frames: Vec>, /// The hasher state for decoding hasher: H, } @@ -48,7 +48,7 @@ pub(super) struct WAL { #[derive(Debug)] enum WALFrame { /// Data frame type - Data(DataFrame), + Data(DataFrameOwned), /// Commit frame type Commit(CommitFrame), } @@ -58,13 +58,25 @@ enum WALFrame { /// Contains either a log entry or a seal index #[derive(Debug, Clone)] #[cfg_attr(test, derive(PartialEq))] -pub(crate) enum DataFrame { +pub(crate) enum DataFrameOwned { /// A Frame containing a log entry Entry(LogEntry), /// A Frame containing the sealed index SealIndex(LogIndex), } +/// The data frame +/// +/// Contains either a log entry or a seal index +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub(crate) enum DataFrame<'a, C> { + /// A Frame containing a log entry + Entry(&'a LogEntry), + /// A Frame containing the sealed index + SealIndex(LogIndex), +} + /// The commit frame /// /// This frames contains a SHA256 checksum of all previous frames since last commit @@ -98,14 +110,14 @@ impl WAL { } } -impl Encoder>> for WAL +impl Encoder>> for WAL where C: Serialize, { type Error = io::Error; /// Encodes a frame - fn encode(&mut self, frames: Vec>) -> Result, Self::Error> { + fn encode(&mut self, frames: Vec>) -> Result, Self::Error> { let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect(); let commit_frame = CommitFrame::new_from_data(&frame_data); frame_data.extend_from_slice(&commit_frame.encode()); @@ -118,7 +130,7 @@ impl Decoder for WAL where C: Serialize + DeserializeOwned, { - type Item = Vec>; + type Item = Vec>; type Error = WALError; @@ -210,14 +222,14 @@ where let entry: LogEntry = bincode::deserialize(payload) .map_err(|e| WALError::Corrupted(CorruptType::Codec(e.to_string())))?; - Ok(Some((Self::Data(DataFrame::Entry(entry)), 8 + len))) + Ok(Some((Self::Data(DataFrameOwned::Entry(entry)), 8 + len))) } /// Decodes an seal index frame from source fn decode_seal_index(header: [u8; 8]) -> Result, WALError> { let index = Self::decode_u64_from_header(header); - Ok(Some((Self::Data(DataFrame::SealIndex(index)), 8))) + Ok(Some((Self::Data(DataFrameOwned::SealIndex(index)), 8))) } /// Decodes a commit frame from source @@ -241,7 +253,17 @@ where } } -impl FrameType for DataFrame { +impl DataFrameOwned { + /// Converts `DataFrameOwned` to `DataFrame` + pub(super) fn get_ref(&self) -> DataFrame<'_, C> { + match *self { + DataFrameOwned::Entry(ref entry) => DataFrame::Entry(entry), + DataFrameOwned::SealIndex(index) => DataFrame::SealIndex(index), + } + } +} + +impl FrameType for DataFrame<'_, C> { fn frame_type(&self) -> u8 { match *self { DataFrame::Entry(_) => ENTRY, @@ -250,7 +272,7 @@ impl FrameType for DataFrame { } } -impl FrameEncoder for DataFrame +impl FrameEncoder for DataFrame<'_, C> where C: Serialize, { @@ -324,31 +346,31 @@ mod tests { async fn frame_encode_decode_is_ok() { let mut codec = WAL::::new(); let entry = LogEntry::::new(1, 1, ProposeId(1, 2), EntryData::Empty); - let data_frame = DataFrame::Entry(entry.clone()); - let seal_frame = DataFrame::::SealIndex(1); - let mut encoded = codec.encode(vec![data_frame]).unwrap(); - encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap()); + let data_frame = DataFrameOwned::Entry(entry.clone()); + let seal_frame = DataFrameOwned::::SealIndex(1); + let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap(); + encoded.extend_from_slice(&codec.encode(vec![seal_frame.get_ref()]).unwrap()); let (data_frame_get, len) = codec.decode(&encoded).unwrap(); let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap(); - let DataFrame::Entry(ref entry_get) = data_frame_get[0] else { + let DataFrameOwned::Entry(ref entry_get) = data_frame_get[0] else { panic!("frame should be type: DataFrame::Entry"); }; - let DataFrame::SealIndex(ref index) = seal_frame_get[0] else { + let DataFrameOwned::SealIndex(index) = seal_frame_get[0] else { panic!("frame should be type: DataFrame::Entry"); }; assert_eq!(*entry_get, entry); - assert_eq!(*index, 1); + assert_eq!(index, 1); } #[tokio::test] async fn frame_zero_write_will_be_detected() { let mut codec = WAL::::new(); let entry = LogEntry::::new(1, 1, ProposeId(1, 2), EntryData::Empty); - let data_frame = DataFrame::Entry(entry.clone()); - let seal_frame = DataFrame::::SealIndex(1); - let mut encoded = codec.encode(vec![data_frame]).unwrap(); + let data_frame = DataFrameOwned::Entry(entry.clone()); + let seal_frame = DataFrameOwned::::SealIndex(1); + let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap(); encoded[0] = 0; let err = codec.decode(&encoded).unwrap_err(); @@ -362,9 +384,9 @@ mod tests { async fn frame_corrupt_will_be_detected() { let mut codec = WAL::::new(); let entry = LogEntry::::new(1, 1, ProposeId(1, 2), EntryData::Empty); - let data_frame = DataFrame::Entry(entry.clone()); - let seal_frame = DataFrame::::SealIndex(1); - let mut encoded = codec.encode(vec![data_frame]).unwrap(); + let data_frame = DataFrameOwned::Entry(entry.clone()); + let seal_frame = DataFrameOwned::::SealIndex(1); + let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap(); encoded[1] = 0; let err = codec.decode(&encoded).unwrap_err(); diff --git a/crates/curp/src/server/storage/wal/segment.rs b/crates/curp/src/server/storage/wal/segment.rs index 6e78f4999..3f9a76781 100644 --- a/crates/curp/src/server/storage/wal/segment.rs +++ b/crates/curp/src/server/storage/wal/segment.rs @@ -17,14 +17,15 @@ use tokio::{ }; use tokio_stream::StreamExt; +use crate::log_entry::LogEntry; + use super::{ - codec::{DataFrame, WAL}, + codec::{DataFrame, DataFrameOwned, WAL}, error::{CorruptType, WALError}, framed::{Decoder, Encoder}, util::{get_checksum, parse_u64, validate_data, LockedFile}, WAL_FILE_EXT, WAL_MAGIC, WAL_VERSION, }; -use crate::log_entry::LogEntry; /// The size of wal file header in bytes const WAL_HEADER_SIZE: usize = 56; @@ -106,7 +107,7 @@ impl WALSegment { let frame = frames .last() .unwrap_or_else(|| unreachable!("a batch should contains at least one frame")); - if let DataFrame::SealIndex(index) = *frame { + if let DataFrameOwned::SealIndex(index) = *frame { highest_index = index; } } @@ -116,7 +117,7 @@ impl WALSegment { // Get log entries that index is no larger than `highest_index` Ok(frame_batches.into_iter().flatten().filter_map(move |f| { - if let DataFrame::Entry(e) = f { + if let DataFrameOwned::Entry(e) = f { (e.index <= highest_index).then_some(e) } else { None @@ -168,7 +169,7 @@ impl WALSegment { /// Updates the seal index pub(super) fn update_seal_index(&mut self, index: LogIndex) { - self.seal_index = self.seal_index.max(index); + self.seal_index = self.seal_index.min(index); } /// Get the size of the segment @@ -289,9 +290,10 @@ mod tests { use curp_test_utils::test_cmd::TestCommand; - use super::*; use crate::log_entry::EntryData; + use super::*; + #[test] fn gen_parse_header_is_correct() { fn corrupt(mut header: Vec, pos: usize) -> Vec { @@ -352,7 +354,7 @@ mod tests { let frames: Vec<_> = (0..100) .map(|i| { - DataFrame::Entry(LogEntry::new( + DataFrameOwned::Entry(LogEntry::new( i, 1, crate::rpc::ProposeId(0, 0), @@ -360,7 +362,11 @@ mod tests { )) }) .collect(); - segment.write_sync(frames.clone(), WAL::new()); + + segment.write_sync( + frames.iter().map(DataFrameOwned::get_ref).collect(), + WAL::new(), + ); drop(segment); @@ -369,7 +375,7 @@ mod tests { let recovered: Vec<_> = segment .recover_segment_logs::() .unwrap() - .map(|e| DataFrame::Entry(e)) + .map(|e| DataFrameOwned::Entry(e)) .collect(); assert_eq!(frames, recovered); }