Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct request validation #942

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions crates/utils/src/lca_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use std::ops::{Add, Sub as _};

/// A LCA tree to accelerate Txns' key overlap validation
#[non_exhaustive]
#[derive(Debug)]
pub struct LCATree {
///
nodes: Vec<LCANode>,
}

///
#[non_exhaustive]
#[derive(Debug)]
pub struct LCANode {
///
pub parent: Vec<usize>,
///
pub depth: usize,
}

#[allow(clippy::indexing_slicing)]
impl LCATree {
/// build a `LCATree` with a sentinel node
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
nodes: vec![LCANode {
parent: vec![0],
depth: 0,
}],
}
}
/// get a node by index
///
/// # Panics
///
/// The function panics if given `i` > max index
#[inline]
#[must_use]
pub fn get_node(&self, i: usize) -> &LCANode {
assert!(i < self.nodes.len(), "Node {i} doesn't exist");
&self.nodes[i]
}
/// insert a node and return its index
///
/// # Panics
///
/// The function panics if given `parent` doesn't exist
#[inline]
#[must_use]
#[allow(clippy::as_conversions)]
pub fn insert_node(&mut self, parent: usize) -> usize {
let depth = if parent == 0 {
0
} else {
self.get_node(parent).depth.add(1)
};
let mut node = LCANode {
parent: vec![],
depth,
};
node.parent.push(parent);
let parent_num = if depth == 0 { 0 } else { depth.ilog2() } as usize;
for i in 0..parent_num {
node.parent.push(self.get_node(node.parent[i]).parent[i]);
}
self.nodes.push(node);
self.nodes.len().sub(1)
}
/// Use Binary Lifting to find the LCA of `node_a` and `node_b`
///
/// # Panics
///
/// The function panics if given `node_a` or `node_b` doesn't exist
#[inline]
#[must_use]
pub fn find_lca(&self, node_a: usize, node_b: usize) -> usize {
let (mut x, mut y) = if self.get_node(node_a).depth < self.get_node(node_b).depth {
(node_a, node_b)
} else {
(node_b, node_a)
};
while self.get_node(x).depth < self.get_node(y).depth {
for ancestor in self.get_node(y).parent.iter().rev() {
if self.get_node(x).depth <= self.get_node(*ancestor).depth {
y = *ancestor;
}
}
}
while x != y {
let node_x = self.get_node(x);
let node_y = self.get_node(y);
if node_x.parent[0] == node_y.parent[0] {
x = node_x.parent[0];
break;
}
for i in (0..node_x.parent.len()).rev() {
if node_x.parent[i] != node_y.parent[i] {
x = node_x.parent[i];
y = node_y.parent[i];
break;
}
}
}
x
}
}

impl Default for LCATree {
#[inline]
fn default() -> Self {
Self::new()
}
}

#[cfg(test)]
mod test {
use crate::lca_tree::LCATree;

#[test]
fn test_ilog2() {
assert_eq!(3_i32.ilog2(), 1);
assert_eq!(5_i32.ilog2(), 2);
assert_eq!(7_i32.ilog2(), 2);
assert_eq!(10_i32.ilog2(), 3);
}

#[test]
// root
// / | \
// / | \
// / | \
// node1 node2 node3
// | \ | |
// | \ | |
// node4 node5 node6 node7
// | \ \
// | \ node10
// node8 node9
//
//
fn test_lca() {
let mut tree = LCATree::new();
let root = 0;
let node1 = tree.insert_node(root);
let node2 = tree.insert_node(root);
let node3 = tree.insert_node(root);

let node4 = tree.insert_node(node1);
let node5 = tree.insert_node(node1);

let node6 = tree.insert_node(node2);

let node7 = tree.insert_node(node3);

let node8 = tree.insert_node(node4);
let node9 = tree.insert_node(node4);

let node10 = tree.insert_node(node5);

assert_eq!(tree.find_lca(node1, node2), root);
assert_eq!(tree.find_lca(node1, node3), root);
assert_eq!(tree.find_lca(node1, node4), node1);
assert_eq!(tree.find_lca(node4, node5), node1);
assert_eq!(tree.find_lca(node5, node7), root);
assert_eq!(tree.find_lca(node6, node7), root);
assert_eq!(tree.find_lca(node8, node9), node4);
assert_eq!(tree.find_lca(node8, node10), node1);
assert_eq!(tree.find_lca(node6, node10), root);
assert_eq!(tree.find_lca(node8, node5), node1);
assert_eq!(tree.find_lca(node9, node3), root);
assert_eq!(tree.find_lca(node10, node2), root);
}
}
2 changes: 2 additions & 0 deletions crates/utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ pub struct ServerTlsConfig;
pub mod barrier;
/// configuration
pub mod config;
/// LCA tree implementation
pub mod lca_tree;
/// utils for metrics
pub mod metrics;
/// utils of `parking_lot` lock
Expand Down
149 changes: 110 additions & 39 deletions crates/xlineapi/src/request_validation.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::collections::HashSet;
use std::collections::{hash_map::Entry, HashMap};

use serde::{Deserialize, Serialize};
use thiserror::Error;
use utils::interval_map::{Interval, IntervalMap};
use utils::lca_tree::LCATree;

use crate::{
command::KeyRange, AuthRoleAddRequest, AuthRoleGrantPermissionRequest, AuthUserAddRequest,
interval::BytesAffine, AuthRoleAddRequest, AuthRoleGrantPermissionRequest, AuthUserAddRequest,
DeleteRangeRequest, PutRequest, RangeRequest, Request, RequestOp, SortOrder, SortTarget,
TxnRequest,
};
Expand Down Expand Up @@ -85,61 +87,133 @@ impl RequestValidator for TxnRequest {
}
}

let _ignore_success = check_intervals(&self.success)?;
let _ignore_failure = check_intervals(&self.failure)?;
check_intervals(&self.success)?;
check_intervals(&self.failure)?;

Ok(())
}
}

/// Check if puts and deletes overlap
fn check_intervals(ops: &[RequestOp]) -> Result<(HashSet<&[u8]>, Vec<KeyRange>), ValidationError> {
// TODO: use interval tree is better?
type DelsIntervalMap<'a> = IntervalMap<BytesAffine, Vec<usize>>;

let mut dels = Vec::new();
fn new_bytes_affine_interval(start: &[u8], key_end: &[u8]) -> Interval<BytesAffine> {
let high = match key_end {
&[] => {
let mut end = start.to_vec();
end.push(0);
BytesAffine::Bytes(end)
}
&[0] => BytesAffine::Unbounded,
bytes => BytesAffine::Bytes(bytes.to_vec()),
};
Interval::new(BytesAffine::new_key(start), high)
}

for op in ops {
if let Some(Request::RequestDeleteRange(ref req)) = op.request {
// collect dels
let del = KeyRange::new(req.key.as_slice(), req.range_end.as_slice());
dels.push(del);
/// Check if puts and deletes overlap
fn check_intervals(ops: &[RequestOp]) -> Result<(), ValidationError> {
let mut lca_tree = LCATree::new();
// Because `dels` stores Vec<node_idx> corresponding to the interval, merging two `dels` is slightly cumbersome.
// Here, `dels` are directly passed into the build function
let mut dels = DelsIntervalMap::new();
// This function will traverse all RequestOp and collect all the parent nodes corresponding to `put` and `del` operations.
// During this process, the atomicity of the put operation can be guaranteed.
let puts = build_interval_tree(ops, &mut dels, &mut lca_tree, 0)?;

// Now we have `dels` and `puts` which contain all node index corresponding to `del` and `put` ops,
// we only need to iterate through the puts to find out whether each put overlaps with the del operation in the dels,
// and even if it overlaps, whether it satisfies lca.depth % 2 == 0.
for (put_key, put_vec) in puts {
let put_interval = new_bytes_affine_interval(put_key, &[]);
let overlaps = dels.find_all_overlap(&put_interval);
for put_node_idx in put_vec {
for (_, del_vec) in overlaps.iter() {
for del_node_idx in del_vec.iter() {
let lca_node_idx = lca_tree.find_lca(put_node_idx, *del_node_idx);
// lca.depth % 2 == 0 means this lca is on a success or failure branch,
// and two nodes on the same branch are prohibited from overlapping.
if lca_tree.get_node(lca_node_idx).depth % 2 == 0 {
return Err(ValidationError::DuplicateKey);
}
}
}
}
}

let mut puts: HashSet<&[u8]> = HashSet::new();
Ok(())
}

fn build_interval_tree<'a>(
ops: &'a [RequestOp],
dels_map: &mut DelsIntervalMap<'a>,
lca_tree: &mut LCATree,
parent: usize,
) -> Result<HashMap<&'a [u8], Vec<usize>>, ValidationError> {
let mut puts_map: HashMap<&[u8], Vec<usize>> = HashMap::new();
for op in ops {
if let Some(Request::RequestTxn(ref req)) = op.request {
// handle child txn request
let (success_puts, mut success_dels) = check_intervals(&req.success)?;
let (failure_puts, mut failure_dels) = check_intervals(&req.failure)?;

for k in success_puts.union(&failure_puts) {
if !puts.insert(k) {
return Err(ValidationError::DuplicateKey);
match op.request {
Some(Request::RequestDeleteRange(ref req)) => {
// collect dels
let cur_node_idx = lca_tree.insert_node(parent);
let del = new_bytes_affine_interval(req.key.as_slice(), req.range_end.as_slice());
dels_map.entry(del).or_insert(vec![]).push(cur_node_idx);
}
Some(Request::RequestTxn(ref req)) => {
// RequestTxn is absolutely a node
let cur_node_idx = lca_tree.insert_node(parent);
let success_puts_map = if !req.success.is_empty() {
// success branch is also a node
let success_node_idx = lca_tree.insert_node(cur_node_idx);
build_interval_tree(&req.success, dels_map, lca_tree, success_node_idx)?
} else {
HashMap::new()
};
let failure_puts_map = if !req.failure.is_empty() {
// failure branch is also a node
let failure_node_idx = lca_tree.insert_node(cur_node_idx);
build_interval_tree(&req.failure, dels_map, lca_tree, failure_node_idx)?
} else {
HashMap::new()
};
// success_puts_map and failure_puts_map cannot overlap with other op's puts_map.
for (sub_put_key, sub_put_node_idx) in success_puts_map.iter() {
if puts_map.contains_key(sub_put_key) {
return Err(ValidationError::DuplicateKey);
}
puts_map.insert(&sub_put_key, sub_put_node_idx.to_vec());
}
if dels.iter().any(|del| del.contains_key(k)) {
return Err(ValidationError::DuplicateKey);
// but they can overlap with each other
for (sub_put_key, mut sub_put_node_idx) in failure_puts_map.into_iter() {
match puts_map.entry(&sub_put_key) {
Entry::Vacant(_) => {
puts_map.insert(&sub_put_key, sub_put_node_idx);
}
Entry::Occupied(mut put_entry) => {
if !success_puts_map.contains_key(sub_put_key) {
return Err(ValidationError::DuplicateKey);
}
let put_vec = put_entry.get_mut();
bsbds marked this conversation as resolved.
Show resolved Hide resolved
put_vec.append(&mut sub_put_node_idx);
}
};
}
}

dels.append(&mut success_dels);
dels.append(&mut failure_dels);
_ => {}
}
}

// put in RequestPut cannot overlap with all puts in RequestTxn
for op in ops {
if let Some(Request::RequestPut(ref req)) = op.request {
// check puts in this level
if !puts.insert(&req.key) {
return Err(ValidationError::DuplicateKey);
}
if dels.iter().any(|del| del.contains_key(&req.key)) {
return Err(ValidationError::DuplicateKey);
match op.request {
Some(Request::RequestPut(ref req)) => {
if puts_map.contains_key(&req.key.as_slice()) {
return Err(ValidationError::DuplicateKey);
}
let cur_node_idx = lca_tree.insert_node(parent);
puts_map.insert(&req.key, vec![cur_node_idx]);
}
_ => {}
}
}
Ok((puts, dels))
Ok(puts_map)
}

impl RequestValidator for AuthUserAddRequest {
Expand Down Expand Up @@ -583,9 +657,6 @@ mod test {
run_test(testcases);
}

// FIXME: This test will fail in the current implementation.
// See https://github.com/xline-kv/Xline/issues/410 for more details
#[ignore]
#[test]
fn check_intervals_txn_nested_overlap_should_return_error() {
let put_op = RequestOp {
Expand Down
Loading