Skip to content

Commit

Permalink
Refactor bits::Slice and bits::Iter to reduce number of unsafe blocks
Browse files Browse the repository at this point in the history
The biggest change is introduction of Slice::split_at method which is
then used by other methods (such as starts_with, pop_back_slice etc.).
This allows number of unsafe blocks to be reduced.
  • Loading branch information
mina86 committed Aug 23, 2023
1 parent 58c8e9a commit f287ad2
Showing 1 changed file with 125 additions and 110 deletions.
235 changes: 125 additions & 110 deletions sealable-trie/src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl<'a> Slice<'a> {
///
/// Returns `None` if `offset` or `length` is too large or `bytes` doesn’t
/// have enough underlying data for the length of the slice.
#[inline]
pub fn new(bytes: &'a [u8], offset: u8, length: u16) -> Option<Self> {
if offset >= 8 {
return None;
Expand All @@ -68,6 +69,7 @@ impl<'a> Slice<'a> {
///
/// Returns `None` if the slice is too long. The maximum length is 8191
/// bytes.
#[inline]
pub fn from_bytes(bytes: &'a [u8]) -> Option<Self> {
Some(Self {
offset: 0,
Expand All @@ -78,9 +80,11 @@ impl<'a> Slice<'a> {
}

/// Returns length of the slice in bits.
#[inline]
pub fn len(&self) -> u16 { self.length }

/// Returns whether the slice is empty.
#[inline]
pub fn is_empty(&self) -> bool { self.length == 0 }

/// Returns the first bit in the slice advances the slice by one position.
Expand All @@ -96,17 +100,18 @@ impl<'a> Slice<'a> {
/// assert_eq!(Some(true), slice.pop_front());
/// assert_eq!(None, slice.pop_front());
/// ```
#[inline]
pub fn pop_front(&mut self) -> Option<bool> {
if self.length == 0 {
return None;
}
// SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte
let bit = (unsafe { self.ptr.read() } & (0x80 >> self.offset)) != 0;
self.offset = (self.offset + 1) & 7;
// SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte and
// `self.ptr + 1` is valid pointer value.
let (first, rest) = unsafe { (self.ptr.read(), self.ptr.add(1)) };
let bit = first & (0x80 >> self.offset) != 0;
self.offset = (self.offset + 1) % 8;
if self.offset == 0 {
// SAFETY: self.ptr pointing at valid byte ⇒ self.ptr+1 is valid
// pointer
self.ptr = unsafe { self.ptr.add(1) }
self.ptr = rest;
}
self.length -= 1;
Some(bit)
Expand All @@ -125,6 +130,7 @@ impl<'a> Slice<'a> {
/// assert_eq!(Some(false), slice.pop_back());
/// assert_eq!(None, slice.pop_back());
/// ```
#[inline]
pub fn pop_back(&mut self) -> Option<bool> {
self.length = self.length.checked_sub(1)?;
let total_bits = self.underlying_bits_length();
Expand All @@ -136,13 +142,39 @@ impl<'a> Slice<'a> {
Some(byte & mask != 0)
}

/// Returns subslice from the beginning of the slice shrinking the slice by
/// its length.
///
/// Behaves like [`Self::split_at`] except instead of returning two slices
/// it advances `self` and returns the head. Returns `None` if the slice is
/// too short.
///
/// ## Example
///
/// ```
/// # use sealable_trie::bits;
///
/// let mut slice = bits::Slice::new(&[0x81], 0, 8).unwrap();
/// let head = slice.pop_front_slice(4).unwrap();
/// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(head));
/// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(slice));
///
/// assert_eq!(None, slice.pop_front_slice(5));
/// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(slice));
/// ```
#[inline]
pub fn pop_front_slice(&mut self, length: u16) -> Option<Self> {
let (head, tail) = self.split_at(length)?;
*self = tail;
Some(head)
}

/// Returns subslice from the end of the slice shrinking the slice by its
/// length.
///
/// Returns `None` if the slice is too short.
///
/// This is an ‘rpsilt_at’ operation but instead of returning two slices it
/// shortens the slice and returns the tail.
/// Behaves similarly to [`Self::split_at`] except the `length` is the
/// length of the suffix and instead of returning two slices it shortens
/// `self` and returns the tail. Returns `None` if the slice is too short.
///
/// ## Example
///
Expand All @@ -157,14 +189,11 @@ impl<'a> Slice<'a> {
/// assert_eq!(None, slice.pop_back_slice(5));
/// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice));
/// ```
#[inline]
pub fn pop_back_slice(&mut self, length: u16) -> Option<Self> {
self.length = self.length.checked_sub(length)?;
let total_bits = self.underlying_bits_length();
// SAFETY: `ptr` is guaranteed to point at offset + original length
// valid bits.
let ptr = unsafe { self.ptr.add(total_bits / 8) };
let offset = (total_bits % 8) as u8;
Some(Self { ptr, offset, length, phantom: Default::default() })
let (head, tail) = self.split_at(self.length.checked_sub(length)?)?;
*self = head;
Some(tail)
}

/// Returns iterator over chunks of slice where each chunk occupies at most
Expand All @@ -173,8 +202,32 @@ impl<'a> Slice<'a> {
/// The final chunk may be shorter. Note that due to offset the first chunk
/// may be shorter than 272 bits (i.e. 34 * 8) however it will span full 34
/// bytes.
#[inline]
pub fn chunks(&self) -> Chunks<'a> { Chunks(*self) }

/// Splits slice into two at given index.
///
/// This is like `[T]::split_at` except rather than panicking it returns
/// `None` if the slice is too short.
#[inline]
pub fn split_at(&self, length: u16) -> Option<(Self, Self)> {
let remaining = self.length.checked_sub(length)?;
let left = Slice { length, ..*self };
// SAFETY: By invariant, `ptr..ptr+(self.offset + self.length + 7) / 8`
// is a valid range. Since `length ≤ self.length` then `ptr +
// (self.offset + length / 8) is valid as well`.
let ptr = unsafe {
self.ptr.add((usize::from(self.offset) + usize::from(length)) / 8)
};
let right = Slice {
offset: (self.offset + length as u8 % 8) % 8,
length: remaining,
ptr,
phantom: self.phantom,
};
Some((left, right))
}

/// Returns whether the slice starts with given prefix.
///
/// **Note**: If the `prefix` slice has a different bit offset it is not
Expand All @@ -192,12 +245,15 @@ impl<'a> Slice<'a> {
/// // Different offset:
/// assert!(!slice.starts_with(bits::Slice::new(&[0xAA], 4, 4).unwrap()));
/// ```
#[inline]
pub fn starts_with(&self, prefix: Slice<'_>) -> bool {
if self.offset != prefix.offset || self.length < prefix.length {
return false;
if self.offset != prefix.offset {
false
} else if let Some((head, _)) = self.split_at(prefix.length) {
head == prefix
} else {
false
}
let subslice = Slice { length: prefix.length, ..*self };
subslice == prefix
}

/// Removes prefix from the slice; returns `false` if slice doesn’t start
Expand Down Expand Up @@ -229,16 +285,17 @@ impl<'a> Slice<'a> {
/// assert!(slice.strip_prefix(slice.clone()));
/// assert_eq!(bits::Slice::new(&[0x00], 4, 0).unwrap(), slice);
/// ```
#[inline]
pub fn strip_prefix(&mut self, prefix: Slice<'_>) -> bool {
if !self.starts_with(prefix) {
return false;
if self.offset == prefix.offset {
if let Some((head, tail)) = self.split_at(prefix.length) {
if head == prefix {
*self = tail;
return true;
}
}
}
let length = usize::from(prefix.length) + usize::from(prefix.offset);
// SAFETY: self.ptr points to at least length+offset valid bits.
unsafe { self.ptr = self.ptr.add(length / 8) };
self.offset = (length % 8) as u8;
self.length -= prefix.length;
true
false
}

/// Strips common prefix from two slices; returns new slice with the common
Expand Down Expand Up @@ -267,32 +324,41 @@ impl<'a> Slice<'a> {
/// assert_eq!(bits::Slice::new(&[0xFF], 6, 0).unwrap(), right);
/// ```
pub fn forward_common_prefix(&mut self, other: &mut Slice<'_>) -> Self {
let offset = self.offset;
if offset != other.offset {
return Self { length: 0, ..*self };
}
let length = (|| {
let offset = self.offset;
if offset != other.offset {
return 0;
}
let length = self.length.min(other.length);
let length = u32::from(length) + u32::from(offset);
let lhs = self.bytes().split_at(((length + 7) / 8) as usize).0;
let rhs = other.bytes().split_at(((length + 7) / 8) as usize).0;

let (fst, lhs, rhs) = match (lhs.split_first(), rhs.split_first()) {
(Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1),
_ => return 0,
};
let fst = fst & (0xFF >> offset);

let total_bits_matched = if fst != 0 {
fst.leading_zeros()
} else if let Some(n) =
lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b)
{
8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros()
} else {
8 + lhs.len() as u32 * 8
}
.min(length);

let length = self.length.min(other.length);
// SAFETY: offset is common offset of both slices and length is shorter
// of either slice, which means that both pointers point to at least
// offset+length bits.
let (idx, length) = unsafe {
forward_common_prefix_impl(self.ptr, other.ptr, offset, length)
};
let result = Self { length, ..*self };

self.length -= length;
self.offset = ((u16::from(self.offset) + length) % 8) as u8;
other.length -= length;
other.offset = self.offset;
// SAFETY: forward_common_prefix_impl guarantees that `idx` is no more
// than what the slices have.
unsafe {
self.ptr = self.ptr.add(idx);
other.ptr = other.ptr.add(idx);
total_bits_matched.saturating_sub(u32::from(offset)) as u16
})();
if length == 0 {
Self { length: 0, ..*self }
} else {
other.pop_front_slice(length).unwrap();
self.pop_front_slice(length).unwrap()
}

result
}

/// Checks that all bits outside of the specified range are set to zero.
Expand Down Expand Up @@ -425,49 +491,6 @@ impl<'a> Slice<'a> {
}
}

/// Implementation of [`Slice::forward_common_prefix`].
///
/// ## Safety
///
/// `lhs` and `rhs` must point to at least `offset + max_length` bits.
unsafe fn forward_common_prefix_impl(
lhs: *const u8,
rhs: *const u8,
offset: u8,
max_length: u16,
) -> (usize, u16) {
let max_length = u32::from(max_length) + u32::from(offset);
// SAFETY: Caller promises that both pointers point to at least offset +
// max_length bits.
let (lhs, rhs) = unsafe {
let len = ((max_length + 7) / 8) as usize;
let lhs = core::slice::from_raw_parts(lhs, len).split_first();
let rhs = core::slice::from_raw_parts(rhs, len).split_first();
(lhs, rhs)
};

let (first, lhs, rhs) = match (lhs, rhs) {
(Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1),
_ => return (0, 0),
};
let first = first & (0xFF >> offset);

let total_bits_matched = if first != 0 {
first.leading_zeros()
} else if let Some(n) = lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b)
{
8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros()
} else {
8 + lhs.len() as u32 * 8
}
.min(max_length);

(
(total_bits_matched / 8) as usize,
total_bits_matched.saturating_sub(u32::from(offset)) as u16,
)
}

impl core::cmp::PartialEq for Slice<'_> {
/// Compares two slices to see if they contain the same bits and have the
/// same offset.
Expand Down Expand Up @@ -580,21 +603,13 @@ impl<'a> core::iter::Iterator for Chunks<'a> {
type Item = Slice<'a>;

fn next(&mut self) -> Option<Slice<'a>> {
let bytes_len = self.0.bytes().len().min(nodes::MAX_EXTENSION_KEY_SIZE);
if bytes_len == 0 {
return None;
const MAX_LENGTH: u16 = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16;
let length = (MAX_LENGTH - u16::from(self.0.offset)).min(self.0.length);
if length == 0 {
None
} else {
self.0.pop_front_slice(length)
}
let slice = &mut self.0;
let offset = slice.offset;
let length = (bytes_len * 8 - usize::from(offset))
.min(usize::from(slice.length)) as u16;
let ptr = slice.ptr;
slice.offset = 0;
slice.length -= length;
// SAFETY: `ptr` points at a slice which is at least `bytes_len` bytes
// long so it’s safe to advance it by that offset.
slice.ptr = unsafe { slice.ptr.add(bytes_len) };
Some(Slice { offset, length, ptr, phantom: Default::default() })
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down

0 comments on commit f287ad2

Please sign in to comment.