diff --git a/sealable-trie/src/bits.rs b/sealable-trie/src/bits.rs index 0f293f88..fab8b0b3 100644 --- a/sealable-trie/src/bits.rs +++ b/sealable-trie/src/bits.rs @@ -50,6 +50,7 @@ impl<'a> Slice<'a> { /// /// Returns `None` if `offset` or `length` is too large or `bytes` doesn’t /// have enough underlying data for the length of the slice. + #[inline] pub fn new(bytes: &'a [u8], offset: u8, length: u16) -> Option { if offset >= 8 { return None; @@ -68,6 +69,7 @@ impl<'a> Slice<'a> { /// /// Returns `None` if the slice is too long. The maximum length is 8191 /// bytes. + #[inline] pub fn from_bytes(bytes: &'a [u8]) -> Option { Some(Self { offset: 0, @@ -78,9 +80,11 @@ impl<'a> Slice<'a> { } /// Returns length of the slice in bits. + #[inline] pub fn len(&self) -> u16 { self.length } /// Returns whether the slice is empty. + #[inline] pub fn is_empty(&self) -> bool { self.length == 0 } /// Returns the first bit in the slice advances the slice by one position. @@ -96,17 +100,18 @@ impl<'a> Slice<'a> { /// assert_eq!(Some(true), slice.pop_front()); /// assert_eq!(None, slice.pop_front()); /// ``` + #[inline] pub fn pop_front(&mut self) -> Option { if self.length == 0 { return None; } - // SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte - let bit = (unsafe { self.ptr.read() } & (0x80 >> self.offset)) != 0; - self.offset = (self.offset + 1) & 7; + // SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte and + // `self.ptr + 1` is valid pointer value. + let (first, rest) = unsafe { (self.ptr.read(), self.ptr.add(1)) }; + let bit = first & (0x80 >> self.offset) != 0; + self.offset = (self.offset + 1) % 8; if self.offset == 0 { - // SAFETY: self.ptr pointing at valid byte ⇒ self.ptr+1 is valid - // pointer - self.ptr = unsafe { self.ptr.add(1) } + self.ptr = rest; } self.length -= 1; Some(bit) @@ -125,6 +130,7 @@ impl<'a> Slice<'a> { /// assert_eq!(Some(false), slice.pop_back()); /// assert_eq!(None, slice.pop_back()); /// ``` + #[inline] pub fn pop_back(&mut self) -> Option { self.length = self.length.checked_sub(1)?; let total_bits = self.underlying_bits_length(); @@ -136,13 +142,39 @@ impl<'a> Slice<'a> { Some(byte & mask != 0) } + /// Returns subslice from the beginning of the slice shrinking the slice by + /// its length. + /// + /// Behaves like [`Self::split_at`] except instead of returning two slices + /// it advances `self` and returns the head. Returns `None` if the slice is + /// too short. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x81], 0, 8).unwrap(); + /// let head = slice.pop_front_slice(4).unwrap(); + /// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(head)); + /// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(slice)); + /// + /// assert_eq!(None, slice.pop_front_slice(5)); + /// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(slice)); + /// ``` + #[inline] + pub fn pop_front_slice(&mut self, length: u16) -> Option { + let (head, tail) = self.split_at(length)?; + *self = tail; + Some(head) + } + /// Returns subslice from the end of the slice shrinking the slice by its /// length. /// - /// Returns `None` if the slice is too short. - /// - /// This is an ‘rpsilt_at’ operation but instead of returning two slices it - /// shortens the slice and returns the tail. + /// Behaves similarly to [`Self::split_at`] except the `length` is the + /// length of the suffix and instead of returning two slices it shortens + /// `self` and returns the tail. Returns `None` if the slice is too short. /// /// ## Example /// @@ -157,14 +189,11 @@ impl<'a> Slice<'a> { /// assert_eq!(None, slice.pop_back_slice(5)); /// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice)); /// ``` + #[inline] pub fn pop_back_slice(&mut self, length: u16) -> Option { - self.length = self.length.checked_sub(length)?; - let total_bits = self.underlying_bits_length(); - // SAFETY: `ptr` is guaranteed to point at offset + original length - // valid bits. - let ptr = unsafe { self.ptr.add(total_bits / 8) }; - let offset = (total_bits % 8) as u8; - Some(Self { ptr, offset, length, phantom: Default::default() }) + let (head, tail) = self.split_at(self.length.checked_sub(length)?)?; + *self = head; + Some(tail) } /// Returns iterator over chunks of slice where each chunk occupies at most @@ -173,8 +202,32 @@ impl<'a> Slice<'a> { /// The final chunk may be shorter. Note that due to offset the first chunk /// may be shorter than 272 bits (i.e. 34 * 8) however it will span full 34 /// bytes. + #[inline] pub fn chunks(&self) -> Chunks<'a> { Chunks(*self) } + /// Splits slice into two at given index. + /// + /// This is like `[T]::split_at` except rather than panicking it returns + /// `None` if the slice is too short. + #[inline] + pub fn split_at(&self, length: u16) -> Option<(Self, Self)> { + let remaining = self.length.checked_sub(length)?; + let left = Slice { length, ..*self }; + // SAFETY: By invariant, `ptr..ptr+(self.offset + self.length + 7) / 8` + // is a valid range. Since `length ≤ self.length` then `ptr + + // (self.offset + length / 8) is valid as well`. + let ptr = unsafe { + self.ptr.add((usize::from(self.offset) + usize::from(length)) / 8) + }; + let right = Slice { + offset: (self.offset + length as u8 % 8) % 8, + length: remaining, + ptr, + phantom: self.phantom, + }; + Some((left, right)) + } + /// Returns whether the slice starts with given prefix. /// /// **Note**: If the `prefix` slice has a different bit offset it is not @@ -192,12 +245,15 @@ impl<'a> Slice<'a> { /// // Different offset: /// assert!(!slice.starts_with(bits::Slice::new(&[0xAA], 4, 4).unwrap())); /// ``` + #[inline] pub fn starts_with(&self, prefix: Slice<'_>) -> bool { - if self.offset != prefix.offset || self.length < prefix.length { - return false; + if self.offset != prefix.offset { + false + } else if let Some((head, _)) = self.split_at(prefix.length) { + head == prefix + } else { + false } - let subslice = Slice { length: prefix.length, ..*self }; - subslice == prefix } /// Removes prefix from the slice; returns `false` if slice doesn’t start @@ -229,16 +285,17 @@ impl<'a> Slice<'a> { /// assert!(slice.strip_prefix(slice.clone())); /// assert_eq!(bits::Slice::new(&[0x00], 4, 0).unwrap(), slice); /// ``` + #[inline] pub fn strip_prefix(&mut self, prefix: Slice<'_>) -> bool { - if !self.starts_with(prefix) { - return false; + if self.offset == prefix.offset { + if let Some((head, tail)) = self.split_at(prefix.length) { + if head == prefix { + *self = tail; + return true; + } + } } - let length = usize::from(prefix.length) + usize::from(prefix.offset); - // SAFETY: self.ptr points to at least length+offset valid bits. - unsafe { self.ptr = self.ptr.add(length / 8) }; - self.offset = (length % 8) as u8; - self.length -= prefix.length; - true + false } /// Strips common prefix from two slices; returns new slice with the common @@ -267,32 +324,41 @@ impl<'a> Slice<'a> { /// assert_eq!(bits::Slice::new(&[0xFF], 6, 0).unwrap(), right); /// ``` pub fn forward_common_prefix(&mut self, other: &mut Slice<'_>) -> Self { - let offset = self.offset; - if offset != other.offset { - return Self { length: 0, ..*self }; - } + let length = (|| { + let offset = self.offset; + if offset != other.offset { + return 0; + } + let length = self.length.min(other.length); + let length = u32::from(length) + u32::from(offset); + let lhs = self.bytes().split_at(((length + 7) / 8) as usize).0; + let rhs = other.bytes().split_at(((length + 7) / 8) as usize).0; + + let (fst, lhs, rhs) = match (lhs.split_first(), rhs.split_first()) { + (Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1), + _ => return 0, + }; + let fst = fst & (0xFF >> offset); + + let total_bits_matched = if fst != 0 { + fst.leading_zeros() + } else if let Some(n) = + lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b) + { + 8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros() + } else { + 8 + lhs.len() as u32 * 8 + } + .min(length); - let length = self.length.min(other.length); - // SAFETY: offset is common offset of both slices and length is shorter - // of either slice, which means that both pointers point to at least - // offset+length bits. - let (idx, length) = unsafe { - forward_common_prefix_impl(self.ptr, other.ptr, offset, length) - }; - let result = Self { length, ..*self }; - - self.length -= length; - self.offset = ((u16::from(self.offset) + length) % 8) as u8; - other.length -= length; - other.offset = self.offset; - // SAFETY: forward_common_prefix_impl guarantees that `idx` is no more - // than what the slices have. - unsafe { - self.ptr = self.ptr.add(idx); - other.ptr = other.ptr.add(idx); + total_bits_matched.saturating_sub(u32::from(offset)) as u16 + })(); + if length == 0 { + Self { length: 0, ..*self } + } else { + other.pop_front_slice(length).unwrap(); + self.pop_front_slice(length).unwrap() } - - result } /// Checks that all bits outside of the specified range are set to zero. @@ -425,49 +491,6 @@ impl<'a> Slice<'a> { } } -/// Implementation of [`Slice::forward_common_prefix`]. -/// -/// ## Safety -/// -/// `lhs` and `rhs` must point to at least `offset + max_length` bits. -unsafe fn forward_common_prefix_impl( - lhs: *const u8, - rhs: *const u8, - offset: u8, - max_length: u16, -) -> (usize, u16) { - let max_length = u32::from(max_length) + u32::from(offset); - // SAFETY: Caller promises that both pointers point to at least offset + - // max_length bits. - let (lhs, rhs) = unsafe { - let len = ((max_length + 7) / 8) as usize; - let lhs = core::slice::from_raw_parts(lhs, len).split_first(); - let rhs = core::slice::from_raw_parts(rhs, len).split_first(); - (lhs, rhs) - }; - - let (first, lhs, rhs) = match (lhs, rhs) { - (Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1), - _ => return (0, 0), - }; - let first = first & (0xFF >> offset); - - let total_bits_matched = if first != 0 { - first.leading_zeros() - } else if let Some(n) = lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b) - { - 8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros() - } else { - 8 + lhs.len() as u32 * 8 - } - .min(max_length); - - ( - (total_bits_matched / 8) as usize, - total_bits_matched.saturating_sub(u32::from(offset)) as u16, - ) -} - impl core::cmp::PartialEq for Slice<'_> { /// Compares two slices to see if they contain the same bits and have the /// same offset. @@ -580,21 +603,13 @@ impl<'a> core::iter::Iterator for Chunks<'a> { type Item = Slice<'a>; fn next(&mut self) -> Option> { - let bytes_len = self.0.bytes().len().min(nodes::MAX_EXTENSION_KEY_SIZE); - if bytes_len == 0 { - return None; + const MAX_LENGTH: u16 = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16; + let length = (MAX_LENGTH - u16::from(self.0.offset)).min(self.0.length); + if length == 0 { + None + } else { + self.0.pop_front_slice(length) } - let slice = &mut self.0; - let offset = slice.offset; - let length = (bytes_len * 8 - usize::from(offset)) - .min(usize::from(slice.length)) as u16; - let ptr = slice.ptr; - slice.offset = 0; - slice.length -= length; - // SAFETY: `ptr` points at a slice which is at least `bytes_len` bytes - // long so it’s safe to advance it by that offset. - slice.ptr = unsafe { slice.ptr.add(bytes_len) }; - Some(Slice { offset, length, ptr, phantom: Default::default() }) } fn size_hint(&self) -> (usize, Option) {