diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml new file mode 100644 index 00000000..1b061340 --- /dev/null +++ b/.github/workflows/master.yml @@ -0,0 +1,61 @@ +--- +name: PR Checks +on: + pull_request: + branches: + - '*' + push: + branches: + - master + +jobs: + misc: + name: Miscellaneous checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + components: rustfmt + + - name: Check formatting + run: cargo fmt --all --check + + - name: Install cargo-deny + run: cargo install cargo-deny + + - name: Check bans + run: cargo-deny --all-features check bans + + stable: + name: Rust stable + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: rustfmt + + - name: Run tests + run: cargo test + + nightly: + name: Rust nightly + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + components: miri + + - name: Run tests + run: cargo test + + - name: Run tests with Miri + run: cargo miri test -- --skip ::stress_test diff --git a/.rustfmt.toml b/.rustfmt.toml index 346ded89..060cd137 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,19 +1,19 @@ -# binop_separator = "Back" -# blank_lines_lower_bound = 0 -# blank_lines_upper_bound = 3 -# brace_style = "PreferSameLine" -# color = "Auto" -# condense_wildcard_suffixes = true -# fn_single_line = true -# format_macro_matchers = true -# format_strings = true -# group_imports = "StdExternalCrate" -# normalize_doc_attributes = true -# overflow_delimited_expr = true - +binop_separator = "Back" +blank_lines_lower_bound = 0 +blank_lines_upper_bound = 3 +color = "Auto" +condense_wildcard_suffixes = true +fn_single_line = true +format_macro_matchers = true +format_strings = true +group_imports = "StdExternalCrate" +imports_granularity = "Module" max_width = 80 newline_style = "Unix" +normalize_doc_attributes = true +overflow_delimited_expr = true reorder_imports = true reorder_modules = true use_field_init_shorthand = true +use_small_heuristics = "Max" use_try_shorthand = true diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..5ea2296a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[workspace.package] +version = "0.0.0" +authors = ["Michal Nazarewicz "] +edition = "2021" +rust-version = "1.71.0" + +[workspace] +members = [ + "sealable-trie", +] +resolver = "2" + +[workspace.dependencies] +base64 = { version = "0.21", default-features = false, features = ["alloc"] } +derive_more = "0.99.17" +pretty_assertions = "1.4.0" +rand = { version = "0.8.5" } +sha2 = { version = "0.10.7", default-features = false } +strum = { version = "0.25.0", default-features = false, features = ["derive"] } diff --git a/deny.toml b/deny.toml new file mode 100644 index 00000000..f009c807 --- /dev/null +++ b/deny.toml @@ -0,0 +1,12 @@ +[sources] +unknown-registry = "deny" +unknown-git = "deny" +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +allow-git = [] + +[bans] +multiple-versions = "deny" +skip = [ + # derive_more still uses old syn + { name = "syn", version = "1.0.*" }, +] diff --git a/sealable-trie/Cargo.toml b/sealable-trie/Cargo.toml new file mode 100644 index 00000000..b6e04c9b --- /dev/null +++ b/sealable-trie/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sealable-trie" +version = "0.0.0" +edition = "2021" + +[dependencies] +base64.workspace = true +derive_more.workspace = true +sha2.workspace = true +strum.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +rand.workspace = true diff --git a/sealable-trie/src/bits.rs b/sealable-trie/src/bits.rs new file mode 100644 index 00000000..ffd3dcba --- /dev/null +++ b/sealable-trie/src/bits.rs @@ -0,0 +1,1021 @@ +use core::fmt; + +#[cfg(test)] +use pretty_assertions::assert_eq; + +use crate::{nodes, stdx}; + +/// Representation of a slice of bits. +/// +/// **Note**: slices with different starting offset are considered different +/// even if iterating over all the bits gives the same result. +#[derive(Clone, Copy)] +pub struct Slice<'a> { + /// Offset in bits to start the slice in `bytes`. + /// + /// In other words, how many most significant bits to skip from `bytes`. + /// This is always less than eight (i.e. we never skip more than one byte). + pub(crate) offset: u8, + + /// Length of the slice in bits. + /// + /// `length + offset` is never more than `36 * 8`. + pub(crate) length: u16, + + /// The bytes to read the bits from. + /// + /// Value of bits outside of the range defined by `offset` and `length` is + /// unspecified and shouldn’t be read. + // Invariant: `ptr` points at `offset + length` valid bits. In other words, + // at `(offset + length + 7) / 8` valid bytes. + pub(crate) ptr: *const u8, + + phantom: core::marker::PhantomData<&'a [u8]>, +} + +/// An iterator over bits in a bit slice. +#[derive(Clone, Copy)] +pub struct Iter<'a> { + /// A 1-bit mask of the next bit to read from `*ptr`. + /// + /// Each time next bit is read, the mask is shifted once to the right. Once + /// it reaches zero, it’s reset to `0x80` and `ptr` is advanced to the next + /// byte. + mask: u8, + + /// Length of the slice in bits. + length: u16, + + /// Pointer to the byte at the beginning of the iterator. + // Invariant: `ptr` points at `offset + length` valid bits where `offset` + // equals `mask.leading_zeros()`. In other words, at `(offset + length + 7) + // / 8` valid bytes. + ptr: *const u8, + + phantom: core::marker::PhantomData<&'a [u8]>, +} + +/// An iterator over chunks of a slice where each chunk (except for the last +/// one) occupies exactly 34 bytes. +#[derive(Clone, Copy)] +pub struct Chunks<'a>(Slice<'a>); + +impl<'a> Slice<'a> { + /// Constructs a new bit slice. + /// + /// `bytes` is underlying bytes slice to read bits from. + /// + /// `offset` specifies how many most significant bits of the first byte of + /// the bytes slice to skip. Must be at most 7. + /// + /// `length` specifies length in bits of the entire bit slice. + /// + /// Returns `None` if `offset` or `length` is too large or `bytes` doesn’t + /// have enough underlying data for the length of the slice. + pub fn new(bytes: &'a [u8], offset: u8, length: u16) -> Option { + if offset >= 8 { + return None; + } + let has_bits = + u32::try_from(bytes.len()).unwrap_or(u32::MAX).saturating_mul(8); + (u32::from(length) + u32::from(offset) <= has_bits).then_some(Self { + offset, + length, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }) + } + + /// Constructs a new bit slice going through all bits in a bytes slice. + /// + /// Returns `None` if the slice is too long. The maximum length is 8191 + /// bytes. + pub fn from_bytes(bytes: &'a [u8]) -> Option { + Some(Self { + offset: 0, + length: u16::try_from(bytes.len().checked_mul(8)?).ok()?, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }) + } + + /// Returns length of the slice in bits. + pub fn len(&self) -> u16 { self.length } + + /// Returns whether the slice is empty. + pub fn is_empty(&self) -> bool { self.length == 0 } + + /// Returns the first bit in the slice advances the slice by one position. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x60], 0, 3).unwrap(); + /// assert_eq!(Some(false), slice.pop_front()); + /// assert_eq!(Some(true), slice.pop_front()); + /// assert_eq!(Some(true), slice.pop_front()); + /// assert_eq!(None, slice.pop_front()); + /// ``` + pub fn pop_front(&mut self) -> Option { + if self.length == 0 { + return None; + } + // SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte + let bit = (unsafe { self.ptr.read() } & (0x80 >> self.offset)) != 0; + self.offset = (self.offset + 1) & 7; + if self.offset == 0 { + // SAFETY: self.ptr pointing at valid byte ⇒ self.ptr+1 is valid + // pointer + self.ptr = unsafe { self.ptr.add(1) } + } + self.length -= 1; + Some(bit) + } + + /// Returns the last bit in the slice shrinking the slice by one bit. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x60], 0, 3).unwrap(); + /// assert_eq!(Some(true), slice.pop_back()); + /// assert_eq!(Some(true), slice.pop_back()); + /// assert_eq!(Some(false), slice.pop_back()); + /// assert_eq!(None, slice.pop_back()); + /// ``` + pub fn pop_back(&mut self) -> Option { + self.length = self.length.checked_sub(1)?; + let total_bits = self.underlying_bits_length(); + // SAFETY: `ptr` is guaranteed to point at offset + original length + // valid bits. Furthermore, since original length was positive than + // there’s at least one byte we can read. + let byte = unsafe { self.ptr.add(total_bits / 8).read() }; + let mask = 0x80 >> (total_bits % 8); + Some(byte & mask != 0) + } + + /// Returns subslice from the end of the slice shrinking the slice by its + /// length. + /// + /// Returns `None` if the slice is too short. + /// + /// This is an ‘rpsilt_at’ operation but instead of returning two slices it + /// shortens the slice and returns the tail. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x81], 0, 8).unwrap(); + /// let tail = slice.pop_back_slice(4).unwrap(); + /// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice)); + /// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(tail)); + /// + /// assert_eq!(None, slice.pop_back_slice(5)); + /// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice)); + /// ``` + pub fn pop_back_slice(&mut self, length: u16) -> Option { + self.length = self.length.checked_sub(length)?; + let total_bits = self.underlying_bits_length(); + // SAFETY: `ptr` is guaranteed to point at offset + original length + // valid bits. + let ptr = unsafe { self.ptr.add(total_bits / 8) }; + let offset = (total_bits % 8) as u8; + Some(Self { ptr, offset, length, phantom: Default::default() }) + } + + /// Returns an iterator over bits in the bit slice. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let slice = bits::Slice::new(&[0xA0], 0, 3).unwrap(); + /// let bits: Vec = slice.iter().collect(); + /// assert_eq!(&[true, false, true], bits.as_slice()); + /// ``` + pub fn iter(&self) -> Iter<'a> { + Iter { + mask: 0x80 >> self.offset, + length: self.length, + ptr: self.ptr, + phantom: self.phantom, + } + } + + /// Returns iterator over chunks of slice where each chunk occupies at most + /// 34 bytes. + /// + /// The final chunk may be shorter. Note that due to offset the first chunk + /// may be shorter than 272 bits (i.e. 34 * 8) however it will span full 34 + /// bytes. + pub fn chunks(&self) -> Chunks<'a> { Chunks(*self) } + + /// Returns whether the slice starts with given prefix. + /// + /// **Note**: If the `prefix` slice has a different bit offset it is not + /// considered a prefix even if it starts with the same bits. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0xAA, 0xA0], 0, 12).unwrap(); + /// + /// assert!(slice.starts_with(bits::Slice::new(&[0xAA], 0, 4).unwrap())); + /// assert!(!slice.starts_with(bits::Slice::new(&[0xFF], 0, 4).unwrap())); + /// // Different offset: + /// assert!(!slice.starts_with(bits::Slice::new(&[0xAA], 4, 4).unwrap())); + /// ``` + pub fn starts_with(&self, prefix: Slice<'_>) -> bool { + if self.offset != prefix.offset || self.length < prefix.length { + return false; + } + let subslice = Slice { length: prefix.length, ..*self }; + subslice == prefix + } + + /// Removes prefix from the slice; returns `false` if slice doesn’t start + /// with given prefix. + /// + /// **Note**: If the `prefix` slice has a different bit offset it is not + /// considered a prefix even if it starts with the same bits. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0xAA, 0xA0], 0, 12).unwrap(); + /// + /// assert!(slice.strip_prefix(bits::Slice::new(&[0xAA], 0, 4).unwrap())); + /// assert_eq!(bits::Slice::new(&[0x0A, 0xA0], 4, 8).unwrap(), slice); + /// + /// // Doesn’t match: + /// assert!(!slice.strip_prefix(bits::Slice::new(&[0x0F], 4, 4).unwrap())); + /// // Different offset: + /// assert!(!slice.strip_prefix(bits::Slice::new(&[0xAA], 0, 4).unwrap())); + /// // Too long: + /// assert!(!slice.strip_prefix(bits::Slice::new(&[0x0A, 0xAA], 4, 12).unwrap())); + /// + /// assert!(slice.strip_prefix(bits::Slice::new(&[0xAA, 0xAA], 4, 6).unwrap())); + /// assert_eq!(bits::Slice::new(&[0x20], 2, 2).unwrap(), slice); + /// + /// assert!(slice.strip_prefix(slice.clone())); + /// assert_eq!(bits::Slice::new(&[0x00], 4, 0).unwrap(), slice); + /// ``` + pub fn strip_prefix(&mut self, prefix: Slice<'_>) -> bool { + if !self.starts_with(prefix) { + return false; + } + let length = usize::from(prefix.length) + usize::from(prefix.offset); + // SAFETY: self.ptr points to at least length+offset valid bits. + unsafe { self.ptr = self.ptr.add(length / 8) }; + self.offset = (length % 8) as u8; + self.length -= prefix.length; + true + } + + /// Strips common prefix from two slices; returns new slice with the common + /// prefix. + /// + /// **Note**: If two slices have different bit offset they are considered to + /// have an empty prefix. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut left = bits::Slice::new(&[0xFF], 0, 8).unwrap(); + /// let mut right = bits::Slice::new(&[0xF0], 0, 8).unwrap(); + /// assert_eq!(bits::Slice::new(&[0xF0], 0, 4).unwrap(), + /// left.forward_common_prefix(&mut right)); + /// assert_eq!(bits::Slice::new(&[0xFF], 4, 4).unwrap(), left); + /// assert_eq!(bits::Slice::new(&[0xF0], 4, 4).unwrap(), right); + /// + /// let mut left = bits::Slice::new(&[0xFF], 0, 8).unwrap(); + /// let mut right = bits::Slice::new(&[0xFF], 0, 6).unwrap(); + /// assert_eq!(bits::Slice::new(&[0xFC], 0, 6).unwrap(), + /// left.forward_common_prefix(&mut right)); + /// assert_eq!(bits::Slice::new(&[0xFF], 6, 2).unwrap(), left); + /// assert_eq!(bits::Slice::new(&[0xFF], 6, 0).unwrap(), right); + /// ``` + pub fn forward_common_prefix(&mut self, other: &mut Slice<'_>) -> Self { + let offset = self.offset; + if offset != other.offset { + return Self { length: 0, ..*self }; + } + + let length = self.length.min(other.length); + // SAFETY: offset is common offset of both slices and length is shorter + // of either slice, which means that both pointers point to at least + // offset+length bits. + let (idx, length) = unsafe { + forward_common_prefix_impl(self.ptr, other.ptr, offset, length) + }; + let result = Self { length, ..*self }; + + self.length -= length; + self.offset = ((u16::from(self.offset) + length) % 8) as u8; + other.length -= length; + other.offset = self.offset; + // SAFETY: forward_common_prefix_impl guarantees that `idx` is no more + // than what the slices have. + unsafe { + self.ptr = self.ptr.add(idx); + other.ptr = other.ptr.add(idx); + } + + result + } + + /// Checks that all bits outside of the specified range are set to zero. + fn check_bytes(bytes: &[u8], offset: u8, length: u16) -> bool { + let (front, back) = Self::masks(offset, length); + let bytes_len = (usize::from(offset) + usize::from(length) + 7) / 8; + bytes_len <= bytes.len() && + (bytes[0] & !front) == 0 && + (bytes[bytes_len - 1] & !back) == 0 && + bytes[bytes_len..].iter().all(|&b| b == 0) + } + + /// Returns total number of underlying bits, i.e. bits in the slice plus the + /// offset. + fn underlying_bits_length(&self) -> usize { + usize::from(self.offset) + usize::from(self.length) + } + + /// Returns bytes underlying the bit slice. + fn bytes(&self) -> &'a [u8] { + let len = (self.underlying_bits_length() + 7) / 8; + // SAFETY: `ptr` is guaranteed to point at offset+length valid bits. + unsafe { core::slice::from_raw_parts(self.ptr, len) } + } + + /// Encodes key into raw binary representation. + /// + /// Fills entire 36-byte buffer. The first the first two bytes encode + /// length and offset (`(length << 3) | offset` specifically leaving the + /// four most significant bits zero) and the rest being bytes holding the + /// bits. Bits which are not part of the slice are set to zero. + /// + /// The first byte written will be xored with `tag`. + /// + /// Returns the length of relevant portion of the buffer. For example, if + /// slice’s length is say 20 bits with zero offset returns five (two bytes + /// for the encoded length and three bytes for the 20 bits). + /// + /// Returns `None` if the slice is empty or too long and won’t fit in the + /// destination buffer. + pub(crate) fn encode_into( + &self, + dest: &mut [u8; 36], + tag: u8, + ) -> Option { + if self.length == 0 { + return None; + } + let bytes = self.bytes(); + if bytes.is_empty() || bytes.len() > dest.len() - 2 { + return None; + } + let (num, tail) = + stdx::split_array_mut::<2, { nodes::MAX_EXTENSION_KEY_SIZE }, 36>( + dest, + ); + tail.fill(0); + *num = self.encode_num(tag); + let (key, _) = tail.split_at_mut(bytes.len()); + let (front, back) = Self::masks(self.offset, self.length); + key.copy_from_slice(bytes); + key[0] &= front; + key[bytes.len() - 1] &= back; + Some(2 + bytes.len()) + } + + /// Encodes key into raw binary representation and sends it to the consumer. + /// + /// This is like [`Self::encode_into`] except that it doesn’t check the + /// length of the key. + pub(crate) fn write_into(&self, mut consumer: impl FnMut(&[u8]), tag: u8) { + consumer(&self.encode_num(tag)); + + let (front, back) = Self::masks(self.offset, self.length); + let bytes = self.bytes(); + match bytes.len() { + 0 => (), + 1 => consumer(&[bytes[0] & front & back]), + 2 => consumer(&[bytes[0] & front, bytes[1] & back]), + n => { + consumer(&[bytes[0] & front]); + consumer(&bytes[1..n - 1]); + consumer(&[bytes[n - 1] & back]); + } + } + } + + /// Decodes key from a raw binary representation. + /// + /// The first byte read will be xored with `tag`. + /// + /// This is the inverse of [`Self::encode_into`]. + pub(crate) fn decode(src: &'a [u8], tag: u8) -> Option { + let (&[high, low], bytes) = stdx::split_at(src)?; + let tag = u16::from_be_bytes([high ^ tag, low]); + let (offset, length) = ((tag % 8) as u8, tag / 8); + (length > 0 && Self::check_bytes(bytes, offset, length)).then_some( + Self { + offset, + length, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }, + ) + } + + /// Encodes offset and length as a two-byte number. + /// + /// The encoding is `llll_llll llll_looo`, i.e. 13-bit length in the most + /// significant bits and 3-bit offset in the least significant bits. The + /// first byte is then further xored with the `tag` argument. + /// + /// This method doesn’t check whether the length and offset are within range. + fn encode_num(&self, tag: u8) -> [u8; 2] { + let num = (self.length << 3) | u16::from(self.offset); + (num ^ (u16::from(tag) << 8)).to_be_bytes() + } + + /// Helper method which returns masks for leading and trailing byte. + /// + /// Based on provided bit offset (which must be ≤ 7) and bit length of the + /// slice returns: mask of bits in the first byte that are part of the + /// slice and mask of bits in the last byte that are part of the slice. + fn masks(offset: u8, length: u16) -> (u8, u8) { + let bits = usize::from(offset) + usize::from(length); + // `1 << 20` is an arbitrary number which is divisible by 8 and greater + // than bits. + let tail = ((1 << 20) - bits) % 8; + (0xFF >> offset, 0xFF << tail) + } +} + +/// Implementation of [`Slice::forward_common_prefix`]. +/// +/// ## Safety +/// +/// `lhs` and `rhs` must point to at least `offset + max_length` bits. +unsafe fn forward_common_prefix_impl( + lhs: *const u8, + rhs: *const u8, + offset: u8, + max_length: u16, +) -> (usize, u16) { + let max_length = u32::from(max_length) + u32::from(offset); + // SAFETY: Caller promises that both pointers point to at least offset + + // max_length bits. + let (lhs, rhs) = unsafe { + let len = ((max_length + 7) / 8) as usize; + let lhs = core::slice::from_raw_parts(lhs, len).split_first(); + let rhs = core::slice::from_raw_parts(rhs, len).split_first(); + (lhs, rhs) + }; + + let (first, lhs, rhs) = match (lhs, rhs) { + (Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1), + _ => return (0, 0), + }; + let first = first & (0xFF >> offset); + + let total_bits_matched = if first != 0 { + first.leading_zeros() + } else if let Some(n) = lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b) + { + 8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros() + } else { + 8 + lhs.len() as u32 * 8 + } + .min(max_length); + + ( + (total_bits_matched / 8) as usize, + total_bits_matched.saturating_sub(u32::from(offset)) as u16, + ) +} + +impl<'a> core::iter::IntoIterator for Slice<'a> { + type Item = bool; + type IntoIter = Iter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { self.iter() } +} + +impl<'a> core::iter::IntoIterator for &'a Slice<'a> { + type Item = bool; + type IntoIter = Iter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { (*self).iter() } +} + +impl core::cmp::PartialEq for Slice<'_> { + /// Compares two slices to see if they contain the same bits and have the + /// same offset. + /// + /// **Note**: If the slices are the same length and contain the same bits + /// but their offsets are different, they are considered non-equal. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// assert_eq!(bits::Slice::new(&[0xFF], 0, 6), + /// bits::Slice::new(&[0xFF], 0, 6)); + /// assert_ne!(bits::Slice::new(&[0xFF], 0, 6), + /// bits::Slice::new(&[0xF0], 0, 6)); + /// assert_ne!(bits::Slice::new(&[0xFF], 0, 6), + /// bits::Slice::new(&[0xFF], 2, 6)); + /// ``` + fn eq(&self, other: &Self) -> bool { + if self.offset != other.offset || self.length != other.length { + return false; + } else if self.length == 0 { + return true; + } + let (front, back) = Self::masks(self.offset, self.length); + let (lhs, rhs) = (self.bytes(), other.bytes()); + let len = lhs.len(); + if len == 1 { + ((lhs[0] ^ rhs[0]) & front & back) == 0 + } else { + ((lhs[0] ^ rhs[0]) & front) == 0 && + ((lhs[len - 1] ^ rhs[len - 1]) & back) == 0 && + lhs[1..len - 1] == rhs[1..len - 1] + } + } +} + +impl fmt::Display for Slice<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(buf: &mut [u8], mut byte: u8) { + for ch in buf.iter_mut().rev() { + *ch = b'0' + (byte & 1); + byte >>= 1; + } + } + + let (first, mid) = match self.bytes().split_first() { + None => return fmtr.write_str("∅"), + Some(pair) => pair, + }; + + let off = usize::from(self.offset); + let len = usize::from(self.length); + let mut buf = [0; 10]; + fmt(&mut buf[2..], *first); + buf[0] = b'0'; + buf[1] = b'b'; + buf[2..2 + off].fill(b'.'); + + let (last, mid) = match mid.split_last() { + None => { + buf[2 + off + len..].fill(b'.'); + let val = unsafe { core::str::from_utf8_unchecked(&buf) }; + return fmtr.write_str(val); + } + Some(pair) => pair, + }; + + fmtr.write_str(unsafe { core::str::from_utf8_unchecked(&buf) })?; + for byte in mid { + write!(fmtr, "_{:08b}", byte)?; + } + fmt(&mut buf[..9], *last); + buf[0] = b'_'; + let len = (off + len) % 8; + if len != 0 { + buf[1 + len..].fill(b'.'); + } + fmtr.write_str(unsafe { core::str::from_utf8_unchecked(&buf[..9]) }) + } +} + +impl fmt::Debug for Slice<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_fmt("Slice", self, fmtr) + } +} + +impl fmt::Debug for Iter<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + let slice = Slice { + offset: self.mask.leading_zeros() as u8, + length: self.length, + ptr: self.ptr, + phantom: self.phantom, + }; + debug_fmt("Iter", &slice, fmtr) + } +} + +impl fmt::Debug for Chunks<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_fmt("Chunks", &self.0, fmtr) + } +} + +/// Internal function for debug formatting objects objects. +fn debug_fmt( + name: &str, + slice: &Slice<'_>, + fmtr: &mut fmt::Formatter<'_>, +) -> fmt::Result { + fmtr.debug_struct(name) + .field("offset", &slice.offset) + .field("length", &slice.length) + .field("bytes", &core::format_args!("{:02x?}", slice.bytes())) + .finish() +} + +impl<'a> core::iter::Iterator for Iter<'a> { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.length == 0 { + return None; + } + // SAFETY: When length is non-zero, ptr points to a valid byte. + let result = (unsafe { self.ptr.read() } & self.mask) != 0; + self.length -= 1; + self.mask = self.mask.rotate_right(1); + if self.mask == 0x80 { + // SAFETY: ptr points to a valid object (see above) so ptr+1 is + // a valid pointer (at worst it’s one-past-the-end pointer). + self.ptr = unsafe { self.ptr.add(1) }; + } + Some(result) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (usize::from(self.length), Some(usize::from(self.length))) + } + + #[inline] + fn count(self) -> usize { usize::from(self.length) } +} + +impl<'a> core::iter::ExactSizeIterator for Iter<'a> { + #[inline] + fn len(&self) -> usize { usize::from(self.length) } +} + +impl<'a> core::iter::FusedIterator for Iter<'a> {} + +impl<'a> core::iter::Iterator for Chunks<'a> { + type Item = Slice<'a>; + + fn next(&mut self) -> Option> { + let bytes_len = self.0.bytes().len().min(nodes::MAX_EXTENSION_KEY_SIZE); + if bytes_len == 0 { + return None; + } + let slice = &mut self.0; + let offset = slice.offset; + let length = (bytes_len * 8 - usize::from(offset)) + .min(usize::from(slice.length)) as u16; + let ptr = slice.ptr; + slice.offset = 0; + slice.length -= length; + // SAFETY: `ptr` points at a slice which is at least `bytes_len` bytes + // long so it’s safe to advance it by that offset. + slice.ptr = unsafe { slice.ptr.add(bytes_len) }; + Some(Slice { offset, length, ptr, phantom: Default::default() }) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl<'a> core::iter::ExactSizeIterator for Chunks<'a> { + #[inline] + fn len(&self) -> usize { + self.0.bytes().chunks(nodes::MAX_EXTENSION_KEY_SIZE).len() + } +} + +impl<'a> core::iter::DoubleEndedIterator for Chunks<'a> { + fn next_back(&mut self) -> Option> { + let mut chunks = self.0.bytes().chunks(nodes::MAX_EXTENSION_KEY_SIZE); + let bytes = chunks.next_back()?; + + if chunks.next().is_none() { + let empty = Slice { + offset: 0, + length: 0, + ptr: self.0.ptr, + phantom: Default::default(), + }; + return Some(core::mem::replace(&mut self.0, empty)); + } + + // `1 << 20` is an arbitrary number which is divisible by 8 and greater + // than underlying_bits_length. + let tail = ((1 << 20) - self.0.underlying_bits_length()) % 8; + let length = (bytes.len() * 8 - tail) as u16; + self.0.length -= length; + + Some(Slice { + offset: 0, + length, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }) + } +} + +#[test] +fn test_encode() { + #[track_caller] + fn test(want_encoded: &[u8], offset: u8, length: u16, bytes: &[u8]) { + let slice = Slice::new(bytes, offset, length).unwrap(); + + if want_encoded.len() <= 36 { + let mut buf = [0; 36]; + let len = slice + .encode_into(&mut buf, 0) + .unwrap_or_else(|| panic!("Failed encoding {slice}")); + assert_eq!( + want_encoded, + &buf[..len], + "Unexpected encoded representation of {slice}" + ); + } + + let mut buf = alloc::vec::Vec::with_capacity(want_encoded.len()); + slice.write_into(|bytes| buf.extend_from_slice(bytes), 0); + assert_eq!( + want_encoded, + buf.as_slice(), + "Unexpected written representation of {slice}" + ); + + let round_trip = Slice::decode(want_encoded, 0) + .unwrap_or_else(|| panic!("Failed decoding {want_encoded:?}")); + assert_eq!(slice, round_trip); + } + + test(&[0, 1 * 8 + 0, 0x80], 0, 1, &[0x80]); + test(&[0, 1 * 8 + 0, 0x80], 0, 1, &[0xFF]); + test(&[0, 1 * 8 + 4, 0x08], 4, 1, &[0xFF]); + test(&[0, 9 * 8 + 0, 0xFF, 0x80], 0, 9, &[0xFF, 0xFF]); + test(&[0, 9 * 8 + 4, 0x0F, 0xF8], 4, 9, &[0xFF, 0xFF]); + test(&[0, 17 * 8 + 0, 0xFF, 0xFF, 0x80], 0, 17, &[0xFF, 0xFF, 0xFF]); + test(&[0, 17 * 8 + 4, 0x0F, 0xFF, 0xF8], 4, 17, &[0xFF, 0xFF, 0xFF]); + + let mut want = [0xFF; 1026]; + want[0] = (8191u16 >> 5) as u8; + want[1] = (8191u16 << 3) as u8; + want[1025] = 0xFE; + test(&want[..], 0, 8191, &[0xFF; 1024][..]); + + want[1] += 1; + want[2] = 0x7F; + want[1025] = 0xFF; + test(&want[..], 1, 8191, &[0xFF; 1024][..]); +} + +#[test] +fn test_decode() { + #[track_caller] + fn ok(num: u16, bytes: &[u8], want_offset: u8, want_length: u16) { + let bytes = [&num.to_be_bytes()[..], bytes].concat(); + let got = Slice::decode(&bytes, 0).unwrap_or_else(|| { + panic!("Expected to get a Slice from {bytes:x?}") + }); + assert_eq!((want_offset, want_length), (got.offset, got.length)); + } + + // Correct values, all bits zero. + ok(34 * 64, &[0; 34], 0, 34 * 8); + ok(33 * 64 + 7, &[0; 34], 7, 264); + ok(2 * 64, &[0, 0], 0, 16); + + // Empty + assert_eq!(None, Slice::decode(&[], 0)); + assert_eq!(None, Slice::decode(&[0], 0)); + assert_eq!(None, Slice::decode(&[0, 0], 0)); + + #[track_caller] + fn test(length: u16, offset: u8, bad: &[u8], good: &[u8]) { + let num = length * 8 + u16::from(offset); + let bad = [&num.to_be_bytes()[..], bad].concat(); + assert_eq!(None, Slice::decode(&bad, 0)); + + let good = [&num.to_be_bytes()[..], good].concat(); + let got = Slice::decode(&good, 0).unwrap_or_else(|| { + panic!("Expected to get a Slice from {good:x?}") + }); + assert_eq!( + (offset, length), + (got.offset, got.length), + "Invalid offset and length decoding {good:x?}" + ); + + let good = [&good[..], &[0, 0]].concat(); + let got = Slice::decode(&good, 0).unwrap_or_else(|| { + panic!("Expected to get a Slice from {good:x?}") + }); + assert_eq!( + (offset, length), + (got.offset, got.length), + "Invalid offset and length decoding {good:x?}" + ); + } + + // Bytes buffer doesn’t match the length. + test(8, 0, &[], &[0]); + test(8, 7, &[0], &[0, 0]); + test(16, 1, &[0, 0], &[0, 0, 0]); + + // Bits which should be zero aren’t. + // Leading bits are skipped: + test(16 - 1, 1, &[0x80, 0], &[0x7F, 0xFF]); + test(16 - 2, 2, &[0x40, 0], &[0x3F, 0xFF]); + test(16 - 3, 3, &[0x20, 0], &[0x1F, 0xFF]); + test(16 - 4, 4, &[0x10, 0], &[0x0F, 0xFF]); + test(16 - 5, 5, &[0x08, 0], &[0x07, 0xFF]); + test(16 - 6, 6, &[0x04, 0], &[0x03, 0xFF]); + test(16 - 7, 7, &[0x02, 0], &[0x01, 0xFF]); + + // Tailing bits are skipped: + test(16 - 1, 0, &[0, 0x01], &[0xFF, 0xFE]); + test(16 - 2, 0, &[0, 0x02], &[0xFF, 0xFC]); + test(16 - 3, 0, &[0, 0x04], &[0xFF, 0xF8]); + test(16 - 4, 0, &[0, 0x08], &[0xFF, 0xF0]); + test(16 - 5, 0, &[0, 0x10], &[0xFF, 0xE0]); + test(16 - 6, 0, &[0, 0x20], &[0xFF, 0xC0]); + test(16 - 7, 0, &[0, 0x40], &[0xFF, 0x80]); + + // Some leading and some tailing bits are skipped of the same byte: + test(1, 1, &[!0x40], &[0x40]); + test(1, 2, &[!0x20], &[0x20]); + test(1, 3, &[!0x10], &[0x10]); + test(1, 4, &[!0x08], &[0x08]); + test(1, 5, &[!0x04], &[0x04]); + test(1, 6, &[!0x02], &[0x02]); +} + +#[test] +fn test_common_prefix() { + let mut lhs = Slice::new(&[0x86, 0xE9], 1, 15).unwrap(); + let mut rhs = Slice::new(&[0x06, 0xE9], 1, 15).unwrap(); + let got = lhs.forward_common_prefix(&mut rhs); + let want = ( + Slice::new(&[0x06, 0xE9], 1, 15).unwrap(), + Slice::new(&[], 0, 0).unwrap(), + Slice::new(&[], 0, 0).unwrap(), + ); + assert_eq!(want, (got, lhs, rhs)); +} + +#[test] +fn test_display() { + fn test(want: &str, bytes: &[u8], offset: u8, length: u16) { + use alloc::string::ToString; + + let got = Slice::new(bytes, offset, length).unwrap().to_string(); + assert_eq!(want, got) + } + + test("0b111111..", &[0xFF], 0, 6); + test("0b..1111..", &[0xFF], 2, 4); + test("0b..111111_11......", &[0xFF, 0xFF], 2, 8); + test("0b..111111_11111111_11......", &[0xFF, 0xFF, 0xFF], 2, 16); + + test("0b10101010", &[0xAA], 0, 8); + test("0b...0101.", &[0xAA], 3, 4); +} + +#[test] +fn test_eq() { + assert_eq!(Slice::new(&[0xFF], 0, 8), Slice::new(&[0xFF], 0, 8)); + assert_eq!(Slice::new(&[0xFF], 0, 4), Slice::new(&[0xF0], 0, 4)); + assert_eq!(Slice::new(&[0xFF], 4, 4), Slice::new(&[0x0F], 4, 4)); +} + +#[test] +#[rustfmt::skip] +fn test_iter() { + use alloc::vec::Vec; + + #[track_caller] + fn test(want: &[u8], bytes: &[u8], offset: u8, length: u16) { + let want = want.iter().map(|&b| b != 0).collect::>(); + let slice = Slice::new(bytes, offset, length).unwrap(); + let got = slice.iter().collect::>(); + assert_eq!(want, got); + } + + test(&[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], + &[0xAA, 0xAA], 0, 16); + test(&[1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0], + &[0x0A, 0xFA], 4, 12); + test(&[0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1], + &[0x0A, 0xFA], 0, 12); + test(&[1, 1, 0, 0], &[0x30], 2, 4); +} + +#[test] +fn test_chunks() { + let data = (0..=255).collect::>(); + let data = data.as_slice(); + + let slice = |off: u8, len: u16| Slice::new(data, off, len).unwrap(); + + // Single chunk + for offset in 0..8 { + for length in 1..(34 * 8 - u16::from(offset)) { + let want = Slice::new(data, offset, length); + + let mut chunks = slice(offset, length).chunks(); + assert_eq!(want, chunks.next()); + assert_eq!(None, chunks.next()); + + let mut chunks = slice(offset, length).chunks(); + assert_eq!(want, chunks.next_back()); + assert_eq!(None, chunks.next()); + } + } + + // Two chunks + for offset in 0..8 { + let want_first = Slice::new(data, offset, 34 * 8 - u16::from(offset)); + let want_second = Slice::new(&data[34..], 0, 10 + u16::from(offset)); + + let mut chunks = slice(offset, 34 * 8 + 10).chunks(); + assert_eq!(want_first, chunks.next()); + assert_eq!(want_second, chunks.next()); + assert_eq!(None, chunks.next()); + + let mut chunks = slice(offset, 34 * 8 + 10).chunks(); + assert_eq!(want_second, chunks.next_back()); + assert_eq!(want_first, chunks.next_back()); + assert_eq!(None, chunks.next()); + + let mut chunks = slice(offset, 34 * 8 + 10).chunks(); + assert_eq!(want_second, chunks.next_back()); + assert_eq!(want_first, chunks.next()); + assert_eq!(None, chunks.next()); + } +} + +#[test] +fn test_pop() { + use alloc::string::String; + + const WANT: &str = concat!("11001110", "00011110", "00011111"); + const BYTES: [u8; 3] = [0b1100_1110, 0b0001_1110, 0b0001_1111]; + + fn test( + want: &str, + mut slice: Slice, + reverse: bool, + pop: fn(&mut Slice) -> Option, + ) { + let got = core::iter::from_fn(move || pop(&mut slice)) + .map(|bit| char::from(b'0' + u8::from(bit))) + .collect::(); + let want = if reverse { + want.chars().rev().collect() + } else { + String::from(want) + }; + assert_eq!(want, got); + } + + fn test_set(reverse: bool, pop: fn(&mut Slice) -> Option) { + for start in 0..8 { + for end in start..=24 { + let slice = + Slice::new(&BYTES[..], start as u8, (end - start) as u16); + test(&WANT[start..end], slice.unwrap(), reverse, pop); + } + } + } + + test_set(false, |slice| slice.pop_front()); + test_set(true, |slice| slice.pop_back()); +} diff --git a/sealable-trie/src/hash.rs b/sealable-trie/src/hash.rs new file mode 100644 index 00000000..b0d2cc1b --- /dev/null +++ b/sealable-trie/src/hash.rs @@ -0,0 +1,214 @@ +use base64::engine::general_purpose::STANDARD as BASE64_ENGINE; +use base64::Engine; +use sha2::Digest; + +/// A cryptographic hash. +#[derive( + Clone, + Default, + PartialEq, + Eq, + derive_more::AsRef, + derive_more::AsMut, + derive_more::From, + derive_more::Into, +)] +#[as_ref(forward)] +#[into(owned, ref, ref_mut)] +#[repr(transparent)] +pub struct CryptoHash(pub [u8; CryptoHash::LENGTH]); + +// TODO(mina86): Make the code generic such that CryptoHash::digest take generic +// argument for the hash to use. This would then mean that Trie, Proof and +// other objects which need to calculate hashes would need to take that argument +// as well. +impl CryptoHash { + /// Length in bytes of the cryptographic hash. + pub const LENGTH: usize = 32; + + /// Default hash value (all zero bits). + pub const DEFAULT: CryptoHash = CryptoHash([0; 32]); + + /// Returns a builder which can be used to construct cryptographic hash by + /// digesting bytes. + #[inline] + pub fn builder() -> Builder { Builder::default() } + + /// Returns hash of given bytes. + #[inline] + pub fn digest(bytes: &[u8]) -> Self { + Self(sha2::Sha256::digest(bytes).into()) + } + + /// Returns hash of concatenation of given byte slices. + #[inline] + pub fn digest_vec(slices: &[&[u8]]) -> Self { + let mut builder = Self::builder(); + for slice in slices { + builder.update(slice); + } + builder.build() + } + + + /// Creates a new hash with given number encoded in its first bytes. + /// + /// This is meant for tests which need to use arbitrary hash values. + #[cfg(test)] + pub(crate) const fn test(num: usize) -> CryptoHash { + let mut buf = [0; Self::LENGTH]; + let num = (num as u32).to_be_bytes(); + let mut idx = 0; + while idx < buf.len() { + buf[idx] = num[idx % num.len()]; + idx += 1; + } + Self(buf) + } + + /// Returns whether the hash is all zero bits. Equivalent to comparing to + /// the default `CryptoHash` object. + #[inline] + pub fn is_zero(&self) -> bool { self.0.iter().all(|&byte| byte == 0) } + + /// Returns reference to the hash as slice of bytes. + #[inline] + pub fn as_slice(&self) -> &[u8] { &self.0[..] } +} + +impl core::fmt::Display for CryptoHash { + /// Encodes the hash as base64 and prints it as a string. + fn fmt(&self, fmtr: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + const ENCODED_LENGTH: usize = (CryptoHash::LENGTH + 2) / 3 * 4; + let mut buf = [0u8; ENCODED_LENGTH]; + let len = + BASE64_ENGINE.encode_slice(self.as_slice(), &mut buf[..]).unwrap(); + // SAFETY: base64 fills the buffer with ASCII characters only. + fmtr.write_str(unsafe { core::str::from_utf8_unchecked(&buf[..len]) }) + } +} + +impl core::fmt::Debug for CryptoHash { + /// Encodes the hash as base64 and prints it as a string. + #[inline] + fn fmt(&self, fmtr: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Display::fmt(self, fmtr) + } +} + +impl<'a> From<&'a [u8; CryptoHash::LENGTH]> for CryptoHash { + #[inline] + fn from(hash: &'a [u8; CryptoHash::LENGTH]) -> Self { + <&CryptoHash>::from(hash).clone() + } +} + +impl From<&'_ CryptoHash> for [u8; CryptoHash::LENGTH] { + #[inline] + fn from(hash: &'_ CryptoHash) -> Self { hash.0.clone() } +} + +impl<'a> From<&'a [u8; CryptoHash::LENGTH]> for &'a CryptoHash { + #[inline] + fn from(hash: &'a [u8; CryptoHash::LENGTH]) -> Self { + let hash = + (hash as *const [u8; CryptoHash::LENGTH]).cast::(); + // SAFETY: CryptoHash is repr(transparent) over [u8; CryptoHash::LENGTH] + // thus transmuting is safe. + unsafe { &*hash } + } +} + +impl<'a> From<&'a mut [u8; CryptoHash::LENGTH]> for &'a mut CryptoHash { + #[inline] + fn from(hash: &'a mut [u8; CryptoHash::LENGTH]) -> Self { + let hash = (hash as *mut [u8; CryptoHash::LENGTH]).cast::(); + // SAFETY: CryptoHash is repr(transparent) over [u8; CryptoHash::LENGTH] + // thus transmuting is safe. + unsafe { &mut *hash } + } +} + +impl<'a> TryFrom<&'a [u8]> for &'a CryptoHash { + type Error = core::array::TryFromSliceError; + + #[inline] + fn try_from(hash: &'a [u8]) -> Result { + <&[u8; CryptoHash::LENGTH]>::try_from(hash).map(Into::into) + } +} + +impl<'a> TryFrom<&'a mut [u8]> for &'a mut CryptoHash { + type Error = core::array::TryFromSliceError; + + #[inline] + fn try_from(hash: &'a mut [u8]) -> Result { + <&mut [u8; CryptoHash::LENGTH]>::try_from(hash).map(Into::into) + } +} + +impl<'a> TryFrom<&'a [u8]> for CryptoHash { + type Error = core::array::TryFromSliceError; + + #[inline] + fn try_from(hash: &'a [u8]) -> Result { + <&CryptoHash>::try_from(hash).map(Clone::clone) + } +} + +/// Builder for the cryptographic hash. +/// +/// The builder calculates the digest of bytes that it’s fed using the +/// [`Builder::update`] method. +/// +/// This is useful if there are multiple discontiguous buffers that hold the +/// data to be hashed. If all data is in a single contiguous buffer it’s more +/// convenient to use [`CryptoHash::digest`] instead. +#[derive(Default)] +pub struct Builder(sha2::Sha256); + +impl Builder { + /// Process data, updating the internal state of the digest. + #[inline] + pub fn update(&mut self, bytes: &[u8]) { self.0.update(bytes) } + + /// Finalises the digest and returns the cryptographic hash. + #[inline] + pub fn build(self) -> CryptoHash { CryptoHash(self.0.finalize().into()) } +} + +#[test] +fn test_new_hash() { + assert_eq!(CryptoHash::from([0; 32]), CryptoHash::default()); + + // https://www.di-mgt.com.au/sha_testvectors.html + let want = CryptoHash::from([ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, + 0x99, 0x6f, 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + ]); + assert_eq!(want, CryptoHash::digest(b"")); + assert_eq!(want, CryptoHash::builder().build()); + let got = { + let mut builder = CryptoHash::builder(); + builder.update(b""); + builder.build() + }; + assert_eq!(want, got); + assert_eq!(want, CryptoHash::builder().build()); + + let want = CryptoHash::from([ + 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, 0x41, 0x41, 0x40, 0xde, + 0x5d, 0xae, 0x22, 0x23, 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c, + 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad, + ]); + assert_eq!(want, CryptoHash::digest(b"abc")); + assert_eq!(want, CryptoHash::digest_vec(&[b"a", b"bc"])); + let got = { + let mut builder = CryptoHash::builder(); + builder.update(b"a"); + builder.update(b"bc"); + builder.build() + }; + assert_eq!(want, got); +} diff --git a/sealable-trie/src/lib.rs b/sealable-trie/src/lib.rs new file mode 100644 index 00000000..49500c3b --- /dev/null +++ b/sealable-trie/src/lib.rs @@ -0,0 +1,17 @@ +#![no_std] +extern crate alloc; +#[cfg(test)] +extern crate std; + +pub mod bits; +pub mod hash; +pub mod memory; +pub mod nodes; +pub mod proof; +pub(crate) mod stdx; +pub mod trie; + +#[cfg(test)] +mod test_utils; + +pub use trie::Trie; diff --git a/sealable-trie/src/memory.rs b/sealable-trie/src/memory.rs new file mode 100644 index 00000000..16492844 --- /dev/null +++ b/sealable-trie/src/memory.rs @@ -0,0 +1,408 @@ +use alloc::vec::Vec; +use core::fmt; +use core::num::NonZeroU32; + +use crate::nodes::RawNode; + +/// A pointer value. The value is 30-bit and always non-zero. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + derive_more::Into, + derive_more::Deref, +)] +#[into(owned, ref, ref_mut)] +#[repr(transparent)] +pub struct Ptr(NonZeroU32); + +#[derive(Copy, Clone, Debug, PartialEq, Eq, derive_more::Display)] +pub struct OutOfMemory; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, derive_more::Display)] +pub struct AddressTooLarge(pub NonZeroU32); + +impl Ptr { + /// Largest value that can be stored in the pointer. + // The two most significant bits are used internally in RawNode encoding + // thus the max value is 30-bit. + const MAX: u32 = u32::MAX >> 2; + + /// Constructs a new pointer from given address. + /// + /// If the value is zero, returns `None` indicating a null pointer. If + /// value fits in 30 bits, returns a new `Ptr` with that value. Otherwise + /// returns an error with the argument. + /// + /// ## Example + /// + /// ``` + /// # use core::num::NonZeroU32; + /// # use sealable_trie::memory::*; + /// + /// assert_eq!(Ok(None), Ptr::new(0)); + /// assert_eq!(42, Ptr::new(42).unwrap().unwrap().get()); + /// assert_eq!((1 << 30) - 1, + /// Ptr::new((1 << 30) - 1).unwrap().unwrap().get()); + /// assert_eq!(Err(AddressTooLarge(NonZeroU32::new(1 << 30).unwrap())), + /// Ptr::new(1 << 30)); + /// ``` + pub const fn new(ptr: u32) -> Result, AddressTooLarge> { + // Using match so the function is const + match NonZeroU32::new(ptr) { + None => Ok(None), + Some(num) if num.get() <= Self::MAX => Ok(Some(Self(num))), + Some(num) => Err(AddressTooLarge(num)), + } + } + + /// Constructs a new pointer from given address. + /// + /// Two most significant bits of the address are masked out thus ensuring + /// that the value is never too large. + pub(crate) fn new_truncated(ptr: u32) -> Option { + NonZeroU32::new(ptr & Self::MAX).map(Self) + } +} + +impl TryFrom for Ptr { + type Error = AddressTooLarge; + + /// Constructs a new pointer from given non-zero address. + /// + /// If the address is too large (see [`Ptr::MAX`]) returns an error with the + /// address which has been passed. + /// + /// ## Example + /// + /// ``` + /// # use core::num::NonZeroU32; + /// # use sealable_trie::memory::*; + /// + /// let answer = NonZeroU32::new(42).unwrap(); + /// assert_eq!(42, Ptr::try_from(answer).unwrap().get()); + /// + /// let large = NonZeroU32::new(1 << 30).unwrap(); + /// assert_eq!(Err(AddressTooLarge(large)), Ptr::try_from(large)); + /// ``` + fn try_from(num: NonZeroU32) -> Result { + if num.get() <= Ptr::MAX { + Ok(Ptr(num)) + } else { + Err(AddressTooLarge(num)) + } + } +} + +impl fmt::Display for Ptr { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.get().fmt(fmtr) + } +} + +impl fmt::Debug for Ptr { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.get().fmt(fmtr) + } +} + +/// An interface for memory management used by the trie. +pub trait Allocator { + /// Allocates a new block and initialise it to given value. + fn alloc(&mut self, value: RawNode) -> Result; + + /// Returns value stored at given pointer. + /// + /// May panic or return garbage if `ptr` is invalid. + fn get(&self, ptr: Ptr) -> RawNode; + + /// Sets value at given pointer. + fn set(&mut self, ptr: Ptr, value: RawNode); + + /// Frees a block. + fn free(&mut self, ptr: Ptr); +} + +/// A write log which can be committed or rolled back. +/// +/// Rather than writing data directly to the allocate, it keeps all changes in +/// memory. When committing, the changes are then applied. Similarly, list of +/// all allocated nodes are kept and during rollback all of those nodes are +/// freed. +/// +/// **Note:** *All* reads are passed directly to the underlying allocator. This +/// means reading a node that has been written to will return the old result. +/// +/// Note that the write log doesn’t offer isolation. Most notably, writes to +/// the allocator performed outside of the write log are visible when accessing +/// the nodes via the write log. (To indicate that, the API doesn’t offer `get` +/// method and instead all reads need to go through the underlying allocator). +/// +/// Secondly, allocations done via the write log are visible outside of the +/// write log. The assumption is that nothing outside of the client of the +/// write log knows the pointer thus in practice they cannot refer to those +/// allocated but not-yet-committed nodes. +pub struct WriteLog<'a, A: Allocator> { + /// Allocator to pass requests to. + alloc: &'a mut A, + + /// List of changes in the transaction. + write_log: Vec<(Ptr, RawNode)>, + + /// List pointers to nodes allocated during the transaction. + allocated: Vec, + + /// List of nodes freed during the transaction. + freed: Vec, +} + +impl<'a, A: Allocator> WriteLog<'a, A> { + pub fn new(alloc: &'a mut A) -> Self { + Self { + alloc, + write_log: Vec::new(), + allocated: Vec::new(), + freed: Vec::new(), + } + } + + /// Commit all changes to the allocator. + /// + /// There’s no explicit rollback method. To roll changes back, drop the + /// object. + pub fn commit(mut self) { + self.allocated.clear(); + for (ptr, value) in self.write_log.drain(..) { + self.alloc.set(ptr, value) + } + for ptr in self.freed.drain(..) { + self.alloc.free(ptr) + } + } + + /// Returns underlying allocator. + pub fn allocator(&self) -> &A { &*self.alloc } + + pub fn alloc(&mut self, value: RawNode) -> Result { + let ptr = self.alloc.alloc(value)?; + self.allocated.push(ptr); + Ok(ptr) + } + + pub fn set(&mut self, ptr: Ptr, value: RawNode) { + self.write_log.push((ptr, value)) + } + + pub fn free(&mut self, ptr: Ptr) { self.freed.push(ptr); } +} + +impl<'a, A: Allocator> core::ops::Drop for WriteLog<'a, A> { + fn drop(&mut self) { + self.write_log.clear(); + self.freed.clear(); + for ptr in self.allocated.drain(..) { + self.alloc.free(ptr) + } + } +} + +#[cfg(test)] +pub(crate) mod test_utils { + use super::*; + use crate::stdx; + + pub struct TestAllocator { + count: usize, + free: Option, + pool: alloc::vec::Vec, + allocated: std::collections::HashMap, + } + + impl TestAllocator { + pub fn new(capacity: usize) -> Self { + let max_cap = usize::try_from(Ptr::MAX).unwrap_or(usize::MAX); + let capacity = capacity.min(max_cap); + let mut pool = alloc::vec::Vec::with_capacity(capacity); + pool.push(RawNode([0xAA; 72])); + Self { count: 0, free: None, pool, allocated: Default::default() } + } + + pub fn count(&self) -> usize { self.count } + + /// Verifies that block has been allocated. Panics if it hasn’t. + fn check_allocated(&self, action: &str, ptr: Ptr) -> usize { + let adj = match self.allocated.get(&ptr.get()).copied() { + None => "unallocated", + Some(false) => "freed", + Some(true) => return usize::try_from(ptr.get()).unwrap(), + }; + panic!("Tried to {action} {adj} block at {ptr}") + } + } + + impl Allocator for TestAllocator { + fn alloc(&mut self, value: RawNode) -> Result { + let ptr = if let Some(ptr) = self.free { + // Grab node from free list + let node = &mut self.pool[ptr.get() as usize]; + let bytes = stdx::split_array_ref::<4, 68, 72>(&node.0).0; + self.free = Ptr::new(u32::from_ne_bytes(*bytes)).unwrap(); + *node = value; + ptr + } else if self.pool.len() < self.pool.capacity() { + // Grab new node + self.pool.push(value); + Ptr::new((self.pool.len() - 1) as u32).unwrap().unwrap() + } else { + // No free node to allocate + return Err(OutOfMemory); + }; + + assert!( + self.allocated.insert(ptr.get(), true) != Some(true), + "Internal error: Allocated the same block twice at {ptr}", + ); + self.count += 1; + Ok(ptr) + } + + #[track_caller] + fn get(&self, ptr: Ptr) -> RawNode { + self.pool[self.check_allocated("read", ptr)].clone() + } + + #[track_caller] + fn set(&mut self, ptr: Ptr, value: RawNode) { + let idx = self.check_allocated("read", ptr); + self.pool[idx] = value + } + + fn free(&mut self, ptr: Ptr) { + let idx = self.check_allocated("free", ptr); + self.allocated.insert(ptr.get(), false); + *stdx::split_array_mut::<4, 68, 72>(&mut self.pool[idx].0).0 = + self.free.map_or(0, |ptr| ptr.get()).to_ne_bytes(); + self.free = Some(ptr); + self.count -= 1; + } + } +} + +#[cfg(test)] +mod test_write_log { + use super::test_utils::TestAllocator; + use super::*; + use crate::hash::CryptoHash; + + fn make_allocator() -> (TestAllocator, Vec) { + let mut alloc = TestAllocator::new(100); + let ptrs = (0..10) + .map(|num| alloc.alloc(make_node(num)).unwrap()) + .collect::>(); + assert_nodes(10, &alloc, &ptrs, 0); + (alloc, ptrs) + } + + fn make_node(num: usize) -> RawNode { + let hash = CryptoHash::test(num); + let child = crate::nodes::Reference::node(None, &hash); + RawNode::branch(child, child) + } + + #[track_caller] + fn assert_nodes( + count: usize, + alloc: &TestAllocator, + ptrs: &[Ptr], + offset: usize, + ) { + assert_eq!(count, alloc.count()); + for (idx, ptr) in ptrs.iter().enumerate() { + assert_eq!( + make_node(idx + offset), + alloc.get(*ptr), + "Invalid value when reading {ptr}" + ); + } + } + + #[test] + fn test_set_commit() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for (idx, &ptr) in ptrs.iter().take(5).enumerate() { + wlog.set(ptr, make_node(idx + 10)); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + wlog.commit(); + assert_nodes(10, &alloc, &ptrs[..5], 10); + assert_nodes(10, &alloc, &ptrs[5..], 5); + } + + #[test] + fn test_set_rollback() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for (idx, &ptr) in ptrs.iter().take(5).enumerate() { + wlog.set(ptr, make_node(idx + 10)); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + core::mem::drop(wlog); + assert_nodes(10, &alloc, &ptrs, 0); + } + + #[test] + fn test_alloc_commit() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + let new_ptrs = (10..20) + .map(|num| wlog.alloc(make_node(num)).unwrap()) + .collect::>(); + assert_nodes(20, &wlog.allocator(), &ptrs, 0); + assert_nodes(20, &wlog.allocator(), &new_ptrs, 10); + wlog.commit(); + assert_nodes(20, &alloc, &ptrs, 0); + assert_nodes(20, &alloc, &new_ptrs, 10); + } + + #[test] + fn test_alloc_rollback() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + let new_ptrs = (10..20) + .map(|num| wlog.alloc(make_node(num)).unwrap()) + .collect::>(); + assert_nodes(20, &wlog.allocator(), &ptrs, 0); + assert_nodes(20, &wlog.allocator(), &new_ptrs, 10); + core::mem::drop(wlog); + assert_nodes(10, &alloc, &ptrs, 0); + } + + #[test] + fn test_free_commit() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for num in 5..10 { + wlog.free(ptrs[num]); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + wlog.commit(); + assert_nodes(5, &alloc, &ptrs[..5], 0); + } + + #[test] + fn test_free_rollback() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for num in 5..10 { + wlog.free(ptrs[num]); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + core::mem::drop(wlog); + assert_nodes(10, &alloc, &ptrs, 0); + } +} diff --git a/sealable-trie/src/nodes.rs b/sealable-trie/src/nodes.rs new file mode 100644 index 00000000..88ddae53 --- /dev/null +++ b/sealable-trie/src/nodes.rs @@ -0,0 +1,584 @@ +use crate::bits::Slice; +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::{bits, stdx}; + +#[cfg(test)] +mod stress_tests; +#[cfg(test)] +mod tests; + +pub(crate) const MAX_EXTENSION_KEY_SIZE: usize = 34; + +type Result = core::result::Result; + +/// A trie node. +/// +/// There are three types of nodes: branches, extensions and values. +/// +/// A branch node has two children which reference other nodes (both are always +/// present). +/// +/// An extension represents a path in a node which doesn’t branch. For example, +/// if trie contains key 0 and 1 then the root node will be an extension with +/// 0b0000_000 as the key and a branch node as a child. +/// +/// The space for key in extension node is limited (max 34 bytes), if longer key +/// is needed, an extension node may point at another extension node. +/// +/// A value node holds hash of the stored value at the key. Furthermore, if the +/// key is a prefix it stores a reference to another node which continues the +/// key. This reference is never a value reference. +/// +/// A node reference either points at another Node or is hash of the stored +/// value. The reference is represented by the `R` generic argument. +/// +/// [`Node`] object can be constructed either from a [`RawNode`]. +/// +/// The generic argument `P` specifies how pointers to nodes are represented and +/// `S` specifies how value being sealed or not is encoded. To represent value +/// parsed from a raw node representation, those types should be `Option` +/// and `bool` respectively. However, when dealing with proofs, pointer and +/// seal information is not available thus both of those types should be a unit +/// type. +#[derive(Clone, Copy, Debug)] +pub enum Node<'a, P = Option, S = bool> { + Branch { + /// Children of the branch. Both are always set. + children: [Reference<'a, P, S>; 2], + }, + Extension { + /// Key of the extension. + key: Slice<'a>, + /// Child node or value pointed by the extension. + child: Reference<'a, P, S>, + }, + Value { + value: ValueRef<'a, S>, + child: NodeRef<'a, P>, + }, +} + +/// Binary representation of the node as kept in the persistent storage. +/// +/// This representation is compact and includes internal details needed to +/// maintain the data-structure which shouldn’t be leaked to the clients of the +/// library and which don’t take part in hashing of the node. +// +// ```ignore +// Branch: +// A branch holds two references. Both of them are always set. Note that +// reference’s most significant bit is always zero thus the first bit of +// a node representation distinguishes whether node is a branch or not. +// +// Extension: 1000_kkkk kkkk_kooo +// `kkkk` is the length of the key in bits and `ooo` is number of most +// significant bits in to skip before getting to the key. is +// 36-byte array which holds the key extension. Only `o..o+k` bits in it +// are the actual key; others are set to zero. +// +// Value: 11s0_0000 0000_0000 0000_0000 0000_0000 +// is the hash of the stored value. `s` is zero if the value hasn’t +// been sealed, one otherwise. is a references the child node +// which points to the subtrie rooted at the key of the value. Value node +// can only point at Branch or Extension node. +// ``` +// +// A Reference is a 36-byte sequence consisting of a 4-byte pointer and +// a 32-byte hash. The most significant bit of the pointer is always zero (this +// is so that Branch nodes can be distinguished from other nodes). The second +// most significant bit is zero if the reference is a node reference and one if +// it’s a value reference. +// +// ```ignore +// Node Ref: 0b00pp_pppp pppp_pppp pppp_pppp pppp_pppp +// `ppp` is the pointer to the node. If it’s zero than the node is sealed +// the it’s not stored anywhere. +// +// Value Ref: 0b01s0_0000 0000_0000 0000_0000 0000_0000 +// `s` determines whether the value is sealed or not. If it is, it cannot be +// changed. +// ``` +// +// The actual pointer value is therefore 30-bit long. +#[derive(Clone, Copy, PartialEq, derive_more::Deref)] +#[repr(transparent)] +pub struct RawNode(pub(crate) [u8; 72]); + +/// Reference either to a node or a value as held in Branch or Extension nodes. +/// +/// See [`Node`] documentation for meaning of `P` and `S` generic arguments. +#[derive(Clone, Copy, Debug, derive_more::From, derive_more::TryInto)] +pub enum Reference<'a, P = Option, S = bool> { + Node(NodeRef<'a, P>), + Value(ValueRef<'a, S>), +} + +/// Reference to a node as held in Value node. +/// +/// See [`Node`] documentation for meaning of the `P` generic argument. +#[derive(Clone, Copy, Debug)] +pub struct NodeRef<'a, P = Option> { + pub hash: &'a CryptoHash, + pub ptr: P, +} + +/// Reference to a value as held in Value node. +/// +/// See [`Node`] documentation for meaning of the `S` generic argument. +#[derive(Clone, Copy, Debug)] +pub struct ValueRef<'a, S = bool> { + pub hash: &'a CryptoHash, + pub is_sealed: S, +} + + +// ============================================================================= +// Implementations + +impl<'a, P, S> Node<'a, P, S> { + /// Constructs a Branch node with specified children. + pub fn branch( + left: Reference<'a, P, S>, + right: Reference<'a, P, S>, + ) -> Self { + Self::Branch { children: [left, right] } + } + + /// Constructs an Extension node with given key and child. + /// + /// Note that length of the key is not checked. It’s possible to create + /// a node which cannot be encoded either in raw or proof format. For an + /// Extension node to be able to be encoded, the key’s underlying bytes + /// slice must not exceed [`MAX_EXTENSION_KEY_SIZE`] bytes. + pub fn extension(key: Slice<'a>, child: Reference<'a, P, S>) -> Self { + Self::Extension { key, child } + } + + /// Constructs a Value node with given value hash and child. + pub fn value(value: ValueRef<'a, S>, child: NodeRef<'a, P>) -> Self { + Self::Value { value, child } + } + + /// Returns a hash of the node. + /// + /// Hash changes if and only if the value of the node (if any) and any child + /// node changes. Sealing or changing pointer value in a node reference + /// doesn’t count as changing the node. + pub fn hash(&self) -> CryptoHash { + let mut buf = [0; 68]; + + fn tag_hash_hash( + buf: &mut [u8; 68], + tag: u8, + lft: &CryptoHash, + rht: &CryptoHash, + ) -> usize { + let buf = stdx::split_array_mut::<65, 3, 68>(buf).0; + let (t, rest) = stdx::split_array_mut::<1, 64, 65>(buf); + let (l, r) = stdx::split_array_mut(rest); + *t = [tag]; + *l = lft.0; + *r = rht.0; + buf.len() + } + + let len = match self { + Node::Branch { children: [left, right] } => { + // tag = 0b0000_00xy where x and y indicate whether left and + // right children respectively are value references. + let tag = (u8::from(left.is_value()) << 1) | + u8::from(right.is_value()); + tag_hash_hash(&mut buf, tag, left.hash(), right.hash()) + } + Node::Value { value, child } => { + tag_hash_hash(&mut buf, 0xC0, &value.hash, &child.hash) + } + Node::Extension { key, child } => { + let key_buf = stdx::split_array_mut::<36, 32, 68>(&mut buf).0; + // tag = 0b100v_0000 where v indicates whether the child is + // a value reference. + let tag = 0x80 | (u8::from(child.is_value()) << 4); + if let Some(len) = key.encode_into(key_buf, tag) { + buf[len..len + 32].copy_from_slice(child.hash().as_slice()); + len + 32 + } else { + return hash_extension_slow_path(*key, child); + } + } + }; + CryptoHash::digest(&buf[..len]) + } +} + +impl<'a> Node<'a> { + /// Builds raw representation of given node. + /// + /// Returns an error if this node is an Extension with a key of invalid + /// length (either empty or too long). + pub fn encode(&self) -> Result { + match self { + Node::Branch { children: [left, right] } => { + Ok(RawNode::branch(*left, *right)) + } + Node::Extension { key, child } => { + RawNode::extension(*key, *child).ok_or(()) + } + Node::Value { value, child } => Ok(RawNode::value(*value, *child)), + } + } +} + +/// Hashes an Extension node with oversized key. +/// +/// Normally, this is never called since we should calculate hashes of nodes +/// whose keys fit in the [`MAX_EXTENSION_KEY_SIZE`] limit. However, to +/// avoid having to handle errors we use this slow path to calculate hashes +/// for nodes with longer keys. +#[cold] +fn hash_extension_slow_path( + key: bits::Slice, + child: &Reference, +) -> CryptoHash { + let mut builder = CryptoHash::builder(); + // tag = 0b100v_0000 where v indicates whether the child is a value + // reference. + let tag = 0x80 | (u8::from(child.is_value()) << 4); + key.write_into(|bytes| builder.update(bytes), tag); + builder.update(child.hash().as_slice()); + builder.build() +} + +impl RawNode { + /// Constructs a Branch node with specified children. + pub fn branch(left: Reference, right: Reference) -> Self { + let mut res = Self([0; 72]); + let (lft, rht) = res.halfs_mut(); + *lft = left.encode(); + *rht = right.encode(); + res + } + + /// Constructs an Extension node with given key and child. + /// + /// Fails and returns `None` if the key is empty or its underlying bytes + /// slice is too long. The slice must not exceed [`MAX_EXTENSION_KEY_SIZE`] + /// to be valid. + pub fn extension(key: Slice, child: Reference) -> Option { + let mut res = Self([0; 72]); + let (lft, rht) = res.halfs_mut(); + key.encode_into(lft, 0x80)?; + *rht = child.encode(); + Some(res) + } + + /// Constructs a Value node with given value hash and child. + pub fn value(value: ValueRef, child: NodeRef) -> Self { + let mut res = Self([0; 72]); + let (lft, rht) = res.halfs_mut(); + *lft = Reference::Value(value).encode(); + lft[0] |= 0x80; + *rht = Reference::Node(child).encode(); + res + } + + /// Decodes raw node into a [`Node`]. + /// + /// In debug builds panics if `node` holds malformed representation, i.e. if + /// any unused bits (which must be cleared) are set. + // TODO(mina86): Convert debug_assertions to the method returning Result. + pub fn decode(&self) -> Node { + let (left, right) = self.halfs(); + let right = Reference::from_raw(right, false); + // `>> 6` to grab the two most significant bits only. + let tag = self.first() >> 6; + if tag == 0 || tag == 1 { + // Branch + Node::Branch { children: [Reference::from_raw(left, false), right] } + } else if tag == 2 { + // Extension + let key = Slice::decode(left, 0x80).unwrap_or_else(|| { + panic!("Failed decoding raw: {self:?}"); + }); + Node::Extension { key, child: right } + } else { + // Value + let (num, value) = stdx::split_array_ref::<4, 32, 36>(left); + let num = u32::from_be_bytes(*num); + debug_assert_eq!( + 0xC000_0000, + num & !0x2000_0000, + "Failed decoding raw node: {self:?}", + ); + let value = ValueRef::new(num & 0x2000_0000 != 0, value.into()); + let child = right.try_into().unwrap_or_else(|_| { + debug_assert!(false, "Failed decoding raw node: {self:?}"); + NodeRef::new(None, &CryptoHash::DEFAULT) + }); + Node::Value { value, child } + } + } + + /// Returns the first byte in the raw representation. + fn first(&self) -> u8 { self.0[0] } + + /// Splits the raw byte representation in two halfs. + fn halfs(&self) -> (&[u8; 36], &[u8; 36]) { + stdx::split_array_ref::<36, 36, 72>(&self.0) + } + + /// Splits the raw byte representation in two halfs. + fn halfs_mut(&mut self) -> (&mut [u8; 36], &mut [u8; 36]) { + stdx::split_array_mut::<36, 36, 72>(&mut self.0) + } +} + +impl<'a, P, S> Reference<'a, P, S> { + /// Returns whether the reference is to a node. + pub fn is_node(&self) -> bool { matches!(self, Self::Node(_)) } + + /// Returns whether the reference is to a value. + pub fn is_value(&self) -> bool { matches!(self, Self::Value(_)) } + + /// Returns node’s or value’s hash depending on type of reference. + /// + /// Use [`Self::is_value`] and [`Self::is_proof`] to check whether + fn hash(&self) -> &'a CryptoHash { + match self { + Self::Node(node) => node.hash, + Self::Value(value) => value.hash, + } + } +} + +impl<'a> Reference<'a> { + /// Creates a new reference pointing at given node. + #[inline] + pub fn node(ptr: Option, hash: &'a CryptoHash) -> Self { + Self::Node(NodeRef::new(ptr, hash)) + } + + /// Creates a new reference pointing at value with given hash. + #[inline] + pub fn value(is_sealed: bool, hash: &'a CryptoHash) -> Self { + Self::Value(ValueRef::new(is_sealed, hash)) + } + + /// Returns whether the reference is to a sealed node or value. + #[inline] + pub fn is_sealed(&self) -> bool { + match self { + Self::Node(node) => node.ptr.is_none(), + Self::Value(value) => value.is_sealed, + } + } + + /// Parses bytes to form a raw node reference representation. + /// + /// Assumes that the bytes are trusted. I.e. doesn’t verify that the most + /// significant bit is zero or that if second bit is one than pointer value + /// must be zero. + /// + /// In debug builds, panics if `bytes` has non-canonical representation, + /// i.e. any unused bits are set. `value_high_bit` in this case determines + /// whether for value reference the most significant bit should be set or + /// not. This is to facilitate decoding Value nodes. The argument is + /// ignored in builds with debug assertions disabled. + // TODO(mina86): Convert debug_assertions to the method returning Result. + fn from_raw(bytes: &'a [u8; 36], value_high_bit: bool) -> Self { + let (ptr, hash) = stdx::split_array_ref::<4, 32, 36>(bytes); + let ptr = u32::from_be_bytes(*ptr); + let hash = hash.into(); + if ptr & 0x4000_0000 == 0 { + // The two most significant bits must be zero. + debug_assert_eq!( + 0, + ptr & 0xC000_0000, + "Failed decoding Reference: {bytes:?}" + ); + let ptr = Ptr::new_truncated(ptr); + Self::Node(NodeRef { ptr, hash }) + } else { + // * The most significant bit is set only if value_high_bit is true. + // * The second most significant bit (so 0b4000_0000) is always set. + // * The third most significant bit (so 0b2000_0000) specifies + // whether value is sealed. + debug_assert_eq!( + 0x4000_0000 | (u32::from(value_high_bit) << 31), + ptr & !0x2000_0000, + "Failed decoding Reference: {bytes:?}" + ); + let is_sealed = ptr & 0x2000_0000 != 0; + Self::Value(ValueRef { is_sealed, hash }) + } + } + + /// Encodes the node reference into the buffer. + fn encode(&self) -> [u8; 36] { + let (num, hash) = match self { + Self::Node(node) => { + (node.ptr.map_or(0, |ptr| ptr.get()), node.hash) + } + Self::Value(value) => { + (0x4000_0000 | (u32::from(value.is_sealed) << 29), value.hash) + } + }; + let mut buf = [0; 36]; + let (left, right) = stdx::split_array_mut::<4, 32, 36>(&mut buf); + *left = num.to_be_bytes(); + *right = hash.into(); + buf + } +} + +impl<'a> Reference<'a, (), ()> { + pub fn new(is_value: bool, hash: &'a CryptoHash) -> Self { + match is_value { + false => NodeRef::new((), hash).into(), + true => ValueRef::new((), hash).into(), + } + } +} + +impl<'a, P> NodeRef<'a, P> { + /// Constructs a new node reference. + #[inline] + pub fn new(ptr: P, hash: &'a CryptoHash) -> Self { Self { ptr, hash } } +} + +impl<'a> NodeRef<'a, Option> { + /// Returns sealed version of the reference. The hash remains unchanged. + #[inline] + pub fn sealed(self) -> Self { Self { ptr: None, hash: self.hash } } +} + +impl<'a, S> ValueRef<'a, S> { + /// Constructs a new node reference. + #[inline] + pub fn new(is_sealed: S, hash: &'a CryptoHash) -> Self { + Self { is_sealed, hash } + } +} + +impl<'a> ValueRef<'a, bool> { + /// Returns sealed version of the reference. The hash remains unchanged. + #[inline] + pub fn sealed(self) -> Self { Self { is_sealed: true, hash: self.hash } } +} + + +// ============================================================================= +// PartialEq + +// Are those impls dumb? Yes, they absolutely are. However, when I used +// #[derive(PartialEq)] I run into lifetime issues. +// +// My understanding is that derive would create implementation for the same +// lifetime on LHS and RHS types (e.g. `impl<'a> PartialEq> for +// Ref<'a>`). As a result, when comparing two objects Rust would try to match +// their lifetimes which wasn’t always possible. + +impl<'a, 'b, P, S> core::cmp::PartialEq> for Node<'a, P, S> +where + P: PartialEq, + S: PartialEq, +{ + fn eq(&self, rhs: &Node<'b, P, S>) -> bool { + match (self, rhs) { + ( + Node::Branch { children: lhs }, + Node::Branch { children: rhs }, + ) => lhs == rhs, + ( + Node::Extension { key: lhs_key, child: lhs_child }, + Node::Extension { key: rhs_key, child: rhs_child }, + ) => lhs_key == rhs_key && lhs_child == rhs_child, + ( + Node::Value { value: lhs_value, child: lhs }, + Node::Value { value: rhs_value, child: rhs }, + ) => lhs_value == rhs_value && lhs == rhs, + _ => false, + } + } +} + +impl<'a, 'b, P, S> core::cmp::PartialEq> + for Reference<'a, P, S> +where + P: PartialEq, + S: PartialEq, +{ + fn eq(&self, rhs: &Reference<'b, P, S>) -> bool { + match (self, rhs) { + (Reference::Node(lhs), Reference::Node(rhs)) => lhs == rhs, + (Reference::Value(lhs), Reference::Value(rhs)) => lhs == rhs, + _ => false, + } + } +} + +impl<'a, 'b, P> core::cmp::PartialEq> for NodeRef<'a, P> +where + P: PartialEq, +{ + fn eq(&self, rhs: &NodeRef<'b, P>) -> bool { + self.ptr == rhs.ptr && self.hash == rhs.hash + } +} + +impl<'a, 'b, S> core::cmp::PartialEq> for ValueRef<'a, S> +where + S: PartialEq, +{ + fn eq(&self, rhs: &ValueRef<'b, S>) -> bool { + self.is_sealed == rhs.is_sealed && self.hash == rhs.hash + } +} + +// ============================================================================= +// Formatting + +impl core::fmt::Debug for RawNode { + fn fmt(&self, fmtr: &mut core::fmt::Formatter) -> core::fmt::Result { + fn write_raw_key( + fmtr: &mut core::fmt::Formatter, + separator: &str, + bytes: &[u8; 36], + ) -> core::fmt::Result { + let (tag, key) = stdx::split_array_ref::<2, 34, 36>(bytes); + write!(fmtr, "{separator}{:04x}", u16::from_be_bytes(*tag))?; + write_binary(fmtr, ":", key) + } + + fn write_raw_ptr( + fmtr: &mut core::fmt::Formatter, + separator: &str, + bytes: &[u8; 36], + ) -> core::fmt::Result { + let (ptr, hash) = stdx::split_array_ref::<4, 32, 36>(bytes); + let ptr = u32::from_be_bytes(*ptr); + let hash = <&CryptoHash>::from(hash); + write!(fmtr, "{separator}{ptr:08x}:{hash}") + } + + let (left, right) = self.halfs(); + if self.first() & 0xC0 == 0x80 { + write_raw_key(fmtr, "", left) + } else { + write_raw_ptr(fmtr, "", left) + }?; + write_raw_ptr(fmtr, ":", right) + } +} + +fn write_binary( + fmtr: &mut core::fmt::Formatter, + mut separator: &str, + bytes: &[u8], +) -> core::fmt::Result { + for byte in bytes { + write!(fmtr, "{separator}{byte:02x}")?; + separator = "_"; + } + Ok(()) +} diff --git a/sealable-trie/src/nodes/stress_tests.rs b/sealable-trie/src/nodes/stress_tests.rs new file mode 100644 index 00000000..ba73d74e --- /dev/null +++ b/sealable-trie/src/nodes/stress_tests.rs @@ -0,0 +1,134 @@ +//! Random stress tests. They generate random data and perform round-trip +//! conversion checking if they result in the same output. +//! +//! The test may be slow, especially when run under MIRI. Number of iterations +//! it performs can be controlled by STRESS_TEST_ITERATIONS environment +//! variable. + +use pretty_assertions::assert_eq; + +use crate::memory::Ptr; +use crate::nodes::{self, Node, NodeRef, RawNode, Reference, ValueRef}; +use crate::test_utils::get_iteration_count; +use crate::{bits, stdx}; + +/// Generates random raw representation and checks decode→encode round-trip. +#[test] +fn stress_test_raw_encoding_round_trip() { + let mut rng = rand::thread_rng(); + let mut raw = RawNode([0; 72]); + for _ in 0..get_iteration_count() { + gen_random_raw_node(&mut rng, &mut raw.0); + let node = raw.decode(); + // Test RawNode→Node→RawNode round trip conversion. + assert_eq!(Ok(raw), node.encode(), "node: {node:?}"); + } +} + +/// Generates a random raw node representation in canonical representation. +fn gen_random_raw_node(rng: &mut impl rand::Rng, bytes: &mut [u8; 72]) { + fn make_ref_canonical(bytes: &mut [u8]) { + if bytes[0] & 0x40 == 0 { + // Node reference. Pointer can be non-zero. + bytes[0] &= !0x80; + } else { + // Value reference. Pointer must be zero but key is_sealed flag: + // 0b01s0_0000 + bytes[..4].copy_from_slice(&0x6000_0000_u32.to_be_bytes()); + } + } + + rng.fill(&mut bytes[..]); + let tag = bytes[0] >> 6; + if tag == 0 || tag == 1 { + // Branch. + make_ref_canonical(&mut bytes[..36]); + make_ref_canonical(&mut bytes[36..]); + } else if tag == 2 { + // Extension. Key must be valid and the most significant bit of + // the child must be zero. For the former it’s easiest to just + // regenerate random data. + + // Random length and offset for the key. + let offset = rng.gen::() % 8; + let max_length = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16; + let length = rng.gen_range(1..=max_length - u16::from(offset)); + let tag = 0x8000 | (length << 3) | u16::from(offset); + bytes[..2].copy_from_slice(&tag.to_be_bytes()[..]); + + // Clear unused bits in the key. The easiest way to do it is by using + // bits::Slice. + let mut tmp = [0; 36]; + bits::Slice::new(&bytes[2..36], offset, length) + .unwrap() + .encode_into(&mut tmp, 0) + .unwrap(); + bytes[0..36].copy_from_slice(&tmp); + + make_ref_canonical(&mut bytes[36..]); + } else { + // Value. Most bits in the first four bytes must be zero and child must + // be a node reference. + bytes[0] &= 0xE0; + bytes[1] = 0; + bytes[2] = 0; + bytes[3] = 0; + bytes[36] &= !0xC0; + } +} + +// ============================================================================= + +/// Generates random node and tests encode→decode round trips. +#[test] +fn stress_test_node_encoding_round_trip() { + let mut rng = rand::thread_rng(); + let mut buf = [0; 66]; + for _ in 0..get_iteration_count() { + let node = gen_random_node(&mut rng, &mut buf); + + let raw = super::tests::raw_from_node(&node); + assert_eq!(node, raw.decode(), "Failed decoding Raw: {raw:?}"); + } +} + +/// Generates a random Node. +fn gen_random_node<'a>( + rng: &mut impl rand::Rng, + buf: &'a mut [u8; 66], +) -> Node<'a> { + fn rand_ref<'a>( + rng: &mut impl rand::Rng, + hash: &'a [u8; 32], + ) -> Reference<'a> { + let num = rng.gen::(); + if num < 0x8000_0000 { + Reference::node(Ptr::new(num).ok().flatten(), hash.into()) + } else { + Reference::value(num & 1 != 0, hash.into()) + } + } + + rng.fill(&mut buf[..]); + let (key, right) = stdx::split_array_ref::<34, 32, 66>(buf); + let (_, left) = stdx::split_array_ref::<2, 32, 34>(key); + match rng.gen_range(0..3) { + 0 => Node::branch(rand_ref(rng, &left), rand_ref(rng, &right)), + 1 => { + let offset = rng.gen::() % 8; + let max_length = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16; + let length = rng.gen_range(1..=max_length - u16::from(offset)); + let key = bits::Slice::new(&key[..], offset, length).unwrap(); + Node::extension(key, rand_ref(rng, &right)) + } + 2 => { + let num = rng.gen::(); + let is_sealed = num & 0x8000_0000 != 0; + let value = ValueRef::new(is_sealed, left.into()); + let ptr = Ptr::new(num & 0x7FFF_FFFF).ok().flatten(); + let child = NodeRef::new(ptr, right.into()); + Node::value(value, child) + } + _ => unreachable!(), + } +} diff --git a/sealable-trie/src/nodes/tests.rs b/sealable-trie/src/nodes/tests.rs new file mode 100644 index 00000000..e1e8dd5e --- /dev/null +++ b/sealable-trie/src/nodes/tests.rs @@ -0,0 +1,251 @@ +use base64::engine::general_purpose::STANDARD as BASE64_ENGINE; +use base64::Engine; +use pretty_assertions::assert_eq; + +use crate::bits; +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, RawNode, Reference, ValueRef}; + +const DEAD: Ptr = match Ptr::new(0xDEAD) { + Ok(Some(ptr)) => ptr, + _ => panic!(), +}; +const BEEF: Ptr = match Ptr::new(0xBEEF) { + Ok(Some(ptr)) => ptr, + _ => panic!(), +}; + +const ONE: CryptoHash = CryptoHash([1; 32]); +const TWO: CryptoHash = CryptoHash([2; 32]); + +/// Converts `Node` into `RawNode` while also checking inverse conversion. +/// +/// Converts `Node` into `RawNode` and then back into `Node`. Panics if the +/// first and last objects aren’t equal. Returns the raw node. +#[track_caller] +pub(super) fn raw_from_node(node: &Node) -> RawNode { + let raw = node + .encode() + .unwrap_or_else(|()| panic!("Failed encoding node as raw: {node:?}")); + assert_eq!( + *node, + raw.decode(), + "Node → RawNode → Node gave different result:\n Raw: {raw:?}" + ); + raw +} + +/// Checks raw encoding of given node. +/// +/// 1. Encodes `node` into raw node node representation and compares the result +/// with expected `want` slices. +/// 2. Verifies Node→RawNode→Node round-trip conversion. +/// 3. Verifies that hash of the node equals the one provided. +/// 4. If node is an Extension, checks if slow path hash calculation produces +/// the same hash. +#[track_caller] +fn check_node_encoding(node: Node, want: [u8; 72], want_hash: &str) { + let raw = raw_from_node(&node); + assert_eq!(want, raw.0, "Unexpected raw representation"); + assert_eq!(node, RawNode(want).decode(), "Bad Raw→Node conversion"); + + let want_hash = BASE64_ENGINE.decode(want_hash).unwrap(); + let want_hash = <&[u8; 32]>::try_from(want_hash.as_slice()).unwrap(); + let want_hash = CryptoHash::from(*want_hash); + assert_eq!(want_hash, node.hash(), "Unexpected hash of {node:?}"); + + if let Node::Extension { key, child } = node { + let got = super::hash_extension_slow_path(key, &child); + assert_eq!(want_hash, got, "Unexpected slow path hash of {node:?}"); + } +} + +#[test] +#[rustfmt::skip] +fn test_branch_encoding() { + // Branch with two node children. + check_node_encoding(Node::Branch { + children: [ + Reference::node(Some(DEAD), &ONE), + Reference::node(None, &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0xDE, 0xAD, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "MvstRBYGfFv/BkI+GHFK04hDZde4FtNKd7M1J9hDhiQ="); + + check_node_encoding(Node::Branch { + children: [ + Reference::node(None, &ONE), + Reference::node(Some(DEAD), &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0xDE, 0xAD, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "MvstRBYGfFv/BkI+GHFK04hDZde4FtNKd7M1J9hDhiQ="); + + // Branch with first child being a node and second being a value. + check_node_encoding(Node::Branch { + children: [ + Reference::node(Some(DEAD), &ONE), + Reference::value(false, &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0xDE, 0xAD, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x40, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "szHabsSdRUfZlCpnJ+USP2m+1aC5esFxz7/WIBQx/Po="); + + check_node_encoding(Node::Branch { + children: [ + Reference::node(None, &ONE), + Reference::value(true, &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x60, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "szHabsSdRUfZlCpnJ+USP2m+1aC5esFxz7/WIBQx/Po="); + + // Branch with first child being a value and second being a node. + check_node_encoding(Node::Branch { + children: [ + Reference::value(true, &ONE), + Reference::node(Some(BEEF), &TWO), + ], + }, [ + /* ptr1: */ 0x60, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0xBE, 0xEF, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "LGZgDJ1qtRlrhOX7OJQBVprw9OvP2sXOdj9Ow0xMQ18="); + + check_node_encoding(Node::Branch { + children: [ + Reference::value(false, &ONE), + Reference::node(None, &TWO), + ], + }, [ + /* ptr1: */ 0x40, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "LGZgDJ1qtRlrhOX7OJQBVprw9OvP2sXOdj9Ow0xMQ18="); + + // Branch with both children being values. + check_node_encoding(Node::Branch { + children: [ + Reference::value(false, &ONE), + Reference::value(true, &TWO), + ], + }, [ + /* ptr1: */ 0x40, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x60, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "O+AyRw5cqn52zppsf3w7xebru6xQ50qGvI7JgFQBNnE="); + + check_node_encoding(Node::Branch { + children: [ + Reference::value(true, &ONE), + Reference::value(false, &TWO), + ], + }, [ + /* ptr1: */ 0x60, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x40, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "O+AyRw5cqn52zppsf3w7xebru6xQ50qGvI7JgFQBNnE="); +} + +#[test] +#[rustfmt::skip] +fn test_extension_encoding() { + // Extension pointing at a node + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 5, 25).unwrap(), + child: Reference::node(Some(DEAD), &ONE), + }, [ + /* tag: */ 0x80, 0xCD, + /* key: */ 0x07, 0xFF, 0xFF, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* ptr: */ 0, 0, 0xDE, 0xAD, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "JnUeS7R/A/mp22ytw/gzGLu24zHArCmVZJoMm4bcqGY="); + + // Extension pointing at a sealed node + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 5, 25).unwrap(), + child: Reference::node(None, &ONE), + }, [ + /* tag: */ 0x80, 0xCD, + /* key: */ 0x07, 0xFF, 0xFF, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* ptr: */ 0, 0, 0, 0, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "JnUeS7R/A/mp22ytw/gzGLu24zHArCmVZJoMm4bcqGY="); + + // Extension pointing at a value + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 4, 248).unwrap(), + child: Reference::value(false, &ONE), + }, [ + /* tag: */ 0x87, 0xC4, + /* key: */ 0x0F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + /* */ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xF0, + /* */ 0x00, 0x00, + /* ptr: */ 0x40, 0, 0, 0, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "uU9GlH+fEQAnezn3HWuvo/ZSBIhuSkuE2IGjhUFdC04="); + + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 4, 248).unwrap(), + child: Reference::value(true, &ONE), + }, [ + /* tag: */ 0x87, 0xC4, + /* key: */ 0x0F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + /* */ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xF0, + /* */ 0x00, 0x00, + /* ptr: */ 0x60, 0, 0, 0, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "uU9GlH+fEQAnezn3HWuvo/ZSBIhuSkuE2IGjhUFdC04="); +} + +#[test] +#[rustfmt::skip] +fn test_value_encoding() { + check_node_encoding(Node::Value { + value: ValueRef::new(false, &ONE), + child: NodeRef::new(Some(BEEF), &TWO), + }, [ + /* tag: */ 0xC0, 0, 0, 0, + /* vhash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr: */ 0, 0, 0xBE, 0xEF, + /* chash: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "1uLWUNQTQCTNVP3Wle2aK1vQlrOPXf9EC0J6TLl4hrY="); + + check_node_encoding(Node::Value { + value: ValueRef::new(true, &ONE), + child: NodeRef::new(Some(BEEF), &TWO), + }, [ + /* tag: */ 0xE0, 0, 0, 0, + /* vhash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr: */ 0, 0, 0xBE, 0xEF, + /* chash: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "1uLWUNQTQCTNVP3Wle2aK1vQlrOPXf9EC0J6TLl4hrY="); + + check_node_encoding(Node::Value { + value: ValueRef::new(true, &ONE), + child: NodeRef::new(None, &TWO), + }, [ + /* tag: */ 0xE0, 0, 0, 0, + /* vhash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr: */ 0, 0, 0, 0, + /* chash: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "1uLWUNQTQCTNVP3Wle2aK1vQlrOPXf9EC0J6TLl4hrY="); +} diff --git a/sealable-trie/src/proof.rs b/sealable-trie/src/proof.rs new file mode 100644 index 00000000..922a884c --- /dev/null +++ b/sealable-trie/src/proof.rs @@ -0,0 +1,399 @@ +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::num::NonZeroU16; + +use crate::bits; +use crate::hash::CryptoHash; +use crate::nodes::{Node, NodeRef, Reference, ValueRef}; + +/// A proof of a membership or non-membership of a key. +/// +/// The proof doesn’t include the key or value (in case of existence proofs). +/// It’s caller responsibility to pair proof with correct key and value. +#[derive(Clone, Debug, derive_more::From)] +pub enum Proof { + Positive(Membership), + Negative(NonMembership), +} + +/// A proof of a membership of a key. +#[derive(Clone, Debug)] +pub struct Membership(Vec); + +/// A proof of a membership of a key. +#[derive(Clone, Debug)] +pub struct NonMembership(Option>, Vec); + +/// A single item in a proof corresponding to a node in the trie. +#[derive(Clone, Debug)] +pub(crate) enum Item { + /// A Branch node where the other child is given reference. + Branch(OwnedRef), + /// An Extension node whose key has given length in bits. + Extension(NonZeroU16), + /// A Value node. + Value(CryptoHash), +} + +/// For non-membership proofs, description of the condition at which the lookup +/// failed. +#[derive(Clone, Debug)] +pub(crate) enum Actual { + /// A Branch node that has been reached at given key. + Branch(OwnedRef, OwnedRef), + + /// Length of the lookup key remaining after reaching given Extension node + /// whose key doesn’t match the lookup key. + Extension(u16, Box<[u8]>, OwnedRef), + + /// Length of the lookup key remaining after reaching a value reference with + /// given value hash. + LookupKeyLeft(NonZeroU16, CryptoHash), +} + +/// A reference to value or node. +#[derive(Clone, Debug)] +pub(crate) struct OwnedRef { + /// Whether the reference is for a value (rather than value). + is_value: bool, + /// Hash of the node or value the reference points at. + hash: CryptoHash, +} + +/// Builder for the proof. +pub(crate) struct Builder(Vec); + +impl Proof { + /// Verifies that this object proves membership or non-membership of given + /// key. + /// + /// If `value_hash` is `None`, verifies a non-membership proof. That is, + /// that this object proves that given `root_hash` there’s no value at + /// specified `key`. + /// + /// Otherwise, verifies a membership-proof. That is, that this object + /// proves that given `root_hash` there’s given `value_hash` stored at + /// specified `key`. + pub fn verify( + &self, + root_hash: &CryptoHash, + key: &[u8], + value_hash: Option<&CryptoHash>, + ) -> bool { + match (self, value_hash) { + (Self::Positive(proof), Some(hash)) => { + proof.verify(root_hash, key, hash) + } + (Self::Negative(proof), None) => proof.verify(root_hash, key), + _ => false, + } + } + + /// Creates a non-membership proof for cases when trie is empty. + pub(crate) fn empty_trie() -> Proof { + NonMembership(None, Vec::new()).into() + } + + /// Creates a builder which allows creation of proofs. + pub(crate) fn builder() -> Builder { Builder(Vec::new()) } +} + +impl Membership { + /// Verifies that this object proves membership of a given key. + pub fn verify( + &self, + root_hash: &CryptoHash, + key: &[u8], + value_hash: &CryptoHash, + ) -> bool { + if *root_hash == crate::trie::EMPTY_TRIE_ROOT { + false + } else if let Some(key) = bits::Slice::from_bytes(key) { + let want = OwnedRef::value(value_hash.clone()); + verify_impl(root_hash, key, want, &self.0).is_some() + } else { + false + } + } +} + +impl NonMembership { + /// Verifies that this object proves non-membership of a given key. + pub fn verify(&self, root_hash: &CryptoHash, key: &[u8]) -> bool { + if *root_hash == crate::trie::EMPTY_TRIE_ROOT { + true + } else if let Some((key, want)) = self.get_reference(key) { + verify_impl(root_hash, key, want, &self.1).is_some() + } else { + false + } + } + + /// Figures out reference to prove. + /// + /// For non-membership proofs, the proofs include the actual node that has + /// been found while looking up the key. This translates that information + /// into a key and reference that the rest of the commitment needs to prove. + fn get_reference<'a>( + &self, + key: &'a [u8], + ) -> Option<(bits::Slice<'a>, OwnedRef)> { + let mut key = bits::Slice::from_bytes(key)?; + match self.0.as_deref()? { + Actual::Branch(lft, rht) => { + // When traversing the trie, we’ve reached a Branch node at the + // lookup key. Lookup key is therefore a prefix of an existing + // value but there’s no value stored at it. + // + // We’re converting non-membership proof into proof that at key + // the given branch Node exists. + let node = Node::Branch { children: [lft.into(), rht.into()] }; + Some((key, OwnedRef::to(node))) + } + + Actual::Extension(left, key_buf, child) => { + // When traversing the trie, we’ve reached an Extension node + // whose key wasn’t a prefix of a lookup key. This could be + // because the extension key was longer or because some bits + // didn’t match. + // + // The first element specifies how many bits of the lookup key + // were left in it when the Extension node has been reached. + // + // We’re converting non-membership proof into proof that at + // shortened key the given Extension node exists. + let suffix = key.pop_back_slice(*left)?; + let ext_key = bits::Slice::decode(&key_buf[..], 0)?; + if suffix.starts_with(ext_key) { + // If key in the Extension node is a prefix of the + // remaining suffix of the lookup key, the proof is + // invalid. + None + } else { + let node = Node::Extension { + key: ext_key, + child: Reference::from(child), + }; + Some((key, OwnedRef::to(node))) + } + } + + Actual::LookupKeyLeft(len, hash) => { + // When traversing the trie, we’ve encountered a value reference + // before the lookup key has finished. `len` determines how + // many bits of the lookup key were not processed. `hash` is + // the value that was found at key[..(key.len() - len)] key. + // + // We’re converting non-membership proof into proof that at + // key[..(key.len() - len)] a `hash` value is stored. + key.pop_back_slice(len.get())?; + Some((key, OwnedRef::value(hash.clone()))) + } + } + } +} + +fn verify_impl( + root_hash: &CryptoHash, + mut key: bits::Slice, + mut want: OwnedRef, + proof: &[Item], +) -> Option<()> { + for item in proof { + let node = match item { + Item::Value(child) if want.is_value => Node::Value { + value: ValueRef::new((), &want.hash), + child: NodeRef::new((), child), + }, + Item::Value(child) => Node::Value { + value: ValueRef::new((), child), + child: NodeRef::new((), &want.hash), + }, + + Item::Branch(child) => { + let us = Reference::from(&want); + let them = child.into(); + let children = match key.pop_back()? { + false => [us, them], + true => [them, us], + }; + Node::Branch { children } + } + + Item::Extension(length) => Node::Extension { + key: key.pop_back_slice(length.get())?, + child: Reference::from(&want), + }, + }; + want = OwnedRef::to(node); + } + + // If we’re here we’ve reached root hash according to the proof. Check the + // key is empty and that hash we’ve calculated is the actual root. + (key.is_empty() && !want.is_value && *root_hash == want.hash).then_some(()) +} + +impl Item { + /// Constructs a new proof item corresponding to given branch. + /// + /// `us` indicates which branch is ours and which one is theirs. When + /// verifying a proof our hash can be computed thus the proof item will only + /// include their hash. + pub fn branch(us: bool, children: &[Reference; 2]) -> Self { + Self::Branch((&children[1 - usize::from(us)]).clone().into()) + } + + pub fn extension(length: u16) -> Option { + NonZeroU16::new(length).map(Self::Extension) + } +} + +impl Builder { + /// Adds a new item to the proof. + pub fn push(&mut self, item: Item) { self.0.push(item); } + + /// Reverses order of items in the builder. + /// + /// The items in the proof must be ordered from the node with the value + /// first with root node as the last entry. When traversing the trie nodes + /// may end up being added in opposite order. In those cases, this can be + /// used to reverse order of items so they are in the correct order. + pub fn reversed(mut self) -> Self { + self.0.reverse(); + self + } + + /// Constructs a new membership proof from added items. + pub fn build>(self) -> T { T::from(Membership(self.0)) } + + /// Constructs a new non-membership proof from added items and given + /// ‘actual’ entry. + /// + /// The actual describes what was actually found when traversing the trie. + pub fn negative>(self, actual: Actual) -> T { + T::from(NonMembership(Some(Box::new(actual)), self.0)) + } + + /// Creates a new non-membership proof after lookup reached a Branch node. + /// + /// If a Branch node has been found at the lookup key (rather than Value + /// node), this method allows creation of a non-membership proof. + /// `children` specifies children of the encountered Branch node. + pub fn reached_branch, P, S>( + self, + children: [Reference; 2], + ) -> T { + let [lft, rht] = children; + self.negative(Actual::Branch(lft.into(), rht.into())) + } + + /// Creates a new non-membership proof after lookup reached a Extension node. + /// + /// If a Extension node has been found which doesn’t match corresponding + /// portion of the lookup key (the extension key may be too long or just not + /// match it), this method allows creation of a non-membership proof. + /// + /// `left` is the number of bits left in the lookup key at the moment the + /// Extension node was encountered. `key` and `child` are corresponding + /// fields of the extension node. + pub fn reached_extension>( + self, + left: u16, + key: bits::Slice, + child: Reference, + ) -> T { + let mut buf = [0; 36]; + let len = key.encode_into(&mut buf, 0).unwrap(); + let ext_key = buf[..len].to_vec().into_boxed_slice(); + self.negative(Actual::Extension(left, ext_key, child.into())) + } + + /// Creates a new non-membership proof after lookup reached a value + /// reference. + /// + /// If the lookup key hasn’t terminated yet but a value reference has been + /// found, , this method allows creation of a non-membership proof. + /// + /// `left` is the number of bits left in the lookup key at the moment the + /// reference was encountered. `value` is the hash of the value from the + /// reference. + pub fn lookup_key_left>( + self, + left: NonZeroU16, + value: CryptoHash, + ) -> T { + self.negative(Actual::LookupKeyLeft(left, value)) + } +} + +impl OwnedRef { + /// Creates a reference pointing at node with given hash. + fn node(hash: CryptoHash) -> Self { Self { is_value: false, hash } } + /// Creates a reference pointing at value with given hash. + fn value(hash: CryptoHash) -> Self { Self { is_value: true, hash } } + /// Creates a reference pointing at given node. + fn to(node: Node) -> Self { Self::node(node.hash()) } +} + +impl<'a, P, S> From<&'a Reference<'a, P, S>> for OwnedRef { + fn from(rf: &'a Reference<'a, P, S>) -> OwnedRef { + let (is_value, hash) = match rf { + Reference::Node(node) => (false, node.hash.clone()), + Reference::Value(value) => (true, value.hash.clone()), + }; + Self { is_value, hash } + } +} + +impl<'a, P, S> From> for OwnedRef { + fn from(rf: Reference<'a, P, S>) -> OwnedRef { Self::from(&rf) } +} + +impl<'a> From<&'a OwnedRef> for Reference<'a, (), ()> { + fn from(rf: &'a OwnedRef) -> Self { Self::new(rf.is_value, &rf.hash) } +} + +#[test] +fn test_simple_success() { + let mut trie = crate::trie::Trie::test(1000); + let some_hash = CryptoHash::test(usize::MAX); + + for (idx, key) in ["foo", "bar", "baz", "qux"].into_iter().enumerate() { + let hash = CryptoHash::test(idx); + + assert_eq!( + Ok(()), + trie.set(key.as_bytes(), &hash), + "Failed setting {key} → {hash}", + ); + + let (got, proof) = trie.prove(key.as_bytes()).unwrap(); + assert_eq!(Some(hash.clone()), got, "Failed getting {key}"); + assert!( + proof.verify(trie.hash(), key.as_bytes(), Some(&hash)), + "Failed verifying {key} → {hash} get proof: {proof:?}", + ); + + assert!( + !proof.verify(trie.hash(), key.as_bytes(), None), + "Unexpectedly succeeded {key} → (none) proof: {proof:?}", + ); + assert!( + !proof.verify(trie.hash(), key.as_bytes(), Some(&some_hash)), + "Unexpectedly succeeded {key} → {some_hash} proof: {proof:?}", + ); + } + + for key in ["Foo", "fo", "ba", "bay", "foobar"] { + let (got, proof) = trie.prove(key.as_bytes()).unwrap(); + assert_eq!(None, got, "Unexpected result when getting {key}"); + assert!( + proof.verify(trie.hash(), key.as_bytes(), None), + "Failed verifying {key} → (none) proof: {proof:?}", + ); + assert!( + !proof.verify(trie.hash(), key.as_bytes(), Some(&some_hash)), + "Unexpectedly succeeded {key} → {some_hash} proof: {proof:?}", + ); + } +} diff --git a/sealable-trie/src/stdx.rs b/sealable-trie/src/stdx.rs new file mode 100644 index 00000000..5431d01d --- /dev/null +++ b/sealable-trie/src/stdx.rs @@ -0,0 +1,53 @@ +/// Splits `&[u8; L + R]` into `(&[u8; L], &[u8; R])`. +pub(crate) fn split_array_ref< + const L: usize, + const R: usize, + const N: usize, +>( + xs: &[u8; N], +) -> (&[u8; L], &[u8; R]) { + let () = AssertEqSum::::OK; + + let (left, right) = xs.split_at(L); + (left.try_into().unwrap(), right.try_into().unwrap()) +} + +/// Splits `&mut [u8; L + R]` into `(&mut [u8; L], &mut [u8; R])`. +pub(crate) fn split_array_mut< + const L: usize, + const R: usize, + const N: usize, +>( + xs: &mut [u8; N], +) -> (&mut [u8; L], &mut [u8; R]) { + let () = AssertEqSum::::OK; + + let (left, right) = xs.split_at_mut(L); + (left.try_into().unwrap(), right.try_into().unwrap()) +} + +/// Splits `&[u8]` into `(&[u8; L], &[u8])`. Returns `None` if input is too +/// shorter. +pub(crate) fn split_at(xs: &[u8]) -> Option<(&[u8; L], &[u8])> { + if xs.len() < L { + return None; + } + let (head, tail) = xs.split_at(L); + Some((head.try_into().unwrap(), tail)) +} + +/// Splits `&[u8]` into `(&[u8], &[u8; R])`. Returns `None` if input is too +/// shorter. +#[allow(dead_code)] +pub(crate) fn rsplit_at( + xs: &[u8], +) -> Option<(&[u8], &[u8; R])> { + let (head, tail) = xs.split_at(xs.len().checked_sub(R)?); + Some((head, tail.try_into().unwrap())) +} + +/// Asserts, at compile time, that `A + B == S`. +struct AssertEqSum; +impl AssertEqSum { + const OK: () = assert!(S == A + B); +} diff --git a/sealable-trie/src/test_utils.rs b/sealable-trie/src/test_utils.rs new file mode 100644 index 00000000..49fc317c --- /dev/null +++ b/sealable-trie/src/test_utils.rs @@ -0,0 +1,23 @@ +/// Reads `STRESS_TEST_ITERATIONS` environment variable to determine how many +/// iterations random tests should try. +/// +/// The variable is used by stress tests which generate random data to verify +/// invariant. By default they run hundred thousand iterations. The +/// aforementioned environment variable allows that number to be changed +/// (including to zero which effectively disables such tests). +pub(crate) fn get_iteration_count() -> usize { + use core::str::FromStr; + match std::env::var_os("STRESS_TEST_ITERATIONS") { + None => 100_000, + Some(val) => usize::from_str(val.to_str().unwrap()).unwrap(), + } +} + +/// Returns zero if dividend is zero otherwise `max(dividend / divisor, 1)`. +pub(crate) fn div_max_1(dividind: usize, divisor: usize) -> usize { + if dividind == 0 { + 0 + } else { + 1.max(dividind / divisor) + } +} diff --git a/sealable-trie/src/trie.rs b/sealable-trie/src/trie.rs new file mode 100644 index 00000000..396430c8 --- /dev/null +++ b/sealable-trie/src/trie.rs @@ -0,0 +1,317 @@ +use core::num::NonZeroU16; + +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, Reference}; +use crate::{bits, memory, proof}; + +mod seal; +mod set; +#[cfg(test)] +mod tests; + +/// Root trie hash if the trie is empty. +pub const EMPTY_TRIE_ROOT: CryptoHash = CryptoHash::DEFAULT; + +/// A Merkle Patricia Trie with sealing/pruning feature. +/// +/// The trie is designed to work in situations where space is constrained. To +/// that effect, it imposes certain limitations and implements feature which +/// help reduce its size. +/// +/// In the abstract, the trie is a regular Merkle Patricia Trie which allows +/// storing arbitrary (key, value) pairs. However: +/// +/// 1. The trie doesn’t actually store values but only their hashes. (Another +/// way to think about it is that all values are 32-byte long byte slices). +/// It’s assumed that clients store values in another location and use this +/// data structure only as a witness. Even though it doesn’t contain the +/// values it can generate proof of existence or non-existence of keys. +/// +/// 2. The trie allows values to be sealed. A hash of a sealed value can no +/// longer be accessed even though in abstract sense the value still resides +/// in the trie. That is, sealing a value doesn’t affect the state root +/// hash and old proofs for the value continue to be valid. +/// +/// Nodes of sealed values are removed from the trie to save storage. +/// Furthermore, if a children of an internal node have been sealed, that +/// node becomes sealed as well. For example, if keys `a` and `b` has +/// both been sealed, than branch node above them becomes sealed as well. +/// +/// To take most benefits from sealing, it’s best to seal consecutive keys. +/// For example, sealing keys `a`, `b`, `c` and `d` will seal their parents +/// as well while sealing keys `a`, `c`, `e` and `g` will leave their parents +/// unsealed and thus kept in the trie. +/// +/// 3. The trie is designed to work with a pool allocator and supports keeping +/// at most 2³⁰-2 nodes. Sealed values don’t count towards this limit since +/// they aren’t stored. In any case, this should be plenty since fully +/// balanced binary tree with that many nodes allows storing 500K keys. +/// +/// 4. Keys are limited to 8191 bytes (technically 2¹⁶-1 bits but there’s no +/// interface for keys which hold partial bytes). It would be possible to +/// extend this limit but 8k bytes should be plenty for any reasonable usage. +/// +/// As an optimisation to take advantage of trie’s internal structure, it’s +/// best to keep keys up to 36-byte long. Or at least, to keep common key +/// prefixes to be at most 36-byte long. For example, a trie which has +/// a single value at a key whose length is withing 36 bytes has a single +/// node however if that key is longer than 36 bytes the trie needs at least +/// two nodes. +pub struct Trie { + /// Pointer to the root node. `None` if the trie is empty or the root node + /// has been sealed. + root_ptr: Option, + + /// Hash of the root node; [`EMPTY_TRIE_ROOT`] if trie is empty. + root_hash: CryptoHash, + + /// Allocator used to access and allocate nodes. + alloc: A, +} + +/// Possible errors when reading or modifying the trie. +#[derive(Copy, Clone, PartialEq, Eq, Debug, derive_more::Display)] +pub enum Error { + #[display(fmt = "Tried to access empty key")] + EmptyKey, + #[display(fmt = "Key longer than 8191 bytes")] + KeyTooLong, + #[display(fmt = "Tried to access sealed node")] + Sealed, + #[display(fmt = "Value not found")] + NotFound, + #[display(fmt = "Not enough space")] + OutOfMemory, +} + +impl From for Error { + fn from(_: memory::OutOfMemory) -> Self { Self::OutOfMemory } +} + +type Result = ::core::result::Result; + +macro_rules! proof { + ($proof:ident push $item:expr) => { + $proof.as_mut().map(|proof| proof.push($item)); + }; + ($proof:ident rev) => { + $proof.map(|builder| builder.reversed().build()) + }; + ($proof:ident rev .$func:ident $($tt:tt)*) => { + $proof.map(|builder| builder.reversed().$func $($tt)*) + }; +} + +impl Trie { + /// Creates a new trie using given allocator. + pub fn new(alloc: A) -> Self { + Self { root_ptr: None, root_hash: EMPTY_TRIE_ROOT, alloc } + } + + /// Returns hash of the root node. + pub fn hash(&self) -> &CryptoHash { &self.root_hash } + + /// Returns whether the trie is empty. + pub fn is_empty(&self) -> bool { self.root_hash == EMPTY_TRIE_ROOT } + + /// Retrieves value at given key. + /// + /// Returns `None` if there’s no value at given key. Returns an error if + /// the value (or its ancestor) has been sealed. + pub fn get(&self, key: &[u8]) -> Result> { + let (value, _) = self.get_impl(key, true)?; + Ok(value) + } + + /// Retrieves value at given key and provides proof of the result. + /// + /// Returns `None` if there’s no value at given key. Returns an error if + /// the value (or its ancestor) has been sealed. + pub fn prove( + &self, + key: &[u8], + ) -> Result<(Option, proof::Proof)> { + let (value, proof) = self.get_impl(key, true)?; + Ok((value, proof.unwrap())) + } + + fn get_impl( + &self, + key: &[u8], + include_proof: bool, + ) -> Result<(Option, Option)> { + let mut key = bits::Slice::from_bytes(key).ok_or(Error::KeyTooLong)?; + if self.root_hash == EMPTY_TRIE_ROOT { + let proof = include_proof.then(|| proof::Proof::empty_trie()); + return Ok((None, proof)); + } + + let mut proof = include_proof.then(|| proof::Proof::builder()); + let mut node_ptr = self.root_ptr; + let mut node_hash = self.root_hash.clone(); + loop { + let node = self.alloc.get(node_ptr.ok_or(Error::Sealed)?); + let node = node.decode(); + debug_assert_eq!(node_hash, node.hash()); + + let child = match node { + Node::Branch { children } => { + if let Some(us) = key.pop_front() { + proof!(proof push proof::Item::branch(us, &children)); + children[usize::from(us)] + } else { + let proof = proof!(proof rev.reached_branch(children)); + return Ok((None, proof)); + } + } + + Node::Extension { key: ext_key, child } => { + if key.strip_prefix(ext_key) { + proof!(proof push proof::Item::extension(ext_key.len()).unwrap()); + child + } else { + let proof = proof!(proof rev.reached_extension(key.len(), ext_key, child)); + return Ok((None, proof)); + } + } + + Node::Value { value, child } => { + if value.is_sealed { + return Err(Error::Sealed); + } else if key.is_empty() { + proof!(proof push proof::Item::Value(child.hash.clone())); + let proof = proof!(proof rev.build()); + return Ok((Some(value.hash.clone()), proof)); + } else { + proof!(proof push proof::Item::Value(value.hash.clone())); + node_ptr = child.ptr; + node_hash = child.hash.clone(); + continue; + } + } + }; + + match child { + Reference::Node(node) => { + node_ptr = node.ptr; + node_hash = node.hash.clone(); + } + Reference::Value(value) => { + return if value.is_sealed { + Err(Error::Sealed) + } else if let Some(len) = NonZeroU16::new(key.len()) { + let proof = proof!(proof rev.lookup_key_left(len, value.hash.clone())); + Ok((None, proof)) + } else { + let proof = proof!(proof rev.build()); + Ok((Some(value.hash.clone()), proof)) + }; + } + }; + } + } + + /// Inserts a new value hash at given key. + /// + /// Sets value hash at given key to given to the provided one. If the value + /// (or one of its ancestors) has been sealed the operation fails with + /// [`Error::Sealed`] error. + /// + /// If `proof` is specified, stores proof nodes into the provided vector. + // TODO(mina86): Add set_with_proof as well as set_and_seal and + // set_and_seal_with_proof. + pub fn set(&mut self, key: &[u8], value_hash: &CryptoHash) -> Result<()> { + let (ptr, hash) = (self.root_ptr, self.root_hash.clone()); + let key = bits::Slice::from_bytes(key).ok_or(Error::KeyTooLong)?; + let (ptr, hash) = + set::SetContext::new(&mut self.alloc, key, value_hash) + .set(ptr, &hash)?; + self.root_ptr = Some(ptr); + self.root_hash = hash; + Ok(()) + } + + /// Seals value at given key as well as all descendant values. + /// + /// Once value is sealed, its hash can no longer be retrieved nor can it be + /// changed. Sealing a value seals the entire subtrie rooted at the key + /// (that is, if key `foo` is sealed, `foobar` is also sealed). + /// + /// However, it’s impossible to seal a subtrie unless there’s a value stored + /// at the key. For example, if trie contains key `foobar` only, neither + /// `foo` nor `qux` keys can be sealed. In those cases, function returns + /// an error. + // TODO(mina86): Add seal_with_proof. + pub fn seal(&mut self, key: &[u8]) -> Result<()> { + let key = bits::Slice::from_bytes(key).ok_or(Error::KeyTooLong)?; + if self.root_hash == EMPTY_TRIE_ROOT { + return Err(Error::NotFound); + } + + let seal = seal::SealContext::new(&mut self.alloc, key) + .seal(NodeRef::new(self.root_ptr, &self.root_hash))?; + if seal { + self.root_ptr = None; + } + Ok(()) + } + + /// Prints the trie. Used for testing and debugging only. + #[cfg(test)] + pub(crate) fn print(&self) { + use std::println; + + if self.root_hash == EMPTY_TRIE_ROOT { + println!("(empty)"); + } else { + self.print_impl(NodeRef::new(self.root_ptr, &self.root_hash), 0); + } + } + + #[cfg(test)] + fn print_impl(&self, nref: NodeRef, depth: usize) { + use std::{print, println}; + + let print_ref = |rf, depth| match rf { + Reference::Node(node) => self.print_impl(node, depth), + Reference::Value(value) => { + let is_sealed = if value.is_sealed { " (sealed)" } else { "" }; + println!("{:depth$}value {}{}", "", value.hash, is_sealed) + } + }; + + print!("{:depth$}«{}»", "", nref.hash); + let ptr = if let Some(ptr) = nref.ptr { + ptr + } else { + println!(" (sealed)"); + return; + }; + match self.alloc.get(ptr).decode() { + Node::Branch { children } => { + println!(" Branch"); + print_ref(children[0], depth + 2); + print_ref(children[1], depth + 2); + } + Node::Extension { key, child } => { + println!(" Extension {key}"); + print_ref(child, depth + 2); + } + Node::Value { value, child } => { + let is_sealed = if value.is_sealed { " (sealed)" } else { "" }; + println!(" Value {}{}", value.hash, is_sealed); + print_ref(Reference::from(child), depth + 2); + } + } + } +} + + +#[cfg(test)] +impl Trie { + /// Creates a test trie using a TestAllocator with given capacity. + pub(crate) fn test(capacity: usize) -> Self { + Self::new(memory::test_utils::TestAllocator::new(capacity)) + } +} diff --git a/sealable-trie/src/trie/seal.rs b/sealable-trie/src/trie/seal.rs new file mode 100644 index 00000000..50e68fdb --- /dev/null +++ b/sealable-trie/src/trie/seal.rs @@ -0,0 +1,175 @@ +use alloc::vec::Vec; + +use super::{Error, Result}; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, RawNode, Reference, ValueRef}; +use crate::{bits, memory}; + +/// Context for [`Trie::seal`] operation. +pub(super) struct SealContext<'a, A> { + /// Part of the key yet to be traversed. + /// + /// It starts as the key user provided and as trie is traversed bits are + /// removed from its front. + key: bits::Slice<'a>, + + /// Allocator used to retrieve and free nodes. + alloc: &'a mut A, +} + +impl<'a, A: memory::Allocator> SealContext<'a, A> { + pub(super) fn new(alloc: &'a mut A, key: bits::Slice<'a>) -> Self { + Self { key, alloc } + } + + /// Traverses the trie starting from node `ptr` to find node at context’s + /// key and seals it. + /// + /// Returns `true` if node at `ptr` has been sealed. This lets caller know + /// that `ptr` has been freed and it has to update references to it. + pub(super) fn seal(&mut self, nref: NodeRef) -> Result { + let ptr = nref.ptr.ok_or(Error::Sealed)?; + let node = self.alloc.get(ptr); + let node = node.decode(); + debug_assert_eq!(*nref.hash, node.hash()); + + let result = match node { + Node::Branch { children } => self.seal_branch(children), + Node::Extension { key, child } => self.seal_extension(key, child), + Node::Value { value, child } => self.seal_value(value, child), + }?; + + match result { + SealResult::Replace(node) => { + self.alloc.set(ptr, node); + Ok(false) + } + SealResult::Free => { + self.alloc.free(ptr); + Ok(true) + } + SealResult::Done => Ok(false), + } + } + + fn seal_branch( + &mut self, + mut children: [Reference; 2], + ) -> Result { + let side = match self.key.pop_front() { + None => return Err(Error::NotFound), + Some(bit) => usize::from(bit), + }; + match self.seal_child(children[side])? { + None => Ok(SealResult::Done), + Some(_) if children[1 - side].is_sealed() => Ok(SealResult::Free), + Some(child) => { + children[side] = child; + let node = RawNode::branch(children[0], children[1]); + Ok(SealResult::Replace(node)) + } + } + } + + fn seal_extension( + &mut self, + ext_key: bits::Slice, + child: Reference, + ) -> Result { + if !self.key.strip_prefix(ext_key) { + return Err(Error::NotFound); + } + Ok(if let Some(child) = self.seal_child(child)? { + let node = RawNode::extension(ext_key, child).unwrap(); + SealResult::Replace(node) + } else { + SealResult::Done + }) + } + + fn seal_value( + &mut self, + value: ValueRef, + child: NodeRef, + ) -> Result { + if value.is_sealed { + Err(Error::Sealed) + } else if self.key.is_empty() { + prune(self.alloc, child.ptr); + Ok(SealResult::Free) + } else if self.seal(child)? { + let child = NodeRef::new(None, child.hash); + let node = RawNode::value(value, child); + Ok(SealResult::Replace(node)) + } else { + Ok(SealResult::Done) + } + } + + fn seal_child<'b>( + &mut self, + child: Reference<'b>, + ) -> Result>> { + match child { + Reference::Node(node) => Ok(if self.seal(node)? { + Some(Reference::Node(node.sealed())) + } else { + None + }), + Reference::Value(value) => { + if value.is_sealed { + Err(Error::Sealed) + } else if self.key.is_empty() { + Ok(Some(value.sealed().into())) + } else { + Err(Error::NotFound) + } + } + } + } +} + +enum SealResult { + Free, + Replace(RawNode), + Done, +} + +/// Frees node and all its descendants from the allocator. +fn prune(alloc: &mut impl memory::Allocator, ptr: Option) { + let mut ptr = match ptr { + Some(ptr) => ptr, + None => return, + }; + let mut queue = Vec::new(); + loop { + let children = get_children(&alloc.get(ptr)); + alloc.free(ptr); + match children { + (None, None) => match queue.pop() { + Some(p) => ptr = p, + None => break, + }, + (Some(p), None) | (None, Some(p)) => ptr = p, + (Some(lhs), Some(rht)) => { + queue.push(lhs); + ptr = rht + } + } + } +} + +fn get_children(node: &RawNode) -> (Option, Option) { + fn get_ptr(child: Reference) -> Option { + match child { + Reference::Node(node) => node.ptr, + Reference::Value { .. } => None, + } + } + + match node.decode() { + Node::Branch { children: [lft, rht] } => (get_ptr(lft), get_ptr(rht)), + Node::Extension { child, .. } => (get_ptr(child), None), + Node::Value { child, .. } => (child.ptr, None), + } +} diff --git a/sealable-trie/src/trie/set.rs b/sealable-trie/src/trie/set.rs new file mode 100644 index 00000000..2a4cbd39 --- /dev/null +++ b/sealable-trie/src/trie/set.rs @@ -0,0 +1,344 @@ +use super::{Error, Result}; +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, RawNode, Reference, ValueRef}; +use crate::{bits, memory}; + +/// Context for [`Trie::set`] operation. +pub(super) struct SetContext<'a, A: memory::Allocator> { + /// Part of the key yet to be traversed. + /// + /// It starts as the key user provided and as trie is traversed bits are + /// removed from its front. + key: bits::Slice<'a>, + + /// Hash to insert into the trie. + value_hash: &'a CryptoHash, + + /// Allocator used to allocate new nodes. + wlog: memory::WriteLog<'a, A>, +} + +impl<'a, A: memory::Allocator> SetContext<'a, A> { + pub(super) fn new( + alloc: &'a mut A, + key: bits::Slice<'a>, + value_hash: &'a CryptoHash, + ) -> Self { + let wlog = memory::WriteLog::new(alloc); + Self { key, value_hash, wlog } + } + + /// Inserts value hash into the trie. + pub(super) fn set( + mut self, + root_ptr: Option, + root_hash: &CryptoHash, + ) -> Result<(Ptr, CryptoHash)> { + let res = (|| { + if let Some(ptr) = root_ptr { + // Trie is non-empty, handle normally. + self.handle(NodeRef { ptr: Some(ptr), hash: root_hash }) + } else if *root_hash != super::EMPTY_TRIE_ROOT { + // Trie is sealed (it’s not empty but ptr is None). + Err(Error::Sealed) + } else if let OwnedRef::Node(ptr, hash) = self.insert_value()? { + // Trie is empty and we’ve just inserted Extension leading to + // the value. + Ok((ptr, hash)) + } else { + // If the key was non-empty, self.insert_value would have + // returned a node reference. If it didn’t, it means key was + // empty which is an error condition. + Err(Error::EmptyKey) + } + })(); + if res.is_ok() { + self.wlog.commit(); + } + res + } + + /// Inserts value into the trie starting at node pointed by given reference. + fn handle(&mut self, nref: NodeRef) -> Result<(Ptr, CryptoHash)> { + let nref = (nref.ptr.ok_or(Error::Sealed)?, nref.hash); + let node = self.wlog.allocator().get(nref.0); + let node = node.decode(); + debug_assert_eq!(*nref.1, node.hash()); + match node { + Node::Branch { children } => self.handle_branch(nref, children), + Node::Extension { key, child } => { + self.handle_extension(nref, key, child) + } + Node::Value { value, child } => { + self.handle_value(nref, value, child) + } + } + } + + /// Inserts value assuming current node is a Branch with given children. + fn handle_branch( + &mut self, + nref: (Ptr, &CryptoHash), + children: [Reference<'_>; 2], + ) -> Result<(Ptr, CryptoHash)> { + let bit = if let Some(bit) = self.key.pop_front() { + bit + } else { + // If Key is empty, insert a new Node value with this node as + // a child. + return self.alloc_value_node(self.value_hash, nref.0, nref.1); + }; + + // Figure out which direction the key leads and update the node + // in-place. + let owned_ref = self.handle_reference(children[usize::from(bit)])?; + let child = owned_ref.to_ref(); + let children = + if bit { [children[0], child] } else { [child, children[1]] }; + Ok(self.set_node(nref.0, RawNode::branch(children[0], children[1]))) + // let child = owned_ref.to_ref(); + // let (left, right) = if bit == 0 { + // (child, children[1]) + // } else { + // (children[0], child) + // }; + + // // Update the node in place with the new child. + // Ok((nref.0, self.set_node(nref.0, RawNode::branch(left, right)))) + } + + /// Inserts value assuming current node is an Extension. + fn handle_extension( + &mut self, + nref: (Ptr, &CryptoHash), + mut key: bits::Slice<'_>, + child: Reference<'_>, + ) -> Result<(Ptr, CryptoHash)> { + // If key is empty, insert a new Value node with this node as a child. + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Val(val, ⫯) + // ↓ ↓ + // C Ext(key, ⫯) + // ↓ + // C + if self.key.is_empty() { + return self.alloc_value_node(self.value_hash, nref.0, nref.1); + } + + let prefix = self.key.forward_common_prefix(&mut key); + let mut suffix = key; + + // The entire extension key matched. Handle the child reference and + // update the node. + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Ext(key, ⫯) + // ↓ ↓ + // C C′ + if suffix.is_empty() { + let owned_ref = self.handle_reference(child)?; + let node = RawNode::extension(prefix, owned_ref.to_ref()).unwrap(); + return Ok(self.set_node(nref.0, node)); + } + + let our = if let Some(bit) = self.key.pop_front() { + usize::from(bit) + } else { + // Our key is done. We need to split the Extension node into + // two and insert Value node in between. + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Ext(prefix, ⫯) + // ↓ ↓ + // C Value(val, ⫯) + // ↓ + // Ext(suffix, ⫯) + // ↓ + // C + let (ptr, hash) = self.alloc_extension_node(suffix, child)?; + let (ptr, hash) = + self.alloc_value_node(self.value_hash, ptr, &hash)?; + let child = Reference::node(Some(ptr), &hash); + let node = RawNode::extension(prefix, child).unwrap(); + return Ok(self.set_node(nref.0, node)); + }; + + let theirs = usize::from(suffix.pop_front().unwrap()); + assert_ne!(our, theirs); + + // We need to split the Extension node with a Branch node in between. + // One child of the Branch will lead to our value; the other will lead + // to subtrie that the Extension points to. + // + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Ext(prefix, ⫯) + // ↓ ↓ + // C Branch(⫯, ⫯) + // ↓ ↓ + // V Ext(suffix, ⫯) + // ↓ + // C + // + // However, keep in mind that each of prefix or suffix may be empty. If + // that’s the case, corresponding Extension node is not created. + let our_ref = self.insert_value()?; + let their_hash: CryptoHash; + let their_ref = if let Some(node) = RawNode::extension(suffix, child) { + let (ptr, hash) = self.alloc_node(node)?; + their_hash = hash; + Reference::node(Some(ptr), &their_hash) + } else { + child + }; + let mut children = [their_ref; 2]; + children[our] = our_ref.to_ref(); + let node = RawNode::branch(children[0], children[1]); + let (ptr, hash) = self.set_node(nref.0, node); + + match RawNode::extension(prefix, Reference::node(Some(ptr), &hash)) { + Some(node) => self.alloc_node(node), + None => Ok((ptr, hash)), + } + } + + /// Inserts value assuming current node is an unsealed Value. + fn handle_value( + &mut self, + nref: (Ptr, &CryptoHash), + value: ValueRef, + child: NodeRef, + ) -> Result<(Ptr, CryptoHash)> { + if value.is_sealed { + return Err(Error::Sealed); + } + let node = if self.key.is_empty() { + RawNode::value(ValueRef::new(false, self.value_hash), child) + } else { + let (ptr, hash) = self.handle(child)?; + RawNode::value(value, NodeRef::new(Some(ptr), &hash)) + }; + Ok(self.set_node(nref.0, node)) + } + + /// Handles a reference which can either point at a node or a value. + /// + /// Returns a new value for the reference updating it such that it points at + /// the subtrie updated with the inserted value. + fn handle_reference(&mut self, child: Reference<'_>) -> Result { + match child { + Reference::Node(node) => { + // Handle node references recursively. We cannot special handle + // our key being empty because we need to handle cases where the + // reference points at a Value node correctly. + self.handle(node).map(|(p, h)| OwnedRef::Node(p, h)) + } + Reference::Value(value) => { + if value.is_sealed { + return Err(Error::Sealed); + } + // It’s a value reference so we just need to update it + // accordingly. One tricky thing is that we need to insert + // Value node with the old hash if our key isn’t empty. + match self.insert_value()? { + rf @ OwnedRef::Value(_) => Ok(rf), + OwnedRef::Node(p, h) => { + let child = NodeRef::new(Some(p), &h); + let node = RawNode::value(value, child); + self.alloc_node(node).map(|(p, h)| OwnedRef::Node(p, h)) + } + } + } + } + } + + /// Inserts the value into the trie and returns reference to it. + /// + /// If key is empty, doesn’t insert any nodes and instead returns a value + /// reference to the value. + /// + /// Otherwise, inserts one or more Extension nodes (depending on the length + /// of the key) and returns reference to the first ancestor node. + fn insert_value(&mut self) -> Result { + let mut ptr: Option = None; + let mut hash = self.value_hash.clone(); + for chunk in self.key.chunks().rev() { + let child = match ptr { + None => Reference::value(false, &hash), + Some(_) => Reference::node(ptr, &hash), + }; + let (p, h) = self.alloc_extension_node(chunk, child)?; + ptr = Some(p); + hash = h; + } + + Ok(if let Some(ptr) = ptr { + // We’ve updated some nodes. Insert node reference to the first + // one. + OwnedRef::Node(ptr, hash) + } else { + // ptr being None means that the above loop never run which means + // self.key is empty. We just need to return value reference. + OwnedRef::Value(hash) + }) + } + + /// A convenience method which allocates a new Extension node and sets it to + /// given value. + /// + /// **Panics** if `key` is empty or too long. + fn alloc_extension_node( + &mut self, + key: bits::Slice<'_>, + child: Reference<'_>, + ) -> Result<(Ptr, CryptoHash)> { + self.alloc_node(RawNode::extension(key, child).unwrap()) + } + + /// A convenience method which allocates a new Value node and sets it to + /// given value. + fn alloc_value_node( + &mut self, + value_hash: &CryptoHash, + ptr: Ptr, + hash: &CryptoHash, + ) -> Result<(Ptr, CryptoHash)> { + let value = ValueRef::new(false, value_hash); + let child = NodeRef::new(Some(ptr), hash); + self.alloc_node(RawNode::value(value, child)) + } + + /// Sets value of a node cell at given address and returns its hash. + fn set_node(&mut self, ptr: Ptr, node: RawNode) -> (Ptr, CryptoHash) { + let hash = node.decode().hash(); + self.wlog.set(ptr, node); + (ptr, hash) + } + + /// Allocates a new node and sets it to given value. + fn alloc_node(&mut self, node: RawNode) -> Result<(Ptr, CryptoHash)> { + let hash = node.decode().hash(); + let ptr = self.wlog.alloc(node)?; + Ok((ptr, hash)) + } +} + +enum OwnedRef { + Node(Ptr, CryptoHash), + Value(CryptoHash), +} + +impl OwnedRef { + fn to_ref(&self) -> Reference { + match self { + Self::Node(ptr, hash) => Reference::node(Some(*ptr), &hash), + Self::Value(hash) => Reference::value(false, &hash), + } + } +} diff --git a/sealable-trie/src/trie/tests.rs b/sealable-trie/src/trie/tests.rs new file mode 100644 index 00000000..a88d634b --- /dev/null +++ b/sealable-trie/src/trie/tests.rs @@ -0,0 +1,179 @@ +use std::collections::HashMap; +use std::println; + +use rand::Rng; + +use crate::hash::CryptoHash; +use crate::memory::test_utils::TestAllocator; + +fn do_test_inserts<'a>( + keys: impl IntoIterator, + verbose: bool, +) -> TestTrie { + let keys = keys.into_iter(); + let count = keys.size_hint().1.unwrap_or(1000).saturating_mul(4); + let mut trie = TestTrie::new(count); + for key in keys { + trie.set(key, verbose) + } + trie +} + +#[test] +fn test_msb_difference() { do_test_inserts([&[0][..], &[0x80][..]], true); } + +#[test] +fn test_sequence() { + do_test_inserts( + b"0123456789:;<=>?".iter().map(core::slice::from_ref), + true, + ); +} + +#[test] +fn test_2byte_extension() { + do_test_inserts([&[123, 40][..], &[134, 233][..]], true); +} + +#[test] +fn test_prefix() { + let key = b"xy"; + do_test_inserts([&key[..], &key[..1]], true); + do_test_inserts([&key[..1], &key[..]], true); +} + +#[test] +fn test_seal() { + let mut trie = do_test_inserts( + b"0123456789:;<=>?".iter().map(core::slice::from_ref), + true, + ); + + for b in b'0'..=b'?' { + trie.seal(&[b], true); + } +} + +#[test] +fn stress_test() { + struct RandKeys<'a> { + buf: &'a mut [u8; 35], + rng: rand::rngs::ThreadRng, + } + + impl<'a> Iterator for RandKeys<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + let len = self.rng.gen_range(1..self.buf.len()); + let key = &mut self.buf[..len]; + self.rng.fill(key); + let key = &key[..]; + // Transmute lifetimes. This is probably not sound in general but + // it works for our needs in this test. + unsafe { core::mem::transmute(key) } + } + } + + let count = crate::test_utils::get_iteration_count(); + let count = crate::test_utils::div_max_1(count, 100); + let keys = RandKeys { buf: &mut [0; 35], rng: rand::thread_rng() }; + do_test_inserts(keys.take(count), false); +} + +#[derive(Eq, Ord)] +struct Key { + len: u8, + buf: [u8; 35], +} + +impl Key { + fn as_bytes(&self) -> &[u8] { &self.buf[..usize::from(self.len)] } +} + +impl core::cmp::PartialEq for Key { + fn eq(&self, other: &Self) -> bool { self.as_bytes() == other.as_bytes() } +} + +impl core::cmp::PartialOrd for Key { + fn partial_cmp(&self, other: &Self) -> Option { + self.as_bytes().partial_cmp(other.as_bytes()) + } +} + +impl core::hash::Hash for Key { + fn hash(&self, state: &mut H) { + self.as_bytes().hash(state) + } +} + +impl core::fmt::Debug for Key { + fn fmt(&self, fmtr: &mut core::fmt::Formatter) -> core::fmt::Result { + self.as_bytes().fmt(fmtr) + } +} + +struct TestTrie { + trie: super::Trie, + mapping: HashMap, + count: usize, +} + +impl TestTrie { + pub fn new(count: usize) -> Self { + Self { + trie: super::Trie::test(count), + mapping: Default::default(), + count: 0, + } + } + + pub fn set(&mut self, key: &[u8], verbose: bool) { + assert!(key.len() <= 35); + let key = Key { + len: key.len() as u8, + buf: { + let mut buf = [0; 35]; + buf[..key.len()].copy_from_slice(key); + buf + }, + }; + + let value = self.next_value(); + println!("{}Inserting {key:?}", if verbose { "\n" } else { "" }); + self.trie + .set(key.as_bytes(), &value) + .unwrap_or_else(|err| panic!("Failed setting ‘{key:?}’: {err}")); + self.mapping.insert(key, value); + if verbose { + self.trie.print(); + } + for (key, value) in self.mapping.iter() { + let key = key.as_bytes(); + let got = self.trie.get(key).unwrap_or_else(|err| { + panic!("Failed getting ‘{key:?}’: {err}") + }); + assert_eq!(Some(value), got.as_ref(), "Invalid value at ‘{key:?}’"); + } + } + + pub fn seal(&mut self, key: &[u8], verbose: bool) { + println!("{}Sealing {key:?}", if verbose { "\n" } else { "" }); + self.trie + .seal(key) + .unwrap_or_else(|err| panic!("Failed sealing ‘{key:?}’: {err}")); + if verbose { + self.trie.print(); + } + assert_eq!( + Err(super::Error::Sealed), + self.trie.get(key), + "Unexpectedly can read ‘{key:?}’ after sealing" + ) + } + + fn next_value(&mut self) -> CryptoHash { + self.count += 1; + CryptoHash::test(self.count) + } +}