Skip to content

Commit

Permalink
refactor: use reference type in DataFrame
Browse files Browse the repository at this point in the history
Signed-off-by: bsbds <[email protected]>
  • Loading branch information
bsbds committed Mar 26, 2024
1 parent 2eac564 commit 5046879
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 32 deletions.
68 changes: 45 additions & 23 deletions crates/curp/src/server/storage/wal/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait FrameEncoder {
#[derive(Debug)]
pub(super) struct WAL<C, H = Sha256> {
/// Frames stored in decoding
frames: Vec<DataFrame<C>>,
frames: Vec<DataFrameOwned<C>>,
/// The hasher state for decoding
hasher: H,
}
Expand All @@ -48,7 +48,7 @@ pub(super) struct WAL<C, H = Sha256> {
#[derive(Debug)]
enum WALFrame<C> {
/// Data frame type
Data(DataFrame<C>),
Data(DataFrameOwned<C>),
/// Commit frame type
Commit(CommitFrame),
}
Expand All @@ -58,13 +58,25 @@ enum WALFrame<C> {
/// Contains either a log entry or a seal index
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DataFrame<C> {
pub(crate) enum DataFrameOwned<C> {
/// A Frame containing a log entry
Entry(LogEntry<C>),
/// A Frame containing the sealed index
SealIndex(LogIndex),
}

/// The data frame
///
/// Contains either a log entry or a seal index
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DataFrame<'a, C> {
/// A Frame containing a log entry
Entry(&'a LogEntry<C>),
/// A Frame containing the sealed index
SealIndex(LogIndex),
}

/// The commit frame
///
/// This frames contains a SHA256 checksum of all previous frames since last commit
Expand Down Expand Up @@ -98,14 +110,14 @@ impl<C> WAL<C> {
}
}

impl<C> Encoder<Vec<DataFrame<C>>> for WAL<C>
impl<C> Encoder<Vec<DataFrame<'_, C>>> for WAL<C>
where
C: Serialize,
{
type Error = io::Error;

/// Encodes a frame
fn encode(&mut self, frames: Vec<DataFrame<C>>) -> Result<Vec<u8>, Self::Error> {
fn encode(&mut self, frames: Vec<DataFrame<'_, C>>) -> Result<Vec<u8>, Self::Error> {
let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frame_data);
frame_data.extend_from_slice(&commit_frame.encode());
Expand All @@ -118,7 +130,7 @@ impl<C> Decoder for WAL<C>
where
C: Serialize + DeserializeOwned,
{
type Item = Vec<DataFrame<C>>;
type Item = Vec<DataFrameOwned<C>>;

type Error = WALError;

Expand Down Expand Up @@ -210,14 +222,14 @@ where
let entry: LogEntry<C> = bincode::deserialize(payload)
.map_err(|e| WALError::Corrupted(CorruptType::Codec(e.to_string())))?;

Ok(Some((Self::Data(DataFrame::Entry(entry)), 8 + len)))
Ok(Some((Self::Data(DataFrameOwned::Entry(entry)), 8 + len)))
}

/// Decodes an seal index frame from source
fn decode_seal_index(header: [u8; 8]) -> Result<Option<(Self, usize)>, WALError> {
let index = Self::decode_u64_from_header(header);

Ok(Some((Self::Data(DataFrame::SealIndex(index)), 8)))
Ok(Some((Self::Data(DataFrameOwned::SealIndex(index)), 8)))
}

/// Decodes a commit frame from source
Expand All @@ -241,7 +253,17 @@ where
}
}

impl<C> FrameType for DataFrame<C> {
impl<C> DataFrameOwned<C> {
/// Converts `DataFrameOwned` to `DataFrame`
pub(super) fn get_ref(&self) -> DataFrame<'_, C> {
match *self {
DataFrameOwned::Entry(ref entry) => DataFrame::Entry(entry),
DataFrameOwned::SealIndex(index) => DataFrame::SealIndex(index),
}
}
}

impl<C> FrameType for DataFrame<'_, C> {
fn frame_type(&self) -> u8 {
match *self {
DataFrame::Entry(_) => ENTRY,
Expand All @@ -250,7 +272,7 @@ impl<C> FrameType for DataFrame<C> {
}
}

impl<C> FrameEncoder for DataFrame<C>
impl<C> FrameEncoder for DataFrame<'_, C>
where
C: Serialize,
{
Expand Down Expand Up @@ -324,31 +346,31 @@ mod tests {
async fn frame_encode_decode_is_ok() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap());
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame.get_ref()]).unwrap());

let (data_frame_get, len) = codec.decode(&encoded).unwrap();
let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap();
let DataFrame::Entry(ref entry_get) = data_frame_get[0] else {
let DataFrameOwned::Entry(ref entry_get) = data_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};
let DataFrame::SealIndex(ref index) = seal_frame_get[0] else {
let DataFrameOwned::SealIndex(index) = seal_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};

assert_eq!(*entry_get, entry);
assert_eq!(*index, 1);
assert_eq!(index, 1);
}

#[tokio::test]
async fn frame_zero_write_will_be_detected() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded[0] = 0;

let err = codec.decode(&encoded).unwrap_err();
Expand All @@ -362,9 +384,9 @@ mod tests {
async fn frame_corrupt_will_be_detected() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded[1] = 0;

let err = codec.decode(&encoded).unwrap_err();
Expand Down
24 changes: 15 additions & 9 deletions crates/curp/src/server/storage/wal/segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ use tokio::{
};
use tokio_stream::StreamExt;

use crate::log_entry::LogEntry;

use super::{
codec::{DataFrame, WAL},
codec::{DataFrame, DataFrameOwned, WAL},
error::{CorruptType, WALError},
framed::{Decoder, Encoder},
util::{get_checksum, parse_u64, validate_data, LockedFile},
WAL_FILE_EXT, WAL_MAGIC, WAL_VERSION,
};
use crate::log_entry::LogEntry;

/// The size of wal file header in bytes
const WAL_HEADER_SIZE: usize = 56;
Expand Down Expand Up @@ -106,7 +107,7 @@ impl WALSegment {
let frame = frames
.last()
.unwrap_or_else(|| unreachable!("a batch should contains at least one frame"));
if let DataFrame::SealIndex(index) = *frame {
if let DataFrameOwned::SealIndex(index) = *frame {
highest_index = index;
}
}
Expand All @@ -116,7 +117,7 @@ impl WALSegment {

// Get log entries that index is no larger than `highest_index`
Ok(frame_batches.into_iter().flatten().filter_map(move |f| {
if let DataFrame::Entry(e) = f {
if let DataFrameOwned::Entry(e) = f {
(e.index <= highest_index).then_some(e)
} else {
None
Expand Down Expand Up @@ -168,7 +169,7 @@ impl WALSegment {

/// Updates the seal index
pub(super) fn update_seal_index(&mut self, index: LogIndex) {
self.seal_index = self.seal_index.max(index);
self.seal_index = self.seal_index.min(index);
}

/// Get the size of the segment
Expand Down Expand Up @@ -289,9 +290,10 @@ mod tests {

use curp_test_utils::test_cmd::TestCommand;

use super::*;
use crate::log_entry::EntryData;

use super::*;

#[test]
fn gen_parse_header_is_correct() {
fn corrupt(mut header: Vec<u8>, pos: usize) -> Vec<u8> {
Expand Down Expand Up @@ -352,15 +354,19 @@ mod tests {

let frames: Vec<_> = (0..100)
.map(|i| {
DataFrame::Entry(LogEntry::new(
DataFrameOwned::Entry(LogEntry::new(
i,
1,
crate::rpc::ProposeId(0, 0),
EntryData::Command(Arc::new(TestCommand::new_put(vec![i as u32], i as u32))),
))
})
.collect();
segment.write_sync(frames.clone(), WAL::new());

segment.write_sync(
frames.iter().map(DataFrameOwned::get_ref).collect(),
WAL::new(),
);

drop(segment);

Expand All @@ -369,7 +375,7 @@ mod tests {
let recovered: Vec<_> = segment
.recover_segment_logs::<TestCommand>()
.unwrap()
.map(|e| DataFrame::Entry(e))
.map(|e| DataFrameOwned::Entry(e))
.collect();
assert_eq!(frames, recovered);
}
Expand Down

0 comments on commit 5046879

Please sign in to comment.