Skip to content

Commit

Permalink
Refactor Segment Tree Implementation (#835)
Browse files Browse the repository at this point in the history
ref: refactor segment tree
  • Loading branch information
sozelfist authored Oct 31, 2024
1 parent 5151982 commit 5a83939
Showing 1 changed file with 181 additions and 142 deletions.
323 changes: 181 additions & 142 deletions src/data_structures/segment_tree.rs
Original file line number Diff line number Diff line change
@@ -1,185 +1,224 @@
use std::cmp::min;
//! A module providing a Segment Tree data structure for efficient range queries
//! and updates. It supports operations like finding the minimum, maximum,
//! and sum of segments in an array.

use std::fmt::Debug;
use std::ops::Range;

/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays.
/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...]
/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together.
/// Note that this function should be commutative and associative
/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b`
pub struct SegmentTree<T: Debug + Default + Ord + Copy> {
len: usize, // length of the represented
tree: Vec<T>, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance)
merge: fn(T, T) -> T, // how we merge two values together
/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`.
#[derive(Debug, PartialEq, Eq)]
pub enum SegmentTreeError {
/// Error indicating that an index is out of bounds.
IndexOutOfBounds,
/// Error indicating that a range provided for a query is invalid.
InvalidRange,
}

/// A structure representing a Segment Tree. This tree can be used to efficiently
/// perform range queries and updates on an array of elements.
pub struct SegmentTree<T, F>
where
T: Debug + Default + Ord + Copy,
F: Fn(T, T) -> T,
{
/// The length of the input array for which the segment tree is built.
size: usize,
/// A vector representing the segment tree.
nodes: Vec<T>,
/// A merging function defined as a closure or callable type.
merge_fn: F,
}

impl<T: Debug + Default + Ord + Copy> SegmentTree<T> {
/// Builds a SegmentTree from an array and a merge function
pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self {
let len = arr.len();
let mut buf: Vec<T> = vec![T::default(); 2 * len];
// Populate the tree bottom-up, from right to left
buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value
for i in (1..len).rev() {
// a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2
buf[i] = merge(buf[2 * i], buf[2 * i + 1]);
impl<T, F> SegmentTree<T, F>
where
T: Debug + Default + Ord + Copy,
F: Fn(T, T) -> T,
{
/// Creates a new `SegmentTree` from the provided slice of elements.
///
/// # Arguments
///
/// * `arr`: A slice of elements of type `T` to initialize the segment tree.
/// * `merge`: A merging function that defines how to merge two elements of type `T`.
///
/// # Returns
///
/// A new `SegmentTree` instance populated with the given elements.
pub fn from_vec(arr: &[T], merge: F) -> Self {
let size = arr.len();
let mut buffer: Vec<T> = vec![T::default(); 2 * size];

// Populate the leaves of the tree
buffer[size..(2 * size)].clone_from_slice(arr);
for idx in (1..size).rev() {
buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]);
}

SegmentTree {
len,
tree: buf,
merge,
size,
nodes: buffer,
merge_fn: merge,
}
}

/// Query the range (exclusive)
/// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.)
/// return the aggregate of values over this range otherwise
pub fn query(&self, range: Range<usize>) -> Option<T> {
let mut l = range.start + self.len;
let mut r = min(self.len, range.end) + self.len;
let mut res = None;
// Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations
while l < r {
if l % 2 == 1 {
res = Some(match res {
None => self.tree[l],
Some(old) => (self.merge)(old, self.tree[l]),
/// Queries the segment tree for the result of merging the elements in the given range.
///
/// # Arguments
///
/// * `range`: A range specified as `Range<usize>`, indicating the start (inclusive)
/// and end (exclusive) indices of the segment to query.
///
/// # Returns
///
/// * `Ok(Some(result))` if the query was successful and there are elements in the range,
/// * `Ok(None)` if the range is empty,
/// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid.
pub fn query(&self, range: Range<usize>) -> Result<Option<T>, SegmentTreeError> {
if range.start >= self.size || range.end > self.size {
return Err(SegmentTreeError::InvalidRange);
}

let mut left = range.start + self.size;
let mut right = range.end + self.size;
let mut result = None;

// Iterate through the segment tree to accumulate results
while left < right {
if left % 2 == 1 {
result = Some(match result {
None => self.nodes[left],
Some(old) => (self.merge_fn)(old, self.nodes[left]),
});
l += 1;
left += 1;
}
if r % 2 == 1 {
r -= 1;
res = Some(match res {
None => self.tree[r],
Some(old) => (self.merge)(old, self.tree[r]),
if right % 2 == 1 {
right -= 1;
result = Some(match result {
None => self.nodes[right],
Some(old) => (self.merge_fn)(old, self.nodes[right]),
});
}
l /= 2;
r /= 2;
left /= 2;
right /= 2;
}
res

Ok(result)
}

/// Updates the value at index `idx` in the original array with a new value `val`
pub fn update(&mut self, idx: usize, val: T) {
// change every value where `idx` plays a role, bottom -> up
// 1: change in the right-hand side of the tree (bottom row)
let mut idx = idx + self.len;
self.tree[idx] = val;

// 2: then bubble up
idx /= 2;
while idx != 0 {
self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]);
idx /= 2;
/// Updates the value at the specified index in the segment tree.
///
/// # Arguments
///
/// * `idx`: The index (0-based) of the element to update.
/// * `val`: The new value of type `T` to set at the specified index.
///
/// # Returns
///
/// * `Ok(())` if the update was successful,
/// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds.
pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> {
if idx >= self.size {
return Err(SegmentTreeError::IndexOutOfBounds);
}

let mut index = idx + self.size;
if self.nodes[index] == val {
return Ok(());
}

self.nodes[index] = val;
while index > 1 {
index /= 2;
self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]);
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;
use std::cmp::{max, min};

#[test]
fn test_min_segments() {
let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
let min_seg_tree = SegmentTree::from_vec(&vec, min);
assert_eq!(Some(-5), min_seg_tree.query(4..7));
assert_eq!(Some(-30), min_seg_tree.query(0..vec.len()));
assert_eq!(Some(-30), min_seg_tree.query(0..2));
assert_eq!(Some(-4), min_seg_tree.query(1..3));
assert_eq!(Some(-5), min_seg_tree.query(1..7));
let mut min_seg_tree = SegmentTree::from_vec(&vec, min);
assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5)));
assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30)));
assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30)));
assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4)));
assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5)));
assert_eq!(min_seg_tree.update(5, 10), Ok(()));
assert_eq!(min_seg_tree.update(14, -8), Ok(()));
assert_eq!(min_seg_tree.query(4..7), Ok(Some(3)));
assert_eq!(
min_seg_tree.update(15, 100),
Err(SegmentTreeError::IndexOutOfBounds)
);
assert_eq!(min_seg_tree.query(5..5), Ok(None));
assert_eq!(
min_seg_tree.query(10..16),
Err(SegmentTreeError::InvalidRange)
);
assert_eq!(
min_seg_tree.query(15..20),
Err(SegmentTreeError::InvalidRange)
);
}

