Skip to content

Commit

Permalink
fix: prevent batch_index overflow in raw_curp
Browse files Browse the repository at this point in the history
Signed-off-by: GFX9 <[email protected]>
  • Loading branch information
GFX9 committed Mar 18, 2024
1 parent f081173 commit f22ed94
Showing 1 changed file with 145 additions and 44 deletions.
189 changes: 145 additions & 44 deletions crates/curp/src/server/raw_curp/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use bincode::serialized_size;
use clippy_utilities::{NumericCast, OverflowArithmetic};
use itertools::Itertools;
use tokio::sync::mpsc;
use tracing::error;
use tracing::{error, warn};

use crate::{
cmd::Command,
Expand Down Expand Up @@ -82,21 +82,28 @@ impl<C: Command> FallbackContext<C> {
struct LogEntryVecDeque<C: Command> {
/// A VecDeque to store log entries, it will be serialized and persisted
entries: VecDeque<Arc<LogEntry<C>>>,
/// The sum of serialized size of previous log entries
/// batch_index[i+1] = batch_index[i] + size(entries[i])
batch_index: VecDeque<u64>,
/// entry size of each item in entries
entry_size: VecDeque<u64>,
/// the right index of the batch (offset)
/// batch_range: [i, i + batch_index[i]]
batch_index: VecDeque<usize>,
/// the first entry idx of the current batch window
first_entry_at_last_batch: usize,
/// the current batch window size
last_batch_size: u64,
/// Batch size limit
batch_limit: u64,
}

impl<C: Command> LogEntryVecDeque<C> {
/// return a log entries with cap
fn new(cap: usize, batch_limit: u64) -> Self {
let mut batch_index = VecDeque::with_capacity(cap.overflow_add(1));
batch_index.push_back(0);
Self {
entries: VecDeque::with_capacity(cap),
batch_index,
entry_size: VecDeque::with_capacity(cap),
batch_index: VecDeque::with_capacity(cap),
first_entry_at_last_batch: 0,
last_batch_size: 0,
batch_limit,
}
}
Expand All @@ -111,21 +118,71 @@ impl<C: Command> LogEntryVecDeque<C> {

/// push a log entry into the back of queue
fn push_back(&mut self, entry: Arc<LogEntry<C>>) -> Result<(), bincode::Error> {
#![allow(clippy::indexing_slicing)]
let entry_size = serialized_size(&entry)?;

if entry_size > self.batch_limit {
warn!("entry_size of an entry > batch_limit, which may be too small.",);
}

self.entries.push_back(entry);
let Some(&pre_entries_size) = self.batch_index.back() else {
unreachable!("batch_index cannot be None")
};
self.batch_index
.push_back(pre_entries_size.overflow_add(entry_size));
self.entry_size.push_back(entry_size);
self.batch_index.push_back(0); // placeholder

if entry_size > self.batch_limit {
let entry_idx = self.batch_index.len() - 1;
for prev_idx in self.first_entry_at_last_batch..entry_idx {
self.batch_index[prev_idx] = entry_idx - prev_idx; // record offset but not absolute index
}
self.batch_index[entry_idx] = 1;
self.last_batch_size = 0;
self.first_entry_at_last_batch = entry_idx + 1;
return Ok(());
}

while self.last_batch_size + entry_size > self.batch_limit
&& self.first_entry_at_last_batch < self.entries.len()
{
self.batch_index[self.first_entry_at_last_batch] =
self.entries.len() - 1 - self.first_entry_at_last_batch; // record offset but not absolute index
self.last_batch_size -= self.entry_size[self.first_entry_at_last_batch];
self.first_entry_at_last_batch += 1;
}

self.last_batch_size += entry_size;

if self.first_entry_at_last_batch >= self.entries.len() {
self.batch_index[self.entries.len() - 1] = 1;
}

if self.last_batch_size == self.batch_limit {
self.batch_index[self.first_entry_at_last_batch] =
self.entries.len() - self.first_entry_at_last_batch; // record offset but not absolute index
}

Ok(())
}

/// pop a log entry from the front of queue
fn pop_front(&mut self) -> Option<Arc<LogEntry<C>>> {
#![allow(clippy::indexing_slicing)]
if self.entries.front().is_some() {
_ = self.batch_index.pop_front();
let front_size = self.entry_size[0];

if self.first_entry_at_last_batch == 0 {
self.last_batch_size -= front_size;
} else {
self.first_entry_at_last_batch -= 1;
}

let _ = self
.batch_index
.pop_front()
.unwrap_or_else(|| unreachable!());
let _ = self
.entry_size
.pop_front()
.unwrap_or_else(|| unreachable!());
self.entries.pop_front()
} else {
None
Expand All @@ -134,36 +191,34 @@ impl<C: Command> LogEntryVecDeque<C> {

/// restore log entries from Vec
fn restore(&mut self, entries: Vec<LogEntry<C>>) {
let mut batch_index = VecDeque::with_capacity(entries.capacity());
batch_index.push_back(0);
for entry in &entries {
#[allow(clippy::expect_used)]
let entry_size =
serialized_size(entry).expect("log entry {entry:?} cannot be serialized");
if let Some(cur_size) = batch_index.back() {
batch_index.push_back(cur_size.overflow_add(entry_size));
}
}
self.batch_index = VecDeque::with_capacity(entries.capacity());
self.entries = VecDeque::with_capacity(entries.capacity());
self.entry_size = VecDeque::with_capacity(entries.capacity());

self.entries = entries.into_iter().map(Arc::new).collect();
self.batch_index = batch_index;
self.last_batch_size = 0;
self.first_entry_at_last_batch = 0;

for entry in entries {
let _unuse = self.push_back(Arc::from(entry));
}
}

/// clear whole log entries
fn clear(&mut self) {
self.entries.clear();
self.entry_size.clear();
self.batch_index.clear();
self.batch_index.push_back(0);
self.last_batch_size = 0;
self.first_entry_at_last_batch = 0;
}

/// Get the range [left, right) of the log entry, whose size should be equal or smaller than `batch_limit`
fn get_range_by_batch(&self, left: usize) -> Range<usize> {
#[allow(clippy::indexing_slicing)]
let target = self.batch_index[left].overflow_add(self.batch_limit);
// remove the fake index 0 in `batch_index`
match self.batch_index.binary_search(&target) {
Ok(right) => left..right,
Err(right) => left..right - 1,
#![allow(clippy::indexing_slicing)]
if self.batch_index[left] == 0 {
left..self.entries.len()
} else {
left..left + self.batch_index[left]
}
}

Expand All @@ -175,15 +230,56 @@ impl<C: Command> LogEntryVecDeque<C> {

/// check whether the log entry range [li,..) exceeds the batch limit or not
fn has_next_batch(&self, left: usize) -> bool {
if let (Some(&cur_size), Some(&last_size)) =
(self.batch_index.get(left), self.batch_index.back())
{
let target_size = cur_size.overflow_add(self.batch_limit);
target_size <= last_size
if let Some(&offset) = self.batch_index.get(left) {
offset != 0
} else {
false
}
}

#[allow(unused)]
/// set batch limit and reconstruct `batch_index`
fn set_batch_limit(&mut self, batch_limit: u64) {
#![allow(clippy::indexing_slicing)]
self.batch_limit = batch_limit;
self.last_batch_size = 0;
self.first_entry_at_last_batch = 0;
self.batch_index.iter_mut().for_each(|val| *val = 0);

for entry_idx in 0..self.entries.len() {
let entry_size = self.entry_size[entry_idx];

if entry_size > self.batch_limit {
for prev_idx in self.first_entry_at_last_batch..entry_idx {
self.batch_index[prev_idx] = entry_idx - prev_idx; // record offset but not absolute index
}
self.batch_index[entry_idx] = 1;
self.last_batch_size = 0;
self.first_entry_at_last_batch = entry_idx + 1;
continue;
}

while self.last_batch_size + entry_size > self.batch_limit
&& self.first_entry_at_last_batch < self.entries.len()
{
self.batch_index[self.first_entry_at_last_batch] =
entry_idx - self.first_entry_at_last_batch; // record offset but not absolute index
self.last_batch_size -= self.entry_size[self.first_entry_at_last_batch];
self.first_entry_at_last_batch += 1;
}

self.last_batch_size += entry_size;

if self.first_entry_at_last_batch >= self.entries.len() {
self.batch_index[entry_idx] = 1;
}

if entry_idx == self.entries.len() - 1 && self.last_batch_size == self.batch_limit {
self.batch_index[self.first_entry_at_last_batch] =
self.entries.len() - self.first_entry_at_last_batch; // record offset but not absolute index
}
}
}
}

impl<C: Command> std::ops::Deref for LogEntryVecDeque<C> {
Expand Down Expand Up @@ -448,6 +544,12 @@ impl<C: Command> Log<C> {
false
});
}

#[allow(unused)]
/// set batch limit and reconstruct `batch_index`
pub(super) fn set_batch_limit(&mut self, batch_limit: u64) {
self.entries.set_batch_limit(batch_limit);
}
}

#[cfg(test)]
Expand All @@ -470,7 +572,7 @@ mod tests {
}

fn set_batch_limit(log: &mut Log<TestCommand>, batch_limit: u64) {
log.entries.batch_limit = batch_limit;
log.set_batch_limit(batch_limit);
}

#[test]
Expand Down Expand Up @@ -575,7 +677,7 @@ mod tests {
.enumerate()
.map(|(idx, cmd)| log.push(1, ProposeId(0, idx.numeric_cast()), cmd).unwrap())
.collect::<Vec<_>>();
let log_entry_size = log.entries.batch_index[1];
let log_entry_size = log.entries.entry_size[0];

set_batch_limit(&mut log, 3 * log_entry_size - 1);
let bound_1 = log.entries.get_range_by_batch(3);
Expand Down Expand Up @@ -633,7 +735,7 @@ mod tests {
let bound_5 = log.entries.get_range_by_batch(3);
assert_eq!(
bound_5,
3..3,
3..4,
"batch_index = {:?}, batch = {}, log_entry_size = {}",
log.entries.batch_index,
log.entries.batch_limit,
Expand Down Expand Up @@ -664,8 +766,7 @@ mod tests {

log.restore_entries(entries);
assert_eq!(log.entries.len(), 10);
assert_eq!(log.entries.batch_index.len(), 11);
assert_eq!(log.entries.batch_index[0], 0);
assert_eq!(log.entries.batch_index.len(), 10);
let entry_size = log.entries.batch_index[1];

log.entries
Expand All @@ -675,7 +776,7 @@ mod tests {
.for_each(|(idx, &size)| {
assert_eq!(
size,
entry_size * idx.numeric_cast::<u64>(),
entry_size * idx,
"batch_index = {:?}, batch = {}, entry_size = {}",
log.entries.batch_index,
log.entries.batch_limit,
Expand All @@ -698,6 +799,6 @@ mod tests {
log.compact();
assert_eq!(log.base_index, 12);
assert_eq!(log.entries.front().unwrap().index, 13);
assert_eq!(log.entries.batch_index.len(), 19);
assert_eq!(log.entries.batch_index.len(), 18);
}
}

0 comments on commit f22ed94

Please sign in to comment.