Skip to content

Commit

Permalink
refactor: use reference type in DataFrame
Browse files Browse the repository at this point in the history
Signed-off-by: bsbds <[email protected]>
  • Loading branch information
bsbds authored and Phoenix500526 committed Jun 3, 2024
1 parent 8213d2d commit 661c9f0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 78 deletions.
68 changes: 45 additions & 23 deletions crates/curp/src/server/storage/wal/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait FrameEncoder {
#[derive(Debug)]
pub(super) struct WAL<C, H = Sha256> {
/// Frames stored in decoding
frames: Vec<DataFrame<C>>,
frames: Vec<DataFrameOwned<C>>,
/// The hasher state for decoding
hasher: H,
}
Expand All @@ -48,7 +48,7 @@ pub(super) struct WAL<C, H = Sha256> {
#[derive(Debug)]
enum WALFrame<C> {
/// Data frame type
Data(DataFrame<C>),
Data(DataFrameOwned<C>),
/// Commit frame type
Commit(CommitFrame),
}
Expand All @@ -58,13 +58,25 @@ enum WALFrame<C> {
/// Contains either a log entry or a seal index
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DataFrame<C> {
pub(crate) enum DataFrameOwned<C> {
/// A Frame containing a log entry
Entry(LogEntry<C>),
/// A Frame containing the sealed index
SealIndex(LogIndex),
}

/// The data frame
///
/// Contains either a log entry or a seal index
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DataFrame<'a, C> {
/// A Frame containing a log entry
Entry(&'a LogEntry<C>),
/// A Frame containing the sealed index
SealIndex(LogIndex),
}

/// The commit frame
///
/// This frames contains a SHA256 checksum of all previous frames since last commit
Expand Down Expand Up @@ -98,14 +110,14 @@ impl<C> WAL<C> {
}
}

impl<C> Encoder<Vec<DataFrame<C>>> for WAL<C>
impl<C> Encoder<Vec<DataFrame<'_, C>>> for WAL<C>
where
C: Serialize,
{
type Error = io::Error;

/// Encodes a frame
fn encode(&mut self, frames: Vec<DataFrame<C>>) -> Result<Vec<u8>, Self::Error> {
fn encode(&mut self, frames: Vec<DataFrame<'_, C>>) -> Result<Vec<u8>, Self::Error> {
let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frame_data);
frame_data.extend_from_slice(&commit_frame.encode());
Expand All @@ -118,7 +130,7 @@ impl<C> Decoder for WAL<C>
where
C: Serialize + DeserializeOwned,
{
type Item = Vec<DataFrame<C>>;
type Item = Vec<DataFrameOwned<C>>;

type Error = WALError;

Expand Down Expand Up @@ -208,14 +220,14 @@ where
let entry: LogEntry<C> = bincode::deserialize(payload)
.map_err(|e| WALError::Corrupted(CorruptType::Codec(e.to_string())))?;

Ok(Some((Self::Data(DataFrame::Entry(entry)), 8 + len)))
Ok(Some((Self::Data(DataFrameOwned::Entry(entry)), 8 + len)))
}

/// Decodes an seal index frame from source
fn decode_seal_index(header: [u8; 8]) -> Result<Option<(Self, usize)>, WALError> {
let index = Self::decode_u64_from_header(header);

Ok(Some((Self::Data(DataFrame::SealIndex(index)), 8)))
Ok(Some((Self::Data(DataFrameOwned::SealIndex(index)), 8)))
}

/// Decodes a commit frame from source
Expand All @@ -239,7 +251,17 @@ where
}
}

impl<C> FrameType for DataFrame<C> {
impl<C> DataFrameOwned<C> {
/// Converts `DataFrameOwned` to `DataFrame`
pub(super) fn get_ref(&self) -> DataFrame<'_, C> {
match *self {
DataFrameOwned::Entry(ref entry) => DataFrame::Entry(entry),
DataFrameOwned::SealIndex(index) => DataFrame::SealIndex(index),
}
}
}

impl<C> FrameType for DataFrame<'_, C> {
fn frame_type(&self) -> u8 {
match *self {
DataFrame::Entry(_) => ENTRY,
Expand All @@ -248,7 +270,7 @@ impl<C> FrameType for DataFrame<C> {
}
}

impl<C> FrameEncoder for DataFrame<C>
impl<C> FrameEncoder for DataFrame<'_, C>
where
C: Serialize,
{
Expand Down Expand Up @@ -322,31 +344,31 @@ mod tests {
async fn frame_encode_decode_is_ok() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap());
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame.get_ref()]).unwrap());

let (data_frame_get, len) = codec.decode(&encoded).unwrap();
let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap();
let DataFrame::Entry(ref entry_get) = data_frame_get[0] else {
let DataFrameOwned::Entry(ref entry_get) = data_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};
let DataFrame::SealIndex(ref index) = seal_frame_get[0] else {
let DataFrameOwned::SealIndex(index) = seal_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};

assert_eq!(*entry_get, entry);
assert_eq!(*index, 1);
assert_eq!(index, 1);
}

#[tokio::test]
async fn frame_zero_write_will_be_detected() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded[0] = 0;

let err = codec.decode(&encoded).unwrap_err();
Expand All @@ -357,9 +379,9 @@ mod tests {
async fn frame_corrupt_will_be_detected() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded[1] = 0;

let err = codec.decode(&encoded).unwrap_err();
Expand Down
119 changes: 73 additions & 46 deletions crates/curp/src/server/storage/wal/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
Arc,
},
task::Poll,
thread::JoinHandle,
};

use clippy_utilities::OverflowArithmetic;
Expand All @@ -28,77 +29,94 @@ pub(super) struct FilePipeline {
///
/// As tokio::fs is generally slower than std::fs, we use synchronous file allocation.
/// Please also refer to the issue discussed on the tokio repo: https://github.com/tokio-rs/tokio/issues/3664
file_iter: flume::IntoIter<LockedFile>,
file_iter: Option<flume::IntoIter<LockedFile>>,
/// Stopped flag
stopped: Arc<AtomicBool>,
/// Join handle of the allocation task
file_alloc_task_handle: Option<JoinHandle<()>>,
}

impl FilePipeline {
/// Creates a new `FilePipeline`
pub(super) fn new(dir: PathBuf, file_size: u64) -> io::Result<Self> {
Self::clean_up(&dir)?;
pub(super) fn new(dir: PathBuf, file_size: u64) -> Self {
if let Err(e) = Self::clean_up(&dir) {
error!("Failed to clean up tmp files: {e}");
}

let (file_tx, file_rx) = flume::bounded(1);
let dir_c = dir.clone();
let stopped = Arc::new(AtomicBool::new(false));
let stopped_c = Arc::clone(&stopped);

#[cfg(not(madsim))]
let _ignore = std::thread::spawn(move || {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send(file).is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
{
let file_alloc_task_handle = std::thread::spawn(move || {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send(file).is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
}
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
});

Self {
dir,
file_size,
file_iter: Some(file_rx.into_iter()),
stopped,
file_alloc_task_handle: Some(file_alloc_task_handle),
}
});
}

#[cfg(madsim)]
let _ignore = tokio::spawn(async move {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send_async(file).await.is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
{
let _ignore = tokio::spawn(async move {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send_async(file).await.is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
}
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
});

Self {
dir,
file_size,
file_iter: Some(file_rx.into_iter()),
stopped,
file_alloc_task_handle: None,
}
});

Ok(Self {
dir,
file_size,
file_iter: file_rx.into_iter(),
stopped,
})
}
}

/// Stops the pipeline
Expand Down Expand Up @@ -132,6 +150,11 @@ impl FilePipeline {
impl Drop for FilePipeline {
fn drop(&mut self) {
self.stop();
// Drops the file rx so that the allocation task could exit
drop(self.file_iter.take());
if let Some(Err(e)) = self.file_alloc_task_handle.take().map(JoinHandle::join) {
error!("failed to join file allocation task: {e:?}");
}
}
}

Expand All @@ -142,7 +165,11 @@ impl Iterator for FilePipeline {
if self.stopped.load(Ordering::Relaxed) {
return None;
}
self.file_iter.next().map(Ok)
self.file_iter
.as_mut()
.unwrap_or_else(|| unreachable!("Option is always `Some`"))
.next()
.map(Ok)
}
}

Expand All @@ -164,7 +191,7 @@ mod tests {
async fn file_pipeline_is_ok() {
let file_size = 1024;
let dir = tempfile::tempdir().unwrap();
let mut pipeline = FilePipeline::new(dir.as_ref().into(), file_size).unwrap();
let mut pipeline = FilePipeline::new(dir.as_ref().into(), file_size);

let check_size = |mut file: LockedFile| {
let file = file.into_std();
Expand Down
Loading

0 comments on commit 661c9f0

Please sign in to comment.