#[test]
fn test_max_segments() {
let val_at_6 = 6;
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
let mut max_seg_tree = SegmentTree::from_vec(&vec, max);
assert_eq!(Some(15), max_seg_tree.query(0..vec.len()));
let max_4_to_6 = 6;
assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7));
let delta = 2;
max_seg_tree.update(6, val_at_6 + delta);
assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7));
assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15)));
assert_eq!(max_seg_tree.query(3..5), Ok(Some(7)));
assert_eq!(max_seg_tree.query(4..8), Ok(Some(11)));
assert_eq!(max_seg_tree.query(8..10), Ok(Some(9)));
assert_eq!(max_seg_tree.query(9..12), Ok(Some(15)));
assert_eq!(max_seg_tree.update(4, 10), Ok(()));
assert_eq!(max_seg_tree.update(14, -8), Ok(()));
assert_eq!(max_seg_tree.query(3..5), Ok(Some(10)));
assert_eq!(
max_seg_tree.update(15, 100),
Err(SegmentTreeError::IndexOutOfBounds)
);
assert_eq!(max_seg_tree.query(5..5), Ok(None));
assert_eq!(
max_seg_tree.query(10..16),
Err(SegmentTreeError::InvalidRange)
);
assert_eq!(
max_seg_tree.query(15..20),
Err(SegmentTreeError::InvalidRange)
);
}

#[test]
fn test_sum_segments() {
let val_at_6 = 6;
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b);
for (i, val) in vec.iter().enumerate() {
assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1)));
}
let sum_4_to_6 = sum_seg_tree.query(4..7);
assert_eq!(Some(4), sum_4_to_6);
let delta = 3;
sum_seg_tree.update(6, val_at_6 + delta);
assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38)));
assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5)));
assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4)));
assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3)));
assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37)));
assert_eq!(sum_seg_tree.update(5, 10), Ok(()));
assert_eq!(sum_seg_tree.update(14, -8), Ok(()));
assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19)));
assert_eq!(
sum_4_to_6.unwrap() + delta,
sum_seg_tree.query(4..7).unwrap()
sum_seg_tree.update(15, 100),
Err(SegmentTreeError::IndexOutOfBounds)
);
assert_eq!(sum_seg_tree.query(5..5), Ok(None));
assert_eq!(
sum_seg_tree.query(10..16),
Err(SegmentTreeError::InvalidRange)
);
assert_eq!(
sum_seg_tree.query(15..20),
Err(SegmentTreeError::InvalidRange)
);
}

// Some properties over segment trees:
// When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc.
// When asking for an interval containing a single value, return this value, no matter the merge function

#[quickcheck]
fn check_overall_interval_min(array: Vec<i32>) -> TestResult {
let seg_tree = SegmentTree::from_vec(&array, min);
TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len()))
}

#[quickcheck]
fn check_overall_interval_max(array: Vec<i32>) -> TestResult {
let seg_tree = SegmentTree::from_vec(&array, max);
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
}

#[quickcheck]
fn check_overall_interval_sum(array: Vec<i32>) -> TestResult {
let seg_tree = SegmentTree::from_vec(&array, max);
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
}

#[quickcheck]
fn check_single_interval_min(array: Vec<i32>) -> TestResult {
let seg_tree = SegmentTree::from_vec(&array, min);
for (i, value) in array.into_iter().enumerate() {
let res = seg_tree.query(i..(i + 1));
if res != Some(value) {
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
}
}
TestResult::passed()
}

#[quickcheck]
fn check_single_interval_max(array: Vec<i32>) -> TestResult {
let seg_tree = SegmentTree::from_vec(&array, max);
for (i, value) in array.into_iter().enumerate() {
let res = seg_tree.query(i..(i + 1));
if res != Some(value) {
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
}
}
TestResult::passed()
}

#[quickcheck]
fn check_single_interval_sum(array: Vec<i32>) -> TestResult {
let seg_tree = SegmentTree::from_vec(&array, max);
for (i, value) in array.into_iter().enumerate() {
let res = seg_tree.query(i..(i + 1));
if res != Some(value) {
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
}
}
TestResult::passed()
}
}

0 comments on commit 5a83939

Please sign in to comment.