Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bits::Slice and bits::Iter to reduce number of unsafe blocks #6

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 125 additions & 110 deletions sealable-trie/src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl<'a> Slice<'a> {
///
/// Returns `None` if `offset` or `length` is too large or `bytes` doesn’t
/// have enough underlying data for the length of the slice.
#[inline]
pub fn new(bytes: &'a [u8], offset: u8, length: u16) -> Option<Self> {
if offset >= 8 {
return None;
Expand All @@ -68,6 +69,7 @@ impl<'a> Slice<'a> {
///
/// Returns `None` if the slice is too long. The maximum length is 8191
/// bytes.
#[inline]
pub fn from_bytes(bytes: &'a [u8]) -> Option<Self> {
Some(Self {
offset: 0,
Expand All @@ -78,9 +80,11 @@ impl<'a> Slice<'a> {
}

/// Returns length of the slice in bits.
#[inline]
pub fn len(&self) -> u16 { self.length }

/// Returns whether the slice is empty.
#[inline]
pub fn is_empty(&self) -> bool { self.length == 0 }

/// Returns the first bit in the slice advances the slice by one position.
Expand All @@ -96,17 +100,18 @@ impl<'a> Slice<'a> {
/// assert_eq!(Some(true), slice.pop_front());
/// assert_eq!(None, slice.pop_front());
/// ```
#[inline]
pub fn pop_front(&mut self) -> Option<bool> {
if self.length == 0 {
return None;
}
// SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte
let bit = (unsafe { self.ptr.read() } & (0x80 >> self.offset)) != 0;
self.offset = (self.offset + 1) & 7;
// SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte and
// `self.ptr + 1` is valid pointer value.
let (first, rest) = unsafe { (self.ptr.read(), self.ptr.add(1)) };
let bit = first & (0x80 >> self.offset) != 0;
self.offset = (self.offset + 1) % 8;
if self.offset == 0 {
// SAFETY: self.ptr pointing at valid byte ⇒ self.ptr+1 is valid
// pointer
self.ptr = unsafe { self.ptr.add(1) }
self.ptr = rest;
}
self.length -= 1;
Some(bit)
Expand All @@ -125,6 +130,7 @@ impl<'a> Slice<'a> {
/// assert_eq!(Some(false), slice.pop_back());
/// assert_eq!(None, slice.pop_back());
/// ```
#[inline]
pub fn pop_back(&mut self) -> Option<bool> {
self.length = self.length.checked_sub(1)?;
let total_bits = self.underlying_bits_length();
Expand All @@ -136,13 +142,39 @@ impl<'a> Slice<'a> {
Some(byte & mask != 0)
}

/// Returns subslice from the beginning of the slice shrinking the slice by
/// its length.
///
/// Behaves like [`Self::split_at`] except instead of returning two slices
/// it advances `self` and returns the head. Returns `None` if the slice is
/// too short.
///
/// ## Example
///
/// ```
/// # use sealable_trie::bits;
///
/// let mut slice = bits::Slice::new(&[0x81], 0, 8).unwrap();
/// let head = slice.pop_front_slice(4).unwrap();
/// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(head));
/// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(slice));
///
/// assert_eq!(None, slice.pop_front_slice(5));
/// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(slice));
/// ```
#[inline]
pub fn pop_front_slice(&mut self, length: u16) -> Option<Self> {
let (head, tail) = self.split_at(length)?;
*self = tail;
Some(head)
}

/// Returns subslice from the end of the slice shrinking the slice by its
/// length.
///
/// Returns `None` if the slice is too short.
///
/// This is an ‘rpsilt_at’ operation but instead of returning two slices it
/// shortens the slice and returns the tail.
/// Behaves similarly to [`Self::split_at`] except the `length` is the
/// length of the suffix and instead of returning two slices it shortens
/// `self` and returns the tail. Returns `None` if the slice is too short.
///
/// ## Example
///
Expand All @@ -157,14 +189,11 @@ impl<'a> Slice<'a> {
/// assert_eq!(None, slice.pop_back_slice(5));
/// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice));
/// ```
#[inline]
pub fn pop_back_slice(&mut self, length: u16) -> Option<Self> {
self.length = self.length.checked_sub(length)?;
let total_bits = self.underlying_bits_length();
// SAFETY: `ptr` is guaranteed to point at offset + original length
// valid bits.
let ptr = unsafe { self.ptr.add(total_bits / 8) };
let offset = (total_bits % 8) as u8;
Some(Self { ptr, offset, length, phantom: Default::default() })
let (head, tail) = self.split_at(self.length.checked_sub(length)?)?;
*self = head;
Some(tail)
}

/// Returns iterator over chunks of slice where each chunk occupies at most
Expand All @@ -173,8 +202,32 @@ impl<'a> Slice<'a> {
/// The final chunk may be shorter. Note that due to offset the first chunk
/// may be shorter than 272 bits (i.e. 34 * 8) however it will span full 34
/// bytes.
#[inline]
pub fn chunks(&self) -> Chunks<'a> { Chunks(*self) }

/// Splits slice into two at given index.
///
/// This is like `[T]::split_at` except rather than panicking it returns
/// `None` if the slice is too short.
#[inline]
pub fn split_at(&self, length: u16) -> Option<(Self, Self)> {
let remaining = self.length.checked_sub(length)?;
let left = Slice { length, ..*self };
// SAFETY: By invariant, `ptr..ptr+(self.offset + self.length + 7) / 8`
// is a valid range. Since `length ≤ self.length` then `ptr +
// (self.offset + length / 8) is valid as well`.
let ptr = unsafe {
self.ptr.add((usize::from(self.offset) + usize::from(length)) / 8)
};
let right = Slice {
offset: (self.offset + length as u8 % 8) % 8,
length: remaining,
ptr,
phantom: self.phantom,
};
Some((left, right))
}

/// Returns whether the slice starts with given prefix.
///
/// **Note**: If the `prefix` slice has a different bit offset it is not
Expand All @@ -192,12 +245,15 @@ impl<'a> Slice<'a> {
/// // Different offset:
/// assert!(!slice.starts_with(bits::Slice::new(&[0xAA], 4, 4).unwrap()));
/// ```
#[inline]
pub fn starts_with(&self, prefix: Slice<'_>) -> bool {
if self.offset != prefix.offset || self.length < prefix.length {
return false;
if self.offset != prefix.offset {
false
} else if let Some((head, _)) = self.split_at(prefix.length) {
head == prefix
} else {
false
}
let subslice = Slice { length: prefix.length, ..*self };
subslice == prefix
}

/// Removes prefix from the slice; returns `false` if slice doesn’t start
Expand Down Expand Up @@ -229,16 +285,17 @@ impl<'a> Slice<'a> {
/// assert!(slice.strip_prefix(slice.clone()));
/// assert_eq!(bits::Slice::new(&[0x00], 4, 0).unwrap(), slice);
/// ```
#[inline]
pub fn strip_prefix(&mut self, prefix: Slice<'_>) -> bool {
if !self.starts_with(prefix) {
return false;
if self.offset == prefix.offset {
if let Some((head, tail)) = self.split_at(prefix.length) {
if head == prefix {
*self = tail;
return true;
}
}
}
let length = usize::from(prefix.length) + usize::from(prefix.offset);
// SAFETY: self.ptr points to at least length+offset valid bits.
unsafe { self.ptr = self.ptr.add(length / 8) };
self.offset = (length % 8) as u8;
self.length -= prefix.length;
true
false
}

/// Strips common prefix from two slices; returns new slice with the common
Expand Down Expand Up @@ -267,32 +324,41 @@ impl<'a> Slice<'a> {
/// assert_eq!(bits::Slice::new(&[0xFF], 6, 0).unwrap(), right);
/// ```
pub fn forward_common_prefix(&mut self, other: &mut Slice<'_>) -> Self {
let offset = self.offset;
if offset != other.offset {
return Self { length: 0, ..*self };
}
let length = (|| {
let offset = self.offset;
if offset != other.offset {
return 0;
}
let length = self.length.min(other.length);
let length = u32::from(length) + u32::from(offset);
let lhs = self.bytes().split_at(((length + 7) / 8) as usize).0;
let rhs = other.bytes().split_at(((length + 7) / 8) as usize).0;

let (fst, lhs, rhs) = match (lhs.split_first(), rhs.split_first()) {
(Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1),
_ => return 0,
};
let fst = fst & (0xFF >> offset);

let total_bits_matched = if fst != 0 {
fst.leading_zeros()
} else if let Some(n) =
lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b)
{
8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros()
} else {
8 + lhs.len() as u32 * 8
}
.min(length);

let length = self.length.min(other.length);
// SAFETY: offset is common offset of both slices and length is shorter
// of either slice, which means that both pointers point to at least
// offset+length bits.
let (idx, length) = unsafe {
forward_common_prefix_impl(self.ptr, other.ptr, offset, length)
};
let result = Self { length, ..*self };

self.length -= length;
self.offset = ((u16::from(self.offset) + length) % 8) as u8;
other.length -= length;
other.offset = self.offset;
// SAFETY: forward_common_prefix_impl guarantees that `idx` is no more
// than what the slices have.
unsafe {
self.ptr = self.ptr.add(idx);
other.ptr = other.ptr.add(idx);
total_bits_matched.saturating_sub(u32::from(offset)) as u16
})();
if length == 0 {
Self { length: 0, ..*self }
} else {
other.pop_front_slice(length).unwrap();
self.pop_front_slice(length).unwrap()
}

result
}

/// Checks that all bits outside of the specified range are set to zero.
Expand Down Expand Up @@ -425,49 +491,6 @@ impl<'a> Slice<'a> {
}
}

/// Implementation of [`Slice::forward_common_prefix`].
///
/// ## Safety
///
/// `lhs` and `rhs` must point to at least `offset + max_length` bits.
unsafe fn forward_common_prefix_impl(
lhs: *const u8,
rhs: *const u8,
offset: u8,
max_length: u16,
) -> (usize, u16) {
let max_length = u32::from(max_length) + u32::from(offset);
// SAFETY: Caller promises that both pointers point to at least offset +
// max_length bits.
let (lhs, rhs) = unsafe {
let len = ((max_length + 7) / 8) as usize;
let lhs = core::slice::from_raw_parts(lhs, len).split_first();
let rhs = core::slice::from_raw_parts(rhs, len).split_first();
(lhs, rhs)
};

let (first, lhs, rhs) = match (lhs, rhs) {
(Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1),
_ => return (0, 0),
};
let first = first & (0xFF >> offset);

let total_bits_matched = if first != 0 {
first.leading_zeros()
} else if let Some(n) = lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b)
{
8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros()
} else {
8 + lhs.len() as u32 * 8
}
.min(max_length);

(
(total_bits_matched / 8) as usize,
total_bits_matched.saturating_sub(u32::from(offset)) as u16,
)
}

impl core::cmp::PartialEq for Slice<'_> {
/// Compares two slices to see if they contain the same bits and have the
/// same offset.
Expand Down Expand Up @@ -580,21 +603,13 @@ impl<'a> core::iter::Iterator for Chunks<'a> {
type Item = Slice<'a>;

fn next(&mut self) -> Option<Slice<'a>> {
let bytes_len = self.0.bytes().len().min(nodes::MAX_EXTENSION_KEY_SIZE);
if bytes_len == 0 {
return None;
const MAX_LENGTH: u16 = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16;
let length = (MAX_LENGTH - u16::from(self.0.offset)).min(self.0.length);
if length == 0 {
None
} else {
self.0.pop_front_slice(length)
}
let slice = &mut self.0;
let offset = slice.offset;
let length = (bytes_len * 8 - usize::from(offset))
.min(usize::from(slice.length)) as u16;
let ptr = slice.ptr;
slice.offset = 0;
slice.length -= length;
// SAFETY: `ptr` points at a slice which is at least `bytes_len` bytes
// long so it’s safe to advance it by that offset.
slice.ptr = unsafe { slice.ptr.add(bytes_len) };
Some(Slice { offset, length, ptr, phantom: Default::default() })
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down