-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Segment Tree Implementation (#835)
ref: refactor segment tree
- Loading branch information
Showing
1 changed file
with
181 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,185 +1,224 @@ | ||
use std::cmp::min; | ||
//! A module providing a Segment Tree data structure for efficient range queries | ||
//! and updates. It supports operations like finding the minimum, maximum, | ||
//! and sum of segments in an array. | ||
|
||
use std::fmt::Debug; | ||
use std::ops::Range; | ||
|
||
/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays. | ||
/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...] | ||
/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together. | ||
/// Note that this function should be commutative and associative | ||
/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b` | ||
pub struct SegmentTree<T: Debug + Default + Ord + Copy> { | ||
len: usize, // length of the represented | ||
tree: Vec<T>, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance) | ||
merge: fn(T, T) -> T, // how we merge two values together | ||
/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. | ||
#[derive(Debug, PartialEq, Eq)] | ||
pub enum SegmentTreeError { | ||
/// Error indicating that an index is out of bounds. | ||
IndexOutOfBounds, | ||
/// Error indicating that a range provided for a query is invalid. | ||
InvalidRange, | ||
} | ||
|
||
/// A structure representing a Segment Tree. This tree can be used to efficiently | ||
/// perform range queries and updates on an array of elements. | ||
pub struct SegmentTree<T, F> | ||
where | ||
T: Debug + Default + Ord + Copy, | ||
F: Fn(T, T) -> T, | ||
{ | ||
/// The length of the input array for which the segment tree is built. | ||
size: usize, | ||
/// A vector representing the segment tree. | ||
nodes: Vec<T>, | ||
/// A merging function defined as a closure or callable type. | ||
merge_fn: F, | ||
} | ||
|
||
impl<T: Debug + Default + Ord + Copy> SegmentTree<T> { | ||
/// Builds a SegmentTree from an array and a merge function | ||
pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self { | ||
let len = arr.len(); | ||
let mut buf: Vec<T> = vec![T::default(); 2 * len]; | ||
// Populate the tree bottom-up, from right to left | ||
buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value | ||
for i in (1..len).rev() { | ||
// a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2 | ||
buf[i] = merge(buf[2 * i], buf[2 * i + 1]); | ||
impl<T, F> SegmentTree<T, F> | ||
where | ||
T: Debug + Default + Ord + Copy, | ||
F: Fn(T, T) -> T, | ||
{ | ||
/// Creates a new `SegmentTree` from the provided slice of elements. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `arr`: A slice of elements of type `T` to initialize the segment tree. | ||
/// * `merge`: A merging function that defines how to merge two elements of type `T`. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A new `SegmentTree` instance populated with the given elements. | ||
pub fn from_vec(arr: &[T], merge: F) -> Self { | ||
let size = arr.len(); | ||
let mut buffer: Vec<T> = vec![T::default(); 2 * size]; | ||
|
||
// Populate the leaves of the tree | ||
buffer[size..(2 * size)].clone_from_slice(arr); | ||
for idx in (1..size).rev() { | ||
buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]); | ||
} | ||
|
||
SegmentTree { | ||
len, | ||
tree: buf, | ||
merge, | ||
size, | ||
nodes: buffer, | ||
merge_fn: merge, | ||
} | ||
} | ||
|
||
/// Query the range (exclusive) | ||
/// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.) | ||
/// return the aggregate of values over this range otherwise | ||
pub fn query(&self, range: Range<usize>) -> Option<T> { | ||
let mut l = range.start + self.len; | ||
let mut r = min(self.len, range.end) + self.len; | ||
let mut res = None; | ||
// Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations | ||
while l < r { | ||
if l % 2 == 1 { | ||
res = Some(match res { | ||
None => self.tree[l], | ||
Some(old) => (self.merge)(old, self.tree[l]), | ||
/// Queries the segment tree for the result of merging the elements in the given range. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `range`: A range specified as `Range<usize>`, indicating the start (inclusive) | ||
/// and end (exclusive) indices of the segment to query. | ||
/// | ||
/// # Returns | ||
/// | ||
/// * `Ok(Some(result))` if the query was successful and there are elements in the range, | ||
/// * `Ok(None)` if the range is empty, | ||
/// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. | ||
pub fn query(&self, range: Range<usize>) -> Result<Option<T>, SegmentTreeError> { | ||
if range.start >= self.size || range.end > self.size { | ||
return Err(SegmentTreeError::InvalidRange); | ||
} | ||
|
||
let mut left = range.start + self.size; | ||
let mut right = range.end + self.size; | ||
let mut result = None; | ||
|
||
// Iterate through the segment tree to accumulate results | ||
while left < right { | ||
if left % 2 == 1 { | ||
result = Some(match result { | ||
None => self.nodes[left], | ||
Some(old) => (self.merge_fn)(old, self.nodes[left]), | ||
}); | ||
l += 1; | ||
left += 1; | ||
} | ||
if r % 2 == 1 { | ||
r -= 1; | ||
res = Some(match res { | ||
None => self.tree[r], | ||
Some(old) => (self.merge)(old, self.tree[r]), | ||
if right % 2 == 1 { | ||
right -= 1; | ||
result = Some(match result { | ||
None => self.nodes[right], | ||
Some(old) => (self.merge_fn)(old, self.nodes[right]), | ||
}); | ||
} | ||
l /= 2; | ||
r /= 2; | ||
left /= 2; | ||
right /= 2; | ||
} | ||
res | ||
|
||
Ok(result) | ||
} | ||
|
||
/// Updates the value at index `idx` in the original array with a new value `val` | ||
pub fn update(&mut self, idx: usize, val: T) { | ||
// change every value where `idx` plays a role, bottom -> up | ||
// 1: change in the right-hand side of the tree (bottom row) | ||
let mut idx = idx + self.len; | ||
self.tree[idx] = val; | ||
|
||
// 2: then bubble up | ||
idx /= 2; | ||
while idx != 0 { | ||
self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]); | ||
idx /= 2; | ||
/// Updates the value at the specified index in the segment tree. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `idx`: The index (0-based) of the element to update. | ||
/// * `val`: The new value of type `T` to set at the specified index. | ||
/// | ||
/// # Returns | ||
/// | ||
/// * `Ok(())` if the update was successful, | ||
/// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. | ||
pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> { | ||
if idx >= self.size { | ||
return Err(SegmentTreeError::IndexOutOfBounds); | ||
} | ||
|
||
let mut index = idx + self.size; | ||
if self.nodes[index] == val { | ||
return Ok(()); | ||
} | ||
|
||
self.nodes[index] = val; | ||
while index > 1 { | ||
index /= 2; | ||
self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]); | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use quickcheck::TestResult; | ||
use quickcheck_macros::quickcheck; | ||
use std::cmp::{max, min}; | ||
|
||
#[test] | ||
fn test_min_segments() { | ||
let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; | ||
let min_seg_tree = SegmentTree::from_vec(&vec, min); | ||
assert_eq!(Some(-5), min_seg_tree.query(4..7)); | ||
assert_eq!(Some(-30), min_seg_tree.query(0..vec.len())); | ||
assert_eq!(Some(-30), min_seg_tree.query(0..2)); | ||
assert_eq!(Some(-4), min_seg_tree.query(1..3)); | ||
assert_eq!(Some(-5), min_seg_tree.query(1..7)); | ||
let mut min_seg_tree = SegmentTree::from_vec(&vec, min); | ||
assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); | ||
assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); | ||
assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); | ||
assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); | ||
assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); | ||
assert_eq!(min_seg_tree.update(5, 10), Ok(())); | ||
assert_eq!(min_seg_tree.update(14, -8), Ok(())); | ||
assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); | ||
assert_eq!( | ||
min_seg_tree.update(15, 100), | ||
Err(SegmentTreeError::IndexOutOfBounds) | ||
); | ||
assert_eq!(min_seg_tree.query(5..5), Ok(None)); | ||
assert_eq!( | ||
min_seg_tree.query(10..16), | ||
Err(SegmentTreeError::InvalidRange) | ||
); | ||
assert_eq!( | ||
min_seg_tree.query(15..20), | ||
Err(SegmentTreeError::InvalidRange) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_max_segments() { | ||
let val_at_6 = 6; | ||
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; | ||
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; | ||
let mut max_seg_tree = SegmentTree::from_vec(&vec, max); | ||
assert_eq!(Some(15), max_seg_tree.query(0..vec.len())); | ||
let max_4_to_6 = 6; | ||
assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7)); | ||
let delta = 2; | ||
max_seg_tree.update(6, val_at_6 + delta); | ||
assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7)); | ||
assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); | ||
assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); | ||
assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); | ||
assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); | ||
assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); | ||
assert_eq!(max_seg_tree.update(4, 10), Ok(())); | ||
assert_eq!(max_seg_tree.update(14, -8), Ok(())); | ||
assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); | ||
assert_eq!( | ||
max_seg_tree.update(15, 100), | ||
Err(SegmentTreeError::IndexOutOfBounds) | ||
); | ||
assert_eq!(max_seg_tree.query(5..5), Ok(None)); | ||
assert_eq!( | ||
max_seg_tree.query(10..16), | ||
Err(SegmentTreeError::InvalidRange) | ||
); | ||
assert_eq!( | ||
max_seg_tree.query(15..20), | ||
Err(SegmentTreeError::InvalidRange) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_sum_segments() { | ||
let val_at_6 = 6; | ||
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; | ||
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; | ||
let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); | ||
for (i, val) in vec.iter().enumerate() { | ||
assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1))); | ||
} | ||
let sum_4_to_6 = sum_seg_tree.query(4..7); | ||
assert_eq!(Some(4), sum_4_to_6); | ||
let delta = 3; | ||
sum_seg_tree.update(6, val_at_6 + delta); | ||
assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); | ||
assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); | ||
assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); | ||
assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); | ||
assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); | ||
assert_eq!(sum_seg_tree.update(5, 10), Ok(())); | ||
assert_eq!(sum_seg_tree.update(14, -8), Ok(())); | ||
assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); | ||
assert_eq!( | ||
sum_4_to_6.unwrap() + delta, | ||
sum_seg_tree.query(4..7).unwrap() | ||
sum_seg_tree.update(15, 100), | ||
Err(SegmentTreeError::IndexOutOfBounds) | ||
); | ||
assert_eq!(sum_seg_tree.query(5..5), Ok(None)); | ||
assert_eq!( | ||
sum_seg_tree.query(10..16), | ||
Err(SegmentTreeError::InvalidRange) | ||
); | ||
assert_eq!( | ||
sum_seg_tree.query(15..20), | ||
Err(SegmentTreeError::InvalidRange) | ||
); | ||
} | ||
|
||
// Some properties over segment trees: | ||
// When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc. | ||
// When asking for an interval containing a single value, return this value, no matter the merge function | ||
|
||
#[quickcheck] | ||
fn check_overall_interval_min(array: Vec<i32>) -> TestResult { | ||
let seg_tree = SegmentTree::from_vec(&array, min); | ||
TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len())) | ||
} | ||
|
||
#[quickcheck] | ||
fn check_overall_interval_max(array: Vec<i32>) -> TestResult { | ||
let seg_tree = SegmentTree::from_vec(&array, max); | ||
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) | ||
} | ||
|
||
#[quickcheck] | ||
fn check_overall_interval_sum(array: Vec<i32>) -> TestResult { | ||
let seg_tree = SegmentTree::from_vec(&array, max); | ||
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) | ||
} | ||
|
||
#[quickcheck] | ||
fn check_single_interval_min(array: Vec<i32>) -> TestResult { | ||
let seg_tree = SegmentTree::from_vec(&array, min); | ||
for (i, value) in array.into_iter().enumerate() { | ||
let res = seg_tree.query(i..(i + 1)); | ||
if res != Some(value) { | ||
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); | ||
} | ||
} | ||
TestResult::passed() | ||
} | ||
|
||
#[quickcheck] | ||
fn check_single_interval_max(array: Vec<i32>) -> TestResult { | ||
let seg_tree = SegmentTree::from_vec(&array, max); | ||
for (i, value) in array.into_iter().enumerate() { | ||
let res = seg_tree.query(i..(i + 1)); | ||
if res != Some(value) { | ||
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); | ||
} | ||
} | ||
TestResult::passed() | ||
} | ||
|
||
#[quickcheck] | ||
fn check_single_interval_sum(array: Vec<i32>) -> TestResult { | ||
let seg_tree = SegmentTree::from_vec(&array, max); | ||
for (i, value) in array.into_iter().enumerate() { | ||
let res = seg_tree.query(i..(i + 1)); | ||
if res != Some(value) { | ||
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); | ||
} | ||
} | ||
TestResult::passed() | ||
} | ||
} |