From 0d0928b4b2ee8b83d70ee6a34824f97d2295e69a Mon Sep 17 00:00:00 2001 From: Michal Nazarewicz Date: Mon, 14 Aug 2023 20:54:22 +0200 Subject: [PATCH] Initial sealable trie implementation (#1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As per the documentation, the sealable trie data structure is a Merkle Patricia Trie with an additional feature where values can be sealed such that they can no longer be changed but also can no longer be accessed. The second property allows the trie to remove nodes keeping sealed values thus reducing storage requirements for the trie. As often is the case, there’s room to add more tests to the code, but for the large part the trie is functional. --- .github/workflows/master.yml | 61 ++ .rustfmt.toml | 26 +- Cargo.toml | 19 + deny.toml | 12 + sealable-trie/Cargo.toml | 14 + sealable-trie/src/bits.rs | 1021 +++++++++++++++++++++++ sealable-trie/src/hash.rs | 214 +++++ sealable-trie/src/lib.rs | 17 + sealable-trie/src/memory.rs | 408 +++++++++ sealable-trie/src/nodes.rs | 584 +++++++++++++ sealable-trie/src/nodes/stress_tests.rs | 134 +++ sealable-trie/src/nodes/tests.rs | 251 ++++++ sealable-trie/src/proof.rs | 399 +++++++++ sealable-trie/src/stdx.rs | 53 ++ sealable-trie/src/test_utils.rs | 23 + sealable-trie/src/trie.rs | 317 +++++++ sealable-trie/src/trie/seal.rs | 175 ++++ sealable-trie/src/trie/set.rs | 344 ++++++++ sealable-trie/src/trie/tests.rs | 179 ++++ 19 files changed, 4238 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/master.yml create mode 100644 Cargo.toml create mode 100644 deny.toml create mode 100644 sealable-trie/Cargo.toml create mode 100644 sealable-trie/src/bits.rs create mode 100644 sealable-trie/src/hash.rs create mode 100644 sealable-trie/src/lib.rs create mode 100644 sealable-trie/src/memory.rs create mode 100644 sealable-trie/src/nodes.rs create mode 100644 sealable-trie/src/nodes/stress_tests.rs create mode 100644 sealable-trie/src/nodes/tests.rs create mode 100644 sealable-trie/src/proof.rs create mode 100644 sealable-trie/src/stdx.rs create mode 100644 sealable-trie/src/test_utils.rs create mode 100644 sealable-trie/src/trie.rs create mode 100644 sealable-trie/src/trie/seal.rs create mode 100644 sealable-trie/src/trie/set.rs create mode 100644 sealable-trie/src/trie/tests.rs diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml new file mode 100644 index 00000000..1b061340 --- /dev/null +++ b/.github/workflows/master.yml @@ -0,0 +1,61 @@ +--- +name: PR Checks +on: + pull_request: + branches: + - '*' + push: + branches: + - master + +jobs: + misc: + name: Miscellaneous checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + components: rustfmt + + - name: Check formatting + run: cargo fmt --all --check + + - name: Install cargo-deny + run: cargo install cargo-deny + + - name: Check bans + run: cargo-deny --all-features check bans + + stable: + name: Rust stable + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: rustfmt + + - name: Run tests + run: cargo test + + nightly: + name: Rust nightly + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + components: miri + + - name: Run tests + run: cargo test + + - name: Run tests with Miri + run: cargo miri test -- --skip ::stress_test diff --git a/.rustfmt.toml b/.rustfmt.toml index 346ded89..060cd137 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,19 +1,19 @@ -# binop_separator = "Back" -# blank_lines_lower_bound = 0 -# blank_lines_upper_bound = 3 -# brace_style = "PreferSameLine" -# color = "Auto" -# condense_wildcard_suffixes = true -# fn_single_line = true -# format_macro_matchers = true -# format_strings = true -# group_imports = "StdExternalCrate" -# normalize_doc_attributes = true -# overflow_delimited_expr = true - +binop_separator = "Back" +blank_lines_lower_bound = 0 +blank_lines_upper_bound = 3 +color = "Auto" +condense_wildcard_suffixes = true +fn_single_line = true +format_macro_matchers = true +format_strings = true +group_imports = "StdExternalCrate" +imports_granularity = "Module" max_width = 80 newline_style = "Unix" +normalize_doc_attributes = true +overflow_delimited_expr = true reorder_imports = true reorder_modules = true use_field_init_shorthand = true +use_small_heuristics = "Max" use_try_shorthand = true diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..5ea2296a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[workspace.package] +version = "0.0.0" +authors = ["Michal Nazarewicz "] +edition = "2021" +rust-version = "1.71.0" + +[workspace] +members = [ + "sealable-trie", +] +resolver = "2" + +[workspace.dependencies] +base64 = { version = "0.21", default-features = false, features = ["alloc"] } +derive_more = "0.99.17" +pretty_assertions = "1.4.0" +rand = { version = "0.8.5" } +sha2 = { version = "0.10.7", default-features = false } +strum = { version = "0.25.0", default-features = false, features = ["derive"] } diff --git a/deny.toml b/deny.toml new file mode 100644 index 00000000..f009c807 --- /dev/null +++ b/deny.toml @@ -0,0 +1,12 @@ +[sources] +unknown-registry = "deny" +unknown-git = "deny" +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +allow-git = [] + +[bans] +multiple-versions = "deny" +skip = [ + # derive_more still uses old syn + { name = "syn", version = "1.0.*" }, +] diff --git a/sealable-trie/Cargo.toml b/sealable-trie/Cargo.toml new file mode 100644 index 00000000..b6e04c9b --- /dev/null +++ b/sealable-trie/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sealable-trie" +version = "0.0.0" +edition = "2021" + +[dependencies] +base64.workspace = true +derive_more.workspace = true +sha2.workspace = true +strum.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +rand.workspace = true diff --git a/sealable-trie/src/bits.rs b/sealable-trie/src/bits.rs new file mode 100644 index 00000000..ffd3dcba --- /dev/null +++ b/sealable-trie/src/bits.rs @@ -0,0 +1,1021 @@ +use core::fmt; + +#[cfg(test)] +use pretty_assertions::assert_eq; + +use crate::{nodes, stdx}; + +/// Representation of a slice of bits. +/// +/// **Note**: slices with different starting offset are considered different +/// even if iterating over all the bits gives the same result. +#[derive(Clone, Copy)] +pub struct Slice<'a> { + /// Offset in bits to start the slice in `bytes`. + /// + /// In other words, how many most significant bits to skip from `bytes`. + /// This is always less than eight (i.e. we never skip more than one byte). + pub(crate) offset: u8, + + /// Length of the slice in bits. + /// + /// `length + offset` is never more than `36 * 8`. + pub(crate) length: u16, + + /// The bytes to read the bits from. + /// + /// Value of bits outside of the range defined by `offset` and `length` is + /// unspecified and shouldn’t be read. + // Invariant: `ptr` points at `offset + length` valid bits. In other words, + // at `(offset + length + 7) / 8` valid bytes. + pub(crate) ptr: *const u8, + + phantom: core::marker::PhantomData<&'a [u8]>, +} + +/// An iterator over bits in a bit slice. +#[derive(Clone, Copy)] +pub struct Iter<'a> { + /// A 1-bit mask of the next bit to read from `*ptr`. + /// + /// Each time next bit is read, the mask is shifted once to the right. Once + /// it reaches zero, it’s reset to `0x80` and `ptr` is advanced to the next + /// byte. + mask: u8, + + /// Length of the slice in bits. + length: u16, + + /// Pointer to the byte at the beginning of the iterator. + // Invariant: `ptr` points at `offset + length` valid bits where `offset` + // equals `mask.leading_zeros()`. In other words, at `(offset + length + 7) + // / 8` valid bytes. + ptr: *const u8, + + phantom: core::marker::PhantomData<&'a [u8]>, +} + +/// An iterator over chunks of a slice where each chunk (except for the last +/// one) occupies exactly 34 bytes. +#[derive(Clone, Copy)] +pub struct Chunks<'a>(Slice<'a>); + +impl<'a> Slice<'a> { + /// Constructs a new bit slice. + /// + /// `bytes` is underlying bytes slice to read bits from. + /// + /// `offset` specifies how many most significant bits of the first byte of + /// the bytes slice to skip. Must be at most 7. + /// + /// `length` specifies length in bits of the entire bit slice. + /// + /// Returns `None` if `offset` or `length` is too large or `bytes` doesn’t + /// have enough underlying data for the length of the slice. + pub fn new(bytes: &'a [u8], offset: u8, length: u16) -> Option { + if offset >= 8 { + return None; + } + let has_bits = + u32::try_from(bytes.len()).unwrap_or(u32::MAX).saturating_mul(8); + (u32::from(length) + u32::from(offset) <= has_bits).then_some(Self { + offset, + length, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }) + } + + /// Constructs a new bit slice going through all bits in a bytes slice. + /// + /// Returns `None` if the slice is too long. The maximum length is 8191 + /// bytes. + pub fn from_bytes(bytes: &'a [u8]) -> Option { + Some(Self { + offset: 0, + length: u16::try_from(bytes.len().checked_mul(8)?).ok()?, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }) + } + + /// Returns length of the slice in bits. + pub fn len(&self) -> u16 { self.length } + + /// Returns whether the slice is empty. + pub fn is_empty(&self) -> bool { self.length == 0 } + + /// Returns the first bit in the slice advances the slice by one position. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x60], 0, 3).unwrap(); + /// assert_eq!(Some(false), slice.pop_front()); + /// assert_eq!(Some(true), slice.pop_front()); + /// assert_eq!(Some(true), slice.pop_front()); + /// assert_eq!(None, slice.pop_front()); + /// ``` + pub fn pop_front(&mut self) -> Option { + if self.length == 0 { + return None; + } + // SAFETY: self.length != 0 ⇒ self.ptr points at a valid byte + let bit = (unsafe { self.ptr.read() } & (0x80 >> self.offset)) != 0; + self.offset = (self.offset + 1) & 7; + if self.offset == 0 { + // SAFETY: self.ptr pointing at valid byte ⇒ self.ptr+1 is valid + // pointer + self.ptr = unsafe { self.ptr.add(1) } + } + self.length -= 1; + Some(bit) + } + + /// Returns the last bit in the slice shrinking the slice by one bit. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x60], 0, 3).unwrap(); + /// assert_eq!(Some(true), slice.pop_back()); + /// assert_eq!(Some(true), slice.pop_back()); + /// assert_eq!(Some(false), slice.pop_back()); + /// assert_eq!(None, slice.pop_back()); + /// ``` + pub fn pop_back(&mut self) -> Option { + self.length = self.length.checked_sub(1)?; + let total_bits = self.underlying_bits_length(); + // SAFETY: `ptr` is guaranteed to point at offset + original length + // valid bits. Furthermore, since original length was positive than + // there’s at least one byte we can read. + let byte = unsafe { self.ptr.add(total_bits / 8).read() }; + let mask = 0x80 >> (total_bits % 8); + Some(byte & mask != 0) + } + + /// Returns subslice from the end of the slice shrinking the slice by its + /// length. + /// + /// Returns `None` if the slice is too short. + /// + /// This is an ‘rpsilt_at’ operation but instead of returning two slices it + /// shortens the slice and returns the tail. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0x81], 0, 8).unwrap(); + /// let tail = slice.pop_back_slice(4).unwrap(); + /// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice)); + /// assert_eq!(bits::Slice::new(&[0x01], 4, 4), Some(tail)); + /// + /// assert_eq!(None, slice.pop_back_slice(5)); + /// assert_eq!(bits::Slice::new(&[0x80], 0, 4), Some(slice)); + /// ``` + pub fn pop_back_slice(&mut self, length: u16) -> Option { + self.length = self.length.checked_sub(length)?; + let total_bits = self.underlying_bits_length(); + // SAFETY: `ptr` is guaranteed to point at offset + original length + // valid bits. + let ptr = unsafe { self.ptr.add(total_bits / 8) }; + let offset = (total_bits % 8) as u8; + Some(Self { ptr, offset, length, phantom: Default::default() }) + } + + /// Returns an iterator over bits in the bit slice. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let slice = bits::Slice::new(&[0xA0], 0, 3).unwrap(); + /// let bits: Vec = slice.iter().collect(); + /// assert_eq!(&[true, false, true], bits.as_slice()); + /// ``` + pub fn iter(&self) -> Iter<'a> { + Iter { + mask: 0x80 >> self.offset, + length: self.length, + ptr: self.ptr, + phantom: self.phantom, + } + } + + /// Returns iterator over chunks of slice where each chunk occupies at most + /// 34 bytes. + /// + /// The final chunk may be shorter. Note that due to offset the first chunk + /// may be shorter than 272 bits (i.e. 34 * 8) however it will span full 34 + /// bytes. + pub fn chunks(&self) -> Chunks<'a> { Chunks(*self) } + + /// Returns whether the slice starts with given prefix. + /// + /// **Note**: If the `prefix` slice has a different bit offset it is not + /// considered a prefix even if it starts with the same bits. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0xAA, 0xA0], 0, 12).unwrap(); + /// + /// assert!(slice.starts_with(bits::Slice::new(&[0xAA], 0, 4).unwrap())); + /// assert!(!slice.starts_with(bits::Slice::new(&[0xFF], 0, 4).unwrap())); + /// // Different offset: + /// assert!(!slice.starts_with(bits::Slice::new(&[0xAA], 4, 4).unwrap())); + /// ``` + pub fn starts_with(&self, prefix: Slice<'_>) -> bool { + if self.offset != prefix.offset || self.length < prefix.length { + return false; + } + let subslice = Slice { length: prefix.length, ..*self }; + subslice == prefix + } + + /// Removes prefix from the slice; returns `false` if slice doesn’t start + /// with given prefix. + /// + /// **Note**: If the `prefix` slice has a different bit offset it is not + /// considered a prefix even if it starts with the same bits. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut slice = bits::Slice::new(&[0xAA, 0xA0], 0, 12).unwrap(); + /// + /// assert!(slice.strip_prefix(bits::Slice::new(&[0xAA], 0, 4).unwrap())); + /// assert_eq!(bits::Slice::new(&[0x0A, 0xA0], 4, 8).unwrap(), slice); + /// + /// // Doesn’t match: + /// assert!(!slice.strip_prefix(bits::Slice::new(&[0x0F], 4, 4).unwrap())); + /// // Different offset: + /// assert!(!slice.strip_prefix(bits::Slice::new(&[0xAA], 0, 4).unwrap())); + /// // Too long: + /// assert!(!slice.strip_prefix(bits::Slice::new(&[0x0A, 0xAA], 4, 12).unwrap())); + /// + /// assert!(slice.strip_prefix(bits::Slice::new(&[0xAA, 0xAA], 4, 6).unwrap())); + /// assert_eq!(bits::Slice::new(&[0x20], 2, 2).unwrap(), slice); + /// + /// assert!(slice.strip_prefix(slice.clone())); + /// assert_eq!(bits::Slice::new(&[0x00], 4, 0).unwrap(), slice); + /// ``` + pub fn strip_prefix(&mut self, prefix: Slice<'_>) -> bool { + if !self.starts_with(prefix) { + return false; + } + let length = usize::from(prefix.length) + usize::from(prefix.offset); + // SAFETY: self.ptr points to at least length+offset valid bits. + unsafe { self.ptr = self.ptr.add(length / 8) }; + self.offset = (length % 8) as u8; + self.length -= prefix.length; + true + } + + /// Strips common prefix from two slices; returns new slice with the common + /// prefix. + /// + /// **Note**: If two slices have different bit offset they are considered to + /// have an empty prefix. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// let mut left = bits::Slice::new(&[0xFF], 0, 8).unwrap(); + /// let mut right = bits::Slice::new(&[0xF0], 0, 8).unwrap(); + /// assert_eq!(bits::Slice::new(&[0xF0], 0, 4).unwrap(), + /// left.forward_common_prefix(&mut right)); + /// assert_eq!(bits::Slice::new(&[0xFF], 4, 4).unwrap(), left); + /// assert_eq!(bits::Slice::new(&[0xF0], 4, 4).unwrap(), right); + /// + /// let mut left = bits::Slice::new(&[0xFF], 0, 8).unwrap(); + /// let mut right = bits::Slice::new(&[0xFF], 0, 6).unwrap(); + /// assert_eq!(bits::Slice::new(&[0xFC], 0, 6).unwrap(), + /// left.forward_common_prefix(&mut right)); + /// assert_eq!(bits::Slice::new(&[0xFF], 6, 2).unwrap(), left); + /// assert_eq!(bits::Slice::new(&[0xFF], 6, 0).unwrap(), right); + /// ``` + pub fn forward_common_prefix(&mut self, other: &mut Slice<'_>) -> Self { + let offset = self.offset; + if offset != other.offset { + return Self { length: 0, ..*self }; + } + + let length = self.length.min(other.length); + // SAFETY: offset is common offset of both slices and length is shorter + // of either slice, which means that both pointers point to at least + // offset+length bits. + let (idx, length) = unsafe { + forward_common_prefix_impl(self.ptr, other.ptr, offset, length) + }; + let result = Self { length, ..*self }; + + self.length -= length; + self.offset = ((u16::from(self.offset) + length) % 8) as u8; + other.length -= length; + other.offset = self.offset; + // SAFETY: forward_common_prefix_impl guarantees that `idx` is no more + // than what the slices have. + unsafe { + self.ptr = self.ptr.add(idx); + other.ptr = other.ptr.add(idx); + } + + result + } + + /// Checks that all bits outside of the specified range are set to zero. + fn check_bytes(bytes: &[u8], offset: u8, length: u16) -> bool { + let (front, back) = Self::masks(offset, length); + let bytes_len = (usize::from(offset) + usize::from(length) + 7) / 8; + bytes_len <= bytes.len() && + (bytes[0] & !front) == 0 && + (bytes[bytes_len - 1] & !back) == 0 && + bytes[bytes_len..].iter().all(|&b| b == 0) + } + + /// Returns total number of underlying bits, i.e. bits in the slice plus the + /// offset. + fn underlying_bits_length(&self) -> usize { + usize::from(self.offset) + usize::from(self.length) + } + + /// Returns bytes underlying the bit slice. + fn bytes(&self) -> &'a [u8] { + let len = (self.underlying_bits_length() + 7) / 8; + // SAFETY: `ptr` is guaranteed to point at offset+length valid bits. + unsafe { core::slice::from_raw_parts(self.ptr, len) } + } + + /// Encodes key into raw binary representation. + /// + /// Fills entire 36-byte buffer. The first the first two bytes encode + /// length and offset (`(length << 3) | offset` specifically leaving the + /// four most significant bits zero) and the rest being bytes holding the + /// bits. Bits which are not part of the slice are set to zero. + /// + /// The first byte written will be xored with `tag`. + /// + /// Returns the length of relevant portion of the buffer. For example, if + /// slice’s length is say 20 bits with zero offset returns five (two bytes + /// for the encoded length and three bytes for the 20 bits). + /// + /// Returns `None` if the slice is empty or too long and won’t fit in the + /// destination buffer. + pub(crate) fn encode_into( + &self, + dest: &mut [u8; 36], + tag: u8, + ) -> Option { + if self.length == 0 { + return None; + } + let bytes = self.bytes(); + if bytes.is_empty() || bytes.len() > dest.len() - 2 { + return None; + } + let (num, tail) = + stdx::split_array_mut::<2, { nodes::MAX_EXTENSION_KEY_SIZE }, 36>( + dest, + ); + tail.fill(0); + *num = self.encode_num(tag); + let (key, _) = tail.split_at_mut(bytes.len()); + let (front, back) = Self::masks(self.offset, self.length); + key.copy_from_slice(bytes); + key[0] &= front; + key[bytes.len() - 1] &= back; + Some(2 + bytes.len()) + } + + /// Encodes key into raw binary representation and sends it to the consumer. + /// + /// This is like [`Self::encode_into`] except that it doesn’t check the + /// length of the key. + pub(crate) fn write_into(&self, mut consumer: impl FnMut(&[u8]), tag: u8) { + consumer(&self.encode_num(tag)); + + let (front, back) = Self::masks(self.offset, self.length); + let bytes = self.bytes(); + match bytes.len() { + 0 => (), + 1 => consumer(&[bytes[0] & front & back]), + 2 => consumer(&[bytes[0] & front, bytes[1] & back]), + n => { + consumer(&[bytes[0] & front]); + consumer(&bytes[1..n - 1]); + consumer(&[bytes[n - 1] & back]); + } + } + } + + /// Decodes key from a raw binary representation. + /// + /// The first byte read will be xored with `tag`. + /// + /// This is the inverse of [`Self::encode_into`]. + pub(crate) fn decode(src: &'a [u8], tag: u8) -> Option { + let (&[high, low], bytes) = stdx::split_at(src)?; + let tag = u16::from_be_bytes([high ^ tag, low]); + let (offset, length) = ((tag % 8) as u8, tag / 8); + (length > 0 && Self::check_bytes(bytes, offset, length)).then_some( + Self { + offset, + length, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }, + ) + } + + /// Encodes offset and length as a two-byte number. + /// + /// The encoding is `llll_llll llll_looo`, i.e. 13-bit length in the most + /// significant bits and 3-bit offset in the least significant bits. The + /// first byte is then further xored with the `tag` argument. + /// + /// This method doesn’t check whether the length and offset are within range. + fn encode_num(&self, tag: u8) -> [u8; 2] { + let num = (self.length << 3) | u16::from(self.offset); + (num ^ (u16::from(tag) << 8)).to_be_bytes() + } + + /// Helper method which returns masks for leading and trailing byte. + /// + /// Based on provided bit offset (which must be ≤ 7) and bit length of the + /// slice returns: mask of bits in the first byte that are part of the + /// slice and mask of bits in the last byte that are part of the slice. + fn masks(offset: u8, length: u16) -> (u8, u8) { + let bits = usize::from(offset) + usize::from(length); + // `1 << 20` is an arbitrary number which is divisible by 8 and greater + // than bits. + let tail = ((1 << 20) - bits) % 8; + (0xFF >> offset, 0xFF << tail) + } +} + +/// Implementation of [`Slice::forward_common_prefix`]. +/// +/// ## Safety +/// +/// `lhs` and `rhs` must point to at least `offset + max_length` bits. +unsafe fn forward_common_prefix_impl( + lhs: *const u8, + rhs: *const u8, + offset: u8, + max_length: u16, +) -> (usize, u16) { + let max_length = u32::from(max_length) + u32::from(offset); + // SAFETY: Caller promises that both pointers point to at least offset + + // max_length bits. + let (lhs, rhs) = unsafe { + let len = ((max_length + 7) / 8) as usize; + let lhs = core::slice::from_raw_parts(lhs, len).split_first(); + let rhs = core::slice::from_raw_parts(rhs, len).split_first(); + (lhs, rhs) + }; + + let (first, lhs, rhs) = match (lhs, rhs) { + (Some(lhs), Some(rhs)) => (lhs.0 ^ rhs.0, lhs.1, rhs.1), + _ => return (0, 0), + }; + let first = first & (0xFF >> offset); + + let total_bits_matched = if first != 0 { + first.leading_zeros() + } else if let Some(n) = lhs.iter().zip(rhs.iter()).position(|(a, b)| a != b) + { + 8 + n as u32 * 8 + (lhs[n] ^ rhs[n]).leading_zeros() + } else { + 8 + lhs.len() as u32 * 8 + } + .min(max_length); + + ( + (total_bits_matched / 8) as usize, + total_bits_matched.saturating_sub(u32::from(offset)) as u16, + ) +} + +impl<'a> core::iter::IntoIterator for Slice<'a> { + type Item = bool; + type IntoIter = Iter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { self.iter() } +} + +impl<'a> core::iter::IntoIterator for &'a Slice<'a> { + type Item = bool; + type IntoIter = Iter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { (*self).iter() } +} + +impl core::cmp::PartialEq for Slice<'_> { + /// Compares two slices to see if they contain the same bits and have the + /// same offset. + /// + /// **Note**: If the slices are the same length and contain the same bits + /// but their offsets are different, they are considered non-equal. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits; + /// + /// assert_eq!(bits::Slice::new(&[0xFF], 0, 6), + /// bits::Slice::new(&[0xFF], 0, 6)); + /// assert_ne!(bits::Slice::new(&[0xFF], 0, 6), + /// bits::Slice::new(&[0xF0], 0, 6)); + /// assert_ne!(bits::Slice::new(&[0xFF], 0, 6), + /// bits::Slice::new(&[0xFF], 2, 6)); + /// ``` + fn eq(&self, other: &Self) -> bool { + if self.offset != other.offset || self.length != other.length { + return false; + } else if self.length == 0 { + return true; + } + let (front, back) = Self::masks(self.offset, self.length); + let (lhs, rhs) = (self.bytes(), other.bytes()); + let len = lhs.len(); + if len == 1 { + ((lhs[0] ^ rhs[0]) & front & back) == 0 + } else { + ((lhs[0] ^ rhs[0]) & front) == 0 && + ((lhs[len - 1] ^ rhs[len - 1]) & back) == 0 && + lhs[1..len - 1] == rhs[1..len - 1] + } + } +} + +impl fmt::Display for Slice<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(buf: &mut [u8], mut byte: u8) { + for ch in buf.iter_mut().rev() { + *ch = b'0' + (byte & 1); + byte >>= 1; + } + } + + let (first, mid) = match self.bytes().split_first() { + None => return fmtr.write_str("∅"), + Some(pair) => pair, + }; + + let off = usize::from(self.offset); + let len = usize::from(self.length); + let mut buf = [0; 10]; + fmt(&mut buf[2..], *first); + buf[0] = b'0'; + buf[1] = b'b'; + buf[2..2 + off].fill(b'.'); + + let (last, mid) = match mid.split_last() { + None => { + buf[2 + off + len..].fill(b'.'); + let val = unsafe { core::str::from_utf8_unchecked(&buf) }; + return fmtr.write_str(val); + } + Some(pair) => pair, + }; + + fmtr.write_str(unsafe { core::str::from_utf8_unchecked(&buf) })?; + for byte in mid { + write!(fmtr, "_{:08b}", byte)?; + } + fmt(&mut buf[..9], *last); + buf[0] = b'_'; + let len = (off + len) % 8; + if len != 0 { + buf[1 + len..].fill(b'.'); + } + fmtr.write_str(unsafe { core::str::from_utf8_unchecked(&buf[..9]) }) + } +} + +impl fmt::Debug for Slice<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_fmt("Slice", self, fmtr) + } +} + +impl fmt::Debug for Iter<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + let slice = Slice { + offset: self.mask.leading_zeros() as u8, + length: self.length, + ptr: self.ptr, + phantom: self.phantom, + }; + debug_fmt("Iter", &slice, fmtr) + } +} + +impl fmt::Debug for Chunks<'_> { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_fmt("Chunks", &self.0, fmtr) + } +} + +/// Internal function for debug formatting objects objects. +fn debug_fmt( + name: &str, + slice: &Slice<'_>, + fmtr: &mut fmt::Formatter<'_>, +) -> fmt::Result { + fmtr.debug_struct(name) + .field("offset", &slice.offset) + .field("length", &slice.length) + .field("bytes", &core::format_args!("{:02x?}", slice.bytes())) + .finish() +} + +impl<'a> core::iter::Iterator for Iter<'a> { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.length == 0 { + return None; + } + // SAFETY: When length is non-zero, ptr points to a valid byte. + let result = (unsafe { self.ptr.read() } & self.mask) != 0; + self.length -= 1; + self.mask = self.mask.rotate_right(1); + if self.mask == 0x80 { + // SAFETY: ptr points to a valid object (see above) so ptr+1 is + // a valid pointer (at worst it’s one-past-the-end pointer). + self.ptr = unsafe { self.ptr.add(1) }; + } + Some(result) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (usize::from(self.length), Some(usize::from(self.length))) + } + + #[inline] + fn count(self) -> usize { usize::from(self.length) } +} + +impl<'a> core::iter::ExactSizeIterator for Iter<'a> { + #[inline] + fn len(&self) -> usize { usize::from(self.length) } +} + +impl<'a> core::iter::FusedIterator for Iter<'a> {} + +impl<'a> core::iter::Iterator for Chunks<'a> { + type Item = Slice<'a>; + + fn next(&mut self) -> Option> { + let bytes_len = self.0.bytes().len().min(nodes::MAX_EXTENSION_KEY_SIZE); + if bytes_len == 0 { + return None; + } + let slice = &mut self.0; + let offset = slice.offset; + let length = (bytes_len * 8 - usize::from(offset)) + .min(usize::from(slice.length)) as u16; + let ptr = slice.ptr; + slice.offset = 0; + slice.length -= length; + // SAFETY: `ptr` points at a slice which is at least `bytes_len` bytes + // long so it’s safe to advance it by that offset. + slice.ptr = unsafe { slice.ptr.add(bytes_len) }; + Some(Slice { offset, length, ptr, phantom: Default::default() }) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl<'a> core::iter::ExactSizeIterator for Chunks<'a> { + #[inline] + fn len(&self) -> usize { + self.0.bytes().chunks(nodes::MAX_EXTENSION_KEY_SIZE).len() + } +} + +impl<'a> core::iter::DoubleEndedIterator for Chunks<'a> { + fn next_back(&mut self) -> Option> { + let mut chunks = self.0.bytes().chunks(nodes::MAX_EXTENSION_KEY_SIZE); + let bytes = chunks.next_back()?; + + if chunks.next().is_none() { + let empty = Slice { + offset: 0, + length: 0, + ptr: self.0.ptr, + phantom: Default::default(), + }; + return Some(core::mem::replace(&mut self.0, empty)); + } + + // `1 << 20` is an arbitrary number which is divisible by 8 and greater + // than underlying_bits_length. + let tail = ((1 << 20) - self.0.underlying_bits_length()) % 8; + let length = (bytes.len() * 8 - tail) as u16; + self.0.length -= length; + + Some(Slice { + offset: 0, + length, + ptr: bytes.as_ptr(), + phantom: Default::default(), + }) + } +} + +#[test] +fn test_encode() { + #[track_caller] + fn test(want_encoded: &[u8], offset: u8, length: u16, bytes: &[u8]) { + let slice = Slice::new(bytes, offset, length).unwrap(); + + if want_encoded.len() <= 36 { + let mut buf = [0; 36]; + let len = slice + .encode_into(&mut buf, 0) + .unwrap_or_else(|| panic!("Failed encoding {slice}")); + assert_eq!( + want_encoded, + &buf[..len], + "Unexpected encoded representation of {slice}" + ); + } + + let mut buf = alloc::vec::Vec::with_capacity(want_encoded.len()); + slice.write_into(|bytes| buf.extend_from_slice(bytes), 0); + assert_eq!( + want_encoded, + buf.as_slice(), + "Unexpected written representation of {slice}" + ); + + let round_trip = Slice::decode(want_encoded, 0) + .unwrap_or_else(|| panic!("Failed decoding {want_encoded:?}")); + assert_eq!(slice, round_trip); + } + + test(&[0, 1 * 8 + 0, 0x80], 0, 1, &[0x80]); + test(&[0, 1 * 8 + 0, 0x80], 0, 1, &[0xFF]); + test(&[0, 1 * 8 + 4, 0x08], 4, 1, &[0xFF]); + test(&[0, 9 * 8 + 0, 0xFF, 0x80], 0, 9, &[0xFF, 0xFF]); + test(&[0, 9 * 8 + 4, 0x0F, 0xF8], 4, 9, &[0xFF, 0xFF]); + test(&[0, 17 * 8 + 0, 0xFF, 0xFF, 0x80], 0, 17, &[0xFF, 0xFF, 0xFF]); + test(&[0, 17 * 8 + 4, 0x0F, 0xFF, 0xF8], 4, 17, &[0xFF, 0xFF, 0xFF]); + + let mut want = [0xFF; 1026]; + want[0] = (8191u16 >> 5) as u8; + want[1] = (8191u16 << 3) as u8; + want[1025] = 0xFE; + test(&want[..], 0, 8191, &[0xFF; 1024][..]); + + want[1] += 1; + want[2] = 0x7F; + want[1025] = 0xFF; + test(&want[..], 1, 8191, &[0xFF; 1024][..]); +} + +#[test] +fn test_decode() { + #[track_caller] + fn ok(num: u16, bytes: &[u8], want_offset: u8, want_length: u16) { + let bytes = [&num.to_be_bytes()[..], bytes].concat(); + let got = Slice::decode(&bytes, 0).unwrap_or_else(|| { + panic!("Expected to get a Slice from {bytes:x?}") + }); + assert_eq!((want_offset, want_length), (got.offset, got.length)); + } + + // Correct values, all bits zero. + ok(34 * 64, &[0; 34], 0, 34 * 8); + ok(33 * 64 + 7, &[0; 34], 7, 264); + ok(2 * 64, &[0, 0], 0, 16); + + // Empty + assert_eq!(None, Slice::decode(&[], 0)); + assert_eq!(None, Slice::decode(&[0], 0)); + assert_eq!(None, Slice::decode(&[0, 0], 0)); + + #[track_caller] + fn test(length: u16, offset: u8, bad: &[u8], good: &[u8]) { + let num = length * 8 + u16::from(offset); + let bad = [&num.to_be_bytes()[..], bad].concat(); + assert_eq!(None, Slice::decode(&bad, 0)); + + let good = [&num.to_be_bytes()[..], good].concat(); + let got = Slice::decode(&good, 0).unwrap_or_else(|| { + panic!("Expected to get a Slice from {good:x?}") + }); + assert_eq!( + (offset, length), + (got.offset, got.length), + "Invalid offset and length decoding {good:x?}" + ); + + let good = [&good[..], &[0, 0]].concat(); + let got = Slice::decode(&good, 0).unwrap_or_else(|| { + panic!("Expected to get a Slice from {good:x?}") + }); + assert_eq!( + (offset, length), + (got.offset, got.length), + "Invalid offset and length decoding {good:x?}" + ); + } + + // Bytes buffer doesn’t match the length. + test(8, 0, &[], &[0]); + test(8, 7, &[0], &[0, 0]); + test(16, 1, &[0, 0], &[0, 0, 0]); + + // Bits which should be zero aren’t. + // Leading bits are skipped: + test(16 - 1, 1, &[0x80, 0], &[0x7F, 0xFF]); + test(16 - 2, 2, &[0x40, 0], &[0x3F, 0xFF]); + test(16 - 3, 3, &[0x20, 0], &[0x1F, 0xFF]); + test(16 - 4, 4, &[0x10, 0], &[0x0F, 0xFF]); + test(16 - 5, 5, &[0x08, 0], &[0x07, 0xFF]); + test(16 - 6, 6, &[0x04, 0], &[0x03, 0xFF]); + test(16 - 7, 7, &[0x02, 0], &[0x01, 0xFF]); + + // Tailing bits are skipped: + test(16 - 1, 0, &[0, 0x01], &[0xFF, 0xFE]); + test(16 - 2, 0, &[0, 0x02], &[0xFF, 0xFC]); + test(16 - 3, 0, &[0, 0x04], &[0xFF, 0xF8]); + test(16 - 4, 0, &[0, 0x08], &[0xFF, 0xF0]); + test(16 - 5, 0, &[0, 0x10], &[0xFF, 0xE0]); + test(16 - 6, 0, &[0, 0x20], &[0xFF, 0xC0]); + test(16 - 7, 0, &[0, 0x40], &[0xFF, 0x80]); + + // Some leading and some tailing bits are skipped of the same byte: + test(1, 1, &[!0x40], &[0x40]); + test(1, 2, &[!0x20], &[0x20]); + test(1, 3, &[!0x10], &[0x10]); + test(1, 4, &[!0x08], &[0x08]); + test(1, 5, &[!0x04], &[0x04]); + test(1, 6, &[!0x02], &[0x02]); +} + +#[test] +fn test_common_prefix() { + let mut lhs = Slice::new(&[0x86, 0xE9], 1, 15).unwrap(); + let mut rhs = Slice::new(&[0x06, 0xE9], 1, 15).unwrap(); + let got = lhs.forward_common_prefix(&mut rhs); + let want = ( + Slice::new(&[0x06, 0xE9], 1, 15).unwrap(), + Slice::new(&[], 0, 0).unwrap(), + Slice::new(&[], 0, 0).unwrap(), + ); + assert_eq!(want, (got, lhs, rhs)); +} + +#[test] +fn test_display() { + fn test(want: &str, bytes: &[u8], offset: u8, length: u16) { + use alloc::string::ToString; + + let got = Slice::new(bytes, offset, length).unwrap().to_string(); + assert_eq!(want, got) + } + + test("0b111111..", &[0xFF], 0, 6); + test("0b..1111..", &[0xFF], 2, 4); + test("0b..111111_11......", &[0xFF, 0xFF], 2, 8); + test("0b..111111_11111111_11......", &[0xFF, 0xFF, 0xFF], 2, 16); + + test("0b10101010", &[0xAA], 0, 8); + test("0b...0101.", &[0xAA], 3, 4); +} + +#[test] +fn test_eq() { + assert_eq!(Slice::new(&[0xFF], 0, 8), Slice::new(&[0xFF], 0, 8)); + assert_eq!(Slice::new(&[0xFF], 0, 4), Slice::new(&[0xF0], 0, 4)); + assert_eq!(Slice::new(&[0xFF], 4, 4), Slice::new(&[0x0F], 4, 4)); +} + +#[test] +#[rustfmt::skip] +fn test_iter() { + use alloc::vec::Vec; + + #[track_caller] + fn test(want: &[u8], bytes: &[u8], offset: u8, length: u16) { + let want = want.iter().map(|&b| b != 0).collect::>(); + let slice = Slice::new(bytes, offset, length).unwrap(); + let got = slice.iter().collect::>(); + assert_eq!(want, got); + } + + test(&[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], + &[0xAA, 0xAA], 0, 16); + test(&[1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0], + &[0x0A, 0xFA], 4, 12); + test(&[0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1], + &[0x0A, 0xFA], 0, 12); + test(&[1, 1, 0, 0], &[0x30], 2, 4); +} + +#[test] +fn test_chunks() { + let data = (0..=255).collect::>(); + let data = data.as_slice(); + + let slice = |off: u8, len: u16| Slice::new(data, off, len).unwrap(); + + // Single chunk + for offset in 0..8 { + for length in 1..(34 * 8 - u16::from(offset)) { + let want = Slice::new(data, offset, length); + + let mut chunks = slice(offset, length).chunks(); + assert_eq!(want, chunks.next()); + assert_eq!(None, chunks.next()); + + let mut chunks = slice(offset, length).chunks(); + assert_eq!(want, chunks.next_back()); + assert_eq!(None, chunks.next()); + } + } + + // Two chunks + for offset in 0..8 { + let want_first = Slice::new(data, offset, 34 * 8 - u16::from(offset)); + let want_second = Slice::new(&data[34..], 0, 10 + u16::from(offset)); + + let mut chunks = slice(offset, 34 * 8 + 10).chunks(); + assert_eq!(want_first, chunks.next()); + assert_eq!(want_second, chunks.next()); + assert_eq!(None, chunks.next()); + + let mut chunks = slice(offset, 34 * 8 + 10).chunks(); + assert_eq!(want_second, chunks.next_back()); + assert_eq!(want_first, chunks.next_back()); + assert_eq!(None, chunks.next()); + + let mut chunks = slice(offset, 34 * 8 + 10).chunks(); + assert_eq!(want_second, chunks.next_back()); + assert_eq!(want_first, chunks.next()); + assert_eq!(None, chunks.next()); + } +} + +#[test] +fn test_pop() { + use alloc::string::String; + + const WANT: &str = concat!("11001110", "00011110", "00011111"); + const BYTES: [u8; 3] = [0b1100_1110, 0b0001_1110, 0b0001_1111]; + + fn test( + want: &str, + mut slice: Slice, + reverse: bool, + pop: fn(&mut Slice) -> Option, + ) { + let got = core::iter::from_fn(move || pop(&mut slice)) + .map(|bit| char::from(b'0' + u8::from(bit))) + .collect::(); + let want = if reverse { + want.chars().rev().collect() + } else { + String::from(want) + }; + assert_eq!(want, got); + } + + fn test_set(reverse: bool, pop: fn(&mut Slice) -> Option) { + for start in 0..8 { + for end in start..=24 { + let slice = + Slice::new(&BYTES[..], start as u8, (end - start) as u16); + test(&WANT[start..end], slice.unwrap(), reverse, pop); + } + } + } + + test_set(false, |slice| slice.pop_front()); + test_set(true, |slice| slice.pop_back()); +} diff --git a/sealable-trie/src/hash.rs b/sealable-trie/src/hash.rs new file mode 100644 index 00000000..b0d2cc1b --- /dev/null +++ b/sealable-trie/src/hash.rs @@ -0,0 +1,214 @@ +use base64::engine::general_purpose::STANDARD as BASE64_ENGINE; +use base64::Engine; +use sha2::Digest; + +/// A cryptographic hash. +#[derive( + Clone, + Default, + PartialEq, + Eq, + derive_more::AsRef, + derive_more::AsMut, + derive_more::From, + derive_more::Into, +)] +#[as_ref(forward)] +#[into(owned, ref, ref_mut)] +#[repr(transparent)] +pub struct CryptoHash(pub [u8; CryptoHash::LENGTH]); + +// TODO(mina86): Make the code generic such that CryptoHash::digest take generic +// argument for the hash to use. This would then mean that Trie, Proof and +// other objects which need to calculate hashes would need to take that argument +// as well. +impl CryptoHash { + /// Length in bytes of the cryptographic hash. + pub const LENGTH: usize = 32; + + /// Default hash value (all zero bits). + pub const DEFAULT: CryptoHash = CryptoHash([0; 32]); + + /// Returns a builder which can be used to construct cryptographic hash by + /// digesting bytes. + #[inline] + pub fn builder() -> Builder { Builder::default() } + + /// Returns hash of given bytes. + #[inline] + pub fn digest(bytes: &[u8]) -> Self { + Self(sha2::Sha256::digest(bytes).into()) + } + + /// Returns hash of concatenation of given byte slices. + #[inline] + pub fn digest_vec(slices: &[&[u8]]) -> Self { + let mut builder = Self::builder(); + for slice in slices { + builder.update(slice); + } + builder.build() + } + + + /// Creates a new hash with given number encoded in its first bytes. + /// + /// This is meant for tests which need to use arbitrary hash values. + #[cfg(test)] + pub(crate) const fn test(num: usize) -> CryptoHash { + let mut buf = [0; Self::LENGTH]; + let num = (num as u32).to_be_bytes(); + let mut idx = 0; + while idx < buf.len() { + buf[idx] = num[idx % num.len()]; + idx += 1; + } + Self(buf) + } + + /// Returns whether the hash is all zero bits. Equivalent to comparing to + /// the default `CryptoHash` object. + #[inline] + pub fn is_zero(&self) -> bool { self.0.iter().all(|&byte| byte == 0) } + + /// Returns reference to the hash as slice of bytes. + #[inline] + pub fn as_slice(&self) -> &[u8] { &self.0[..] } +} + +impl core::fmt::Display for CryptoHash { + /// Encodes the hash as base64 and prints it as a string. + fn fmt(&self, fmtr: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + const ENCODED_LENGTH: usize = (CryptoHash::LENGTH + 2) / 3 * 4; + let mut buf = [0u8; ENCODED_LENGTH]; + let len = + BASE64_ENGINE.encode_slice(self.as_slice(), &mut buf[..]).unwrap(); + // SAFETY: base64 fills the buffer with ASCII characters only. + fmtr.write_str(unsafe { core::str::from_utf8_unchecked(&buf[..len]) }) + } +} + +impl core::fmt::Debug for CryptoHash { + /// Encodes the hash as base64 and prints it as a string. + #[inline] + fn fmt(&self, fmtr: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Display::fmt(self, fmtr) + } +} + +impl<'a> From<&'a [u8; CryptoHash::LENGTH]> for CryptoHash { + #[inline] + fn from(hash: &'a [u8; CryptoHash::LENGTH]) -> Self { + <&CryptoHash>::from(hash).clone() + } +} + +impl From<&'_ CryptoHash> for [u8; CryptoHash::LENGTH] { + #[inline] + fn from(hash: &'_ CryptoHash) -> Self { hash.0.clone() } +} + +impl<'a> From<&'a [u8; CryptoHash::LENGTH]> for &'a CryptoHash { + #[inline] + fn from(hash: &'a [u8; CryptoHash::LENGTH]) -> Self { + let hash = + (hash as *const [u8; CryptoHash::LENGTH]).cast::(); + // SAFETY: CryptoHash is repr(transparent) over [u8; CryptoHash::LENGTH] + // thus transmuting is safe. + unsafe { &*hash } + } +} + +impl<'a> From<&'a mut [u8; CryptoHash::LENGTH]> for &'a mut CryptoHash { + #[inline] + fn from(hash: &'a mut [u8; CryptoHash::LENGTH]) -> Self { + let hash = (hash as *mut [u8; CryptoHash::LENGTH]).cast::(); + // SAFETY: CryptoHash is repr(transparent) over [u8; CryptoHash::LENGTH] + // thus transmuting is safe. + unsafe { &mut *hash } + } +} + +impl<'a> TryFrom<&'a [u8]> for &'a CryptoHash { + type Error = core::array::TryFromSliceError; + + #[inline] + fn try_from(hash: &'a [u8]) -> Result { + <&[u8; CryptoHash::LENGTH]>::try_from(hash).map(Into::into) + } +} + +impl<'a> TryFrom<&'a mut [u8]> for &'a mut CryptoHash { + type Error = core::array::TryFromSliceError; + + #[inline] + fn try_from(hash: &'a mut [u8]) -> Result { + <&mut [u8; CryptoHash::LENGTH]>::try_from(hash).map(Into::into) + } +} + +impl<'a> TryFrom<&'a [u8]> for CryptoHash { + type Error = core::array::TryFromSliceError; + + #[inline] + fn try_from(hash: &'a [u8]) -> Result { + <&CryptoHash>::try_from(hash).map(Clone::clone) + } +} + +/// Builder for the cryptographic hash. +/// +/// The builder calculates the digest of bytes that it’s fed using the +/// [`Builder::update`] method. +/// +/// This is useful if there are multiple discontiguous buffers that hold the +/// data to be hashed. If all data is in a single contiguous buffer it’s more +/// convenient to use [`CryptoHash::digest`] instead. +#[derive(Default)] +pub struct Builder(sha2::Sha256); + +impl Builder { + /// Process data, updating the internal state of the digest. + #[inline] + pub fn update(&mut self, bytes: &[u8]) { self.0.update(bytes) } + + /// Finalises the digest and returns the cryptographic hash. + #[inline] + pub fn build(self) -> CryptoHash { CryptoHash(self.0.finalize().into()) } +} + +#[test] +fn test_new_hash() { + assert_eq!(CryptoHash::from([0; 32]), CryptoHash::default()); + + // https://www.di-mgt.com.au/sha_testvectors.html + let want = CryptoHash::from([ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, + 0x99, 0x6f, 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + ]); + assert_eq!(want, CryptoHash::digest(b"")); + assert_eq!(want, CryptoHash::builder().build()); + let got = { + let mut builder = CryptoHash::builder(); + builder.update(b""); + builder.build() + }; + assert_eq!(want, got); + assert_eq!(want, CryptoHash::builder().build()); + + let want = CryptoHash::from([ + 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, 0x41, 0x41, 0x40, 0xde, + 0x5d, 0xae, 0x22, 0x23, 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c, + 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad, + ]); + assert_eq!(want, CryptoHash::digest(b"abc")); + assert_eq!(want, CryptoHash::digest_vec(&[b"a", b"bc"])); + let got = { + let mut builder = CryptoHash::builder(); + builder.update(b"a"); + builder.update(b"bc"); + builder.build() + }; + assert_eq!(want, got); +} diff --git a/sealable-trie/src/lib.rs b/sealable-trie/src/lib.rs new file mode 100644 index 00000000..49500c3b --- /dev/null +++ b/sealable-trie/src/lib.rs @@ -0,0 +1,17 @@ +#![no_std] +extern crate alloc; +#[cfg(test)] +extern crate std; + +pub mod bits; +pub mod hash; +pub mod memory; +pub mod nodes; +pub mod proof; +pub(crate) mod stdx; +pub mod trie; + +#[cfg(test)] +mod test_utils; + +pub use trie::Trie; diff --git a/sealable-trie/src/memory.rs b/sealable-trie/src/memory.rs new file mode 100644 index 00000000..16492844 --- /dev/null +++ b/sealable-trie/src/memory.rs @@ -0,0 +1,408 @@ +use alloc::vec::Vec; +use core::fmt; +use core::num::NonZeroU32; + +use crate::nodes::RawNode; + +/// A pointer value. The value is 30-bit and always non-zero. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + derive_more::Into, + derive_more::Deref, +)] +#[into(owned, ref, ref_mut)] +#[repr(transparent)] +pub struct Ptr(NonZeroU32); + +#[derive(Copy, Clone, Debug, PartialEq, Eq, derive_more::Display)] +pub struct OutOfMemory; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, derive_more::Display)] +pub struct AddressTooLarge(pub NonZeroU32); + +impl Ptr { + /// Largest value that can be stored in the pointer. + // The two most significant bits are used internally in RawNode encoding + // thus the max value is 30-bit. + const MAX: u32 = u32::MAX >> 2; + + /// Constructs a new pointer from given address. + /// + /// If the value is zero, returns `None` indicating a null pointer. If + /// value fits in 30 bits, returns a new `Ptr` with that value. Otherwise + /// returns an error with the argument. + /// + /// ## Example + /// + /// ``` + /// # use core::num::NonZeroU32; + /// # use sealable_trie::memory::*; + /// + /// assert_eq!(Ok(None), Ptr::new(0)); + /// assert_eq!(42, Ptr::new(42).unwrap().unwrap().get()); + /// assert_eq!((1 << 30) - 1, + /// Ptr::new((1 << 30) - 1).unwrap().unwrap().get()); + /// assert_eq!(Err(AddressTooLarge(NonZeroU32::new(1 << 30).unwrap())), + /// Ptr::new(1 << 30)); + /// ``` + pub const fn new(ptr: u32) -> Result, AddressTooLarge> { + // Using match so the function is const + match NonZeroU32::new(ptr) { + None => Ok(None), + Some(num) if num.get() <= Self::MAX => Ok(Some(Self(num))), + Some(num) => Err(AddressTooLarge(num)), + } + } + + /// Constructs a new pointer from given address. + /// + /// Two most significant bits of the address are masked out thus ensuring + /// that the value is never too large. + pub(crate) fn new_truncated(ptr: u32) -> Option { + NonZeroU32::new(ptr & Self::MAX).map(Self) + } +} + +impl TryFrom for Ptr { + type Error = AddressTooLarge; + + /// Constructs a new pointer from given non-zero address. + /// + /// If the address is too large (see [`Ptr::MAX`]) returns an error with the + /// address which has been passed. + /// + /// ## Example + /// + /// ``` + /// # use core::num::NonZeroU32; + /// # use sealable_trie::memory::*; + /// + /// let answer = NonZeroU32::new(42).unwrap(); + /// assert_eq!(42, Ptr::try_from(answer).unwrap().get()); + /// + /// let large = NonZeroU32::new(1 << 30).unwrap(); + /// assert_eq!(Err(AddressTooLarge(large)), Ptr::try_from(large)); + /// ``` + fn try_from(num: NonZeroU32) -> Result { + if num.get() <= Ptr::MAX { + Ok(Ptr(num)) + } else { + Err(AddressTooLarge(num)) + } + } +} + +impl fmt::Display for Ptr { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.get().fmt(fmtr) + } +} + +impl fmt::Debug for Ptr { + fn fmt(&self, fmtr: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.get().fmt(fmtr) + } +} + +/// An interface for memory management used by the trie. +pub trait Allocator { + /// Allocates a new block and initialise it to given value. + fn alloc(&mut self, value: RawNode) -> Result; + + /// Returns value stored at given pointer. + /// + /// May panic or return garbage if `ptr` is invalid. + fn get(&self, ptr: Ptr) -> RawNode; + + /// Sets value at given pointer. + fn set(&mut self, ptr: Ptr, value: RawNode); + + /// Frees a block. + fn free(&mut self, ptr: Ptr); +} + +/// A write log which can be committed or rolled back. +/// +/// Rather than writing data directly to the allocate, it keeps all changes in +/// memory. When committing, the changes are then applied. Similarly, list of +/// all allocated nodes are kept and during rollback all of those nodes are +/// freed. +/// +/// **Note:** *All* reads are passed directly to the underlying allocator. This +/// means reading a node that has been written to will return the old result. +/// +/// Note that the write log doesn’t offer isolation. Most notably, writes to +/// the allocator performed outside of the write log are visible when accessing +/// the nodes via the write log. (To indicate that, the API doesn’t offer `get` +/// method and instead all reads need to go through the underlying allocator). +/// +/// Secondly, allocations done via the write log are visible outside of the +/// write log. The assumption is that nothing outside of the client of the +/// write log knows the pointer thus in practice they cannot refer to those +/// allocated but not-yet-committed nodes. +pub struct WriteLog<'a, A: Allocator> { + /// Allocator to pass requests to. + alloc: &'a mut A, + + /// List of changes in the transaction. + write_log: Vec<(Ptr, RawNode)>, + + /// List pointers to nodes allocated during the transaction. + allocated: Vec, + + /// List of nodes freed during the transaction. + freed: Vec, +} + +impl<'a, A: Allocator> WriteLog<'a, A> { + pub fn new(alloc: &'a mut A) -> Self { + Self { + alloc, + write_log: Vec::new(), + allocated: Vec::new(), + freed: Vec::new(), + } + } + + /// Commit all changes to the allocator. + /// + /// There’s no explicit rollback method. To roll changes back, drop the + /// object. + pub fn commit(mut self) { + self.allocated.clear(); + for (ptr, value) in self.write_log.drain(..) { + self.alloc.set(ptr, value) + } + for ptr in self.freed.drain(..) { + self.alloc.free(ptr) + } + } + + /// Returns underlying allocator. + pub fn allocator(&self) -> &A { &*self.alloc } + + pub fn alloc(&mut self, value: RawNode) -> Result { + let ptr = self.alloc.alloc(value)?; + self.allocated.push(ptr); + Ok(ptr) + } + + pub fn set(&mut self, ptr: Ptr, value: RawNode) { + self.write_log.push((ptr, value)) + } + + pub fn free(&mut self, ptr: Ptr) { self.freed.push(ptr); } +} + +impl<'a, A: Allocator> core::ops::Drop for WriteLog<'a, A> { + fn drop(&mut self) { + self.write_log.clear(); + self.freed.clear(); + for ptr in self.allocated.drain(..) { + self.alloc.free(ptr) + } + } +} + +#[cfg(test)] +pub(crate) mod test_utils { + use super::*; + use crate::stdx; + + pub struct TestAllocator { + count: usize, + free: Option, + pool: alloc::vec::Vec, + allocated: std::collections::HashMap, + } + + impl TestAllocator { + pub fn new(capacity: usize) -> Self { + let max_cap = usize::try_from(Ptr::MAX).unwrap_or(usize::MAX); + let capacity = capacity.min(max_cap); + let mut pool = alloc::vec::Vec::with_capacity(capacity); + pool.push(RawNode([0xAA; 72])); + Self { count: 0, free: None, pool, allocated: Default::default() } + } + + pub fn count(&self) -> usize { self.count } + + /// Verifies that block has been allocated. Panics if it hasn’t. + fn check_allocated(&self, action: &str, ptr: Ptr) -> usize { + let adj = match self.allocated.get(&ptr.get()).copied() { + None => "unallocated", + Some(false) => "freed", + Some(true) => return usize::try_from(ptr.get()).unwrap(), + }; + panic!("Tried to {action} {adj} block at {ptr}") + } + } + + impl Allocator for TestAllocator { + fn alloc(&mut self, value: RawNode) -> Result { + let ptr = if let Some(ptr) = self.free { + // Grab node from free list + let node = &mut self.pool[ptr.get() as usize]; + let bytes = stdx::split_array_ref::<4, 68, 72>(&node.0).0; + self.free = Ptr::new(u32::from_ne_bytes(*bytes)).unwrap(); + *node = value; + ptr + } else if self.pool.len() < self.pool.capacity() { + // Grab new node + self.pool.push(value); + Ptr::new((self.pool.len() - 1) as u32).unwrap().unwrap() + } else { + // No free node to allocate + return Err(OutOfMemory); + }; + + assert!( + self.allocated.insert(ptr.get(), true) != Some(true), + "Internal error: Allocated the same block twice at {ptr}", + ); + self.count += 1; + Ok(ptr) + } + + #[track_caller] + fn get(&self, ptr: Ptr) -> RawNode { + self.pool[self.check_allocated("read", ptr)].clone() + } + + #[track_caller] + fn set(&mut self, ptr: Ptr, value: RawNode) { + let idx = self.check_allocated("read", ptr); + self.pool[idx] = value + } + + fn free(&mut self, ptr: Ptr) { + let idx = self.check_allocated("free", ptr); + self.allocated.insert(ptr.get(), false); + *stdx::split_array_mut::<4, 68, 72>(&mut self.pool[idx].0).0 = + self.free.map_or(0, |ptr| ptr.get()).to_ne_bytes(); + self.free = Some(ptr); + self.count -= 1; + } + } +} + +#[cfg(test)] +mod test_write_log { + use super::test_utils::TestAllocator; + use super::*; + use crate::hash::CryptoHash; + + fn make_allocator() -> (TestAllocator, Vec) { + let mut alloc = TestAllocator::new(100); + let ptrs = (0..10) + .map(|num| alloc.alloc(make_node(num)).unwrap()) + .collect::>(); + assert_nodes(10, &alloc, &ptrs, 0); + (alloc, ptrs) + } + + fn make_node(num: usize) -> RawNode { + let hash = CryptoHash::test(num); + let child = crate::nodes::Reference::node(None, &hash); + RawNode::branch(child, child) + } + + #[track_caller] + fn assert_nodes( + count: usize, + alloc: &TestAllocator, + ptrs: &[Ptr], + offset: usize, + ) { + assert_eq!(count, alloc.count()); + for (idx, ptr) in ptrs.iter().enumerate() { + assert_eq!( + make_node(idx + offset), + alloc.get(*ptr), + "Invalid value when reading {ptr}" + ); + } + } + + #[test] + fn test_set_commit() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for (idx, &ptr) in ptrs.iter().take(5).enumerate() { + wlog.set(ptr, make_node(idx + 10)); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + wlog.commit(); + assert_nodes(10, &alloc, &ptrs[..5], 10); + assert_nodes(10, &alloc, &ptrs[5..], 5); + } + + #[test] + fn test_set_rollback() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for (idx, &ptr) in ptrs.iter().take(5).enumerate() { + wlog.set(ptr, make_node(idx + 10)); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + core::mem::drop(wlog); + assert_nodes(10, &alloc, &ptrs, 0); + } + + #[test] + fn test_alloc_commit() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + let new_ptrs = (10..20) + .map(|num| wlog.alloc(make_node(num)).unwrap()) + .collect::>(); + assert_nodes(20, &wlog.allocator(), &ptrs, 0); + assert_nodes(20, &wlog.allocator(), &new_ptrs, 10); + wlog.commit(); + assert_nodes(20, &alloc, &ptrs, 0); + assert_nodes(20, &alloc, &new_ptrs, 10); + } + + #[test] + fn test_alloc_rollback() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + let new_ptrs = (10..20) + .map(|num| wlog.alloc(make_node(num)).unwrap()) + .collect::>(); + assert_nodes(20, &wlog.allocator(), &ptrs, 0); + assert_nodes(20, &wlog.allocator(), &new_ptrs, 10); + core::mem::drop(wlog); + assert_nodes(10, &alloc, &ptrs, 0); + } + + #[test] + fn test_free_commit() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for num in 5..10 { + wlog.free(ptrs[num]); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + wlog.commit(); + assert_nodes(5, &alloc, &ptrs[..5], 0); + } + + #[test] + fn test_free_rollback() { + let (mut alloc, ptrs) = make_allocator(); + let mut wlog = WriteLog::new(&mut alloc); + for num in 5..10 { + wlog.free(ptrs[num]); + } + assert_nodes(10, wlog.allocator(), &ptrs, 0); + core::mem::drop(wlog); + assert_nodes(10, &alloc, &ptrs, 0); + } +} diff --git a/sealable-trie/src/nodes.rs b/sealable-trie/src/nodes.rs new file mode 100644 index 00000000..88ddae53 --- /dev/null +++ b/sealable-trie/src/nodes.rs @@ -0,0 +1,584 @@ +use crate::bits::Slice; +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::{bits, stdx}; + +#[cfg(test)] +mod stress_tests; +#[cfg(test)] +mod tests; + +pub(crate) const MAX_EXTENSION_KEY_SIZE: usize = 34; + +type Result = core::result::Result; + +/// A trie node. +/// +/// There are three types of nodes: branches, extensions and values. +/// +/// A branch node has two children which reference other nodes (both are always +/// present). +/// +/// An extension represents a path in a node which doesn’t branch. For example, +/// if trie contains key 0 and 1 then the root node will be an extension with +/// 0b0000_000 as the key and a branch node as a child. +/// +/// The space for key in extension node is limited (max 34 bytes), if longer key +/// is needed, an extension node may point at another extension node. +/// +/// A value node holds hash of the stored value at the key. Furthermore, if the +/// key is a prefix it stores a reference to another node which continues the +/// key. This reference is never a value reference. +/// +/// A node reference either points at another Node or is hash of the stored +/// value. The reference is represented by the `R` generic argument. +/// +/// [`Node`] object can be constructed either from a [`RawNode`]. +/// +/// The generic argument `P` specifies how pointers to nodes are represented and +/// `S` specifies how value being sealed or not is encoded. To represent value +/// parsed from a raw node representation, those types should be `Option` +/// and `bool` respectively. However, when dealing with proofs, pointer and +/// seal information is not available thus both of those types should be a unit +/// type. +#[derive(Clone, Copy, Debug)] +pub enum Node<'a, P = Option, S = bool> { + Branch { + /// Children of the branch. Both are always set. + children: [Reference<'a, P, S>; 2], + }, + Extension { + /// Key of the extension. + key: Slice<'a>, + /// Child node or value pointed by the extension. + child: Reference<'a, P, S>, + }, + Value { + value: ValueRef<'a, S>, + child: NodeRef<'a, P>, + }, +} + +/// Binary representation of the node as kept in the persistent storage. +/// +/// This representation is compact and includes internal details needed to +/// maintain the data-structure which shouldn’t be leaked to the clients of the +/// library and which don’t take part in hashing of the node. +// +// ```ignore +// Branch: +// A branch holds two references. Both of them are always set. Note that +// reference’s most significant bit is always zero thus the first bit of +// a node representation distinguishes whether node is a branch or not. +// +// Extension: 1000_kkkk kkkk_kooo +// `kkkk` is the length of the key in bits and `ooo` is number of most +// significant bits in to skip before getting to the key. is +// 36-byte array which holds the key extension. Only `o..o+k` bits in it +// are the actual key; others are set to zero. +// +// Value: 11s0_0000 0000_0000 0000_0000 0000_0000 +// is the hash of the stored value. `s` is zero if the value hasn’t +// been sealed, one otherwise. is a references the child node +// which points to the subtrie rooted at the key of the value. Value node +// can only point at Branch or Extension node. +// ``` +// +// A Reference is a 36-byte sequence consisting of a 4-byte pointer and +// a 32-byte hash. The most significant bit of the pointer is always zero (this +// is so that Branch nodes can be distinguished from other nodes). The second +// most significant bit is zero if the reference is a node reference and one if +// it’s a value reference. +// +// ```ignore +// Node Ref: 0b00pp_pppp pppp_pppp pppp_pppp pppp_pppp +// `ppp` is the pointer to the node. If it’s zero than the node is sealed +// the it’s not stored anywhere. +// +// Value Ref: 0b01s0_0000 0000_0000 0000_0000 0000_0000 +// `s` determines whether the value is sealed or not. If it is, it cannot be +// changed. +// ``` +// +// The actual pointer value is therefore 30-bit long. +#[derive(Clone, Copy, PartialEq, derive_more::Deref)] +#[repr(transparent)] +pub struct RawNode(pub(crate) [u8; 72]); + +/// Reference either to a node or a value as held in Branch or Extension nodes. +/// +/// See [`Node`] documentation for meaning of `P` and `S` generic arguments. +#[derive(Clone, Copy, Debug, derive_more::From, derive_more::TryInto)] +pub enum Reference<'a, P = Option, S = bool> { + Node(NodeRef<'a, P>), + Value(ValueRef<'a, S>), +} + +/// Reference to a node as held in Value node. +/// +/// See [`Node`] documentation for meaning of the `P` generic argument. +#[derive(Clone, Copy, Debug)] +pub struct NodeRef<'a, P = Option> { + pub hash: &'a CryptoHash, + pub ptr: P, +} + +/// Reference to a value as held in Value node. +/// +/// See [`Node`] documentation for meaning of the `S` generic argument. +#[derive(Clone, Copy, Debug)] +pub struct ValueRef<'a, S = bool> { + pub hash: &'a CryptoHash, + pub is_sealed: S, +} + + +// ============================================================================= +// Implementations + +impl<'a, P, S> Node<'a, P, S> { + /// Constructs a Branch node with specified children. + pub fn branch( + left: Reference<'a, P, S>, + right: Reference<'a, P, S>, + ) -> Self { + Self::Branch { children: [left, right] } + } + + /// Constructs an Extension node with given key and child. + /// + /// Note that length of the key is not checked. It’s possible to create + /// a node which cannot be encoded either in raw or proof format. For an + /// Extension node to be able to be encoded, the key’s underlying bytes + /// slice must not exceed [`MAX_EXTENSION_KEY_SIZE`] bytes. + pub fn extension(key: Slice<'a>, child: Reference<'a, P, S>) -> Self { + Self::Extension { key, child } + } + + /// Constructs a Value node with given value hash and child. + pub fn value(value: ValueRef<'a, S>, child: NodeRef<'a, P>) -> Self { + Self::Value { value, child } + } + + /// Returns a hash of the node. + /// + /// Hash changes if and only if the value of the node (if any) and any child + /// node changes. Sealing or changing pointer value in a node reference + /// doesn’t count as changing the node. + pub fn hash(&self) -> CryptoHash { + let mut buf = [0; 68]; + + fn tag_hash_hash( + buf: &mut [u8; 68], + tag: u8, + lft: &CryptoHash, + rht: &CryptoHash, + ) -> usize { + let buf = stdx::split_array_mut::<65, 3, 68>(buf).0; + let (t, rest) = stdx::split_array_mut::<1, 64, 65>(buf); + let (l, r) = stdx::split_array_mut(rest); + *t = [tag]; + *l = lft.0; + *r = rht.0; + buf.len() + } + + let len = match self { + Node::Branch { children: [left, right] } => { + // tag = 0b0000_00xy where x and y indicate whether left and + // right children respectively are value references. + let tag = (u8::from(left.is_value()) << 1) | + u8::from(right.is_value()); + tag_hash_hash(&mut buf, tag, left.hash(), right.hash()) + } + Node::Value { value, child } => { + tag_hash_hash(&mut buf, 0xC0, &value.hash, &child.hash) + } + Node::Extension { key, child } => { + let key_buf = stdx::split_array_mut::<36, 32, 68>(&mut buf).0; + // tag = 0b100v_0000 where v indicates whether the child is + // a value reference. + let tag = 0x80 | (u8::from(child.is_value()) << 4); + if let Some(len) = key.encode_into(key_buf, tag) { + buf[len..len + 32].copy_from_slice(child.hash().as_slice()); + len + 32 + } else { + return hash_extension_slow_path(*key, child); + } + } + }; + CryptoHash::digest(&buf[..len]) + } +} + +impl<'a> Node<'a> { + /// Builds raw representation of given node. + /// + /// Returns an error if this node is an Extension with a key of invalid + /// length (either empty or too long). + pub fn encode(&self) -> Result { + match self { + Node::Branch { children: [left, right] } => { + Ok(RawNode::branch(*left, *right)) + } + Node::Extension { key, child } => { + RawNode::extension(*key, *child).ok_or(()) + } + Node::Value { value, child } => Ok(RawNode::value(*value, *child)), + } + } +} + +/// Hashes an Extension node with oversized key. +/// +/// Normally, this is never called since we should calculate hashes of nodes +/// whose keys fit in the [`MAX_EXTENSION_KEY_SIZE`] limit. However, to +/// avoid having to handle errors we use this slow path to calculate hashes +/// for nodes with longer keys. +#[cold] +fn hash_extension_slow_path( + key: bits::Slice, + child: &Reference, +) -> CryptoHash { + let mut builder = CryptoHash::builder(); + // tag = 0b100v_0000 where v indicates whether the child is a value + // reference. + let tag = 0x80 | (u8::from(child.is_value()) << 4); + key.write_into(|bytes| builder.update(bytes), tag); + builder.update(child.hash().as_slice()); + builder.build() +} + +impl RawNode { + /// Constructs a Branch node with specified children. + pub fn branch(left: Reference, right: Reference) -> Self { + let mut res = Self([0; 72]); + let (lft, rht) = res.halfs_mut(); + *lft = left.encode(); + *rht = right.encode(); + res + } + + /// Constructs an Extension node with given key and child. + /// + /// Fails and returns `None` if the key is empty or its underlying bytes + /// slice is too long. The slice must not exceed [`MAX_EXTENSION_KEY_SIZE`] + /// to be valid. + pub fn extension(key: Slice, child: Reference) -> Option { + let mut res = Self([0; 72]); + let (lft, rht) = res.halfs_mut(); + key.encode_into(lft, 0x80)?; + *rht = child.encode(); + Some(res) + } + + /// Constructs a Value node with given value hash and child. + pub fn value(value: ValueRef, child: NodeRef) -> Self { + let mut res = Self([0; 72]); + let (lft, rht) = res.halfs_mut(); + *lft = Reference::Value(value).encode(); + lft[0] |= 0x80; + *rht = Reference::Node(child).encode(); + res + } + + /// Decodes raw node into a [`Node`]. + /// + /// In debug builds panics if `node` holds malformed representation, i.e. if + /// any unused bits (which must be cleared) are set. + // TODO(mina86): Convert debug_assertions to the method returning Result. + pub fn decode(&self) -> Node { + let (left, right) = self.halfs(); + let right = Reference::from_raw(right, false); + // `>> 6` to grab the two most significant bits only. + let tag = self.first() >> 6; + if tag == 0 || tag == 1 { + // Branch + Node::Branch { children: [Reference::from_raw(left, false), right] } + } else if tag == 2 { + // Extension + let key = Slice::decode(left, 0x80).unwrap_or_else(|| { + panic!("Failed decoding raw: {self:?}"); + }); + Node::Extension { key, child: right } + } else { + // Value + let (num, value) = stdx::split_array_ref::<4, 32, 36>(left); + let num = u32::from_be_bytes(*num); + debug_assert_eq!( + 0xC000_0000, + num & !0x2000_0000, + "Failed decoding raw node: {self:?}", + ); + let value = ValueRef::new(num & 0x2000_0000 != 0, value.into()); + let child = right.try_into().unwrap_or_else(|_| { + debug_assert!(false, "Failed decoding raw node: {self:?}"); + NodeRef::new(None, &CryptoHash::DEFAULT) + }); + Node::Value { value, child } + } + } + + /// Returns the first byte in the raw representation. + fn first(&self) -> u8 { self.0[0] } + + /// Splits the raw byte representation in two halfs. + fn halfs(&self) -> (&[u8; 36], &[u8; 36]) { + stdx::split_array_ref::<36, 36, 72>(&self.0) + } + + /// Splits the raw byte representation in two halfs. + fn halfs_mut(&mut self) -> (&mut [u8; 36], &mut [u8; 36]) { + stdx::split_array_mut::<36, 36, 72>(&mut self.0) + } +} + +impl<'a, P, S> Reference<'a, P, S> { + /// Returns whether the reference is to a node. + pub fn is_node(&self) -> bool { matches!(self, Self::Node(_)) } + + /// Returns whether the reference is to a value. + pub fn is_value(&self) -> bool { matches!(self, Self::Value(_)) } + + /// Returns node’s or value’s hash depending on type of reference. + /// + /// Use [`Self::is_value`] and [`Self::is_proof`] to check whether + fn hash(&self) -> &'a CryptoHash { + match self { + Self::Node(node) => node.hash, + Self::Value(value) => value.hash, + } + } +} + +impl<'a> Reference<'a> { + /// Creates a new reference pointing at given node. + #[inline] + pub fn node(ptr: Option, hash: &'a CryptoHash) -> Self { + Self::Node(NodeRef::new(ptr, hash)) + } + + /// Creates a new reference pointing at value with given hash. + #[inline] + pub fn value(is_sealed: bool, hash: &'a CryptoHash) -> Self { + Self::Value(ValueRef::new(is_sealed, hash)) + } + + /// Returns whether the reference is to a sealed node or value. + #[inline] + pub fn is_sealed(&self) -> bool { + match self { + Self::Node(node) => node.ptr.is_none(), + Self::Value(value) => value.is_sealed, + } + } + + /// Parses bytes to form a raw node reference representation. + /// + /// Assumes that the bytes are trusted. I.e. doesn’t verify that the most + /// significant bit is zero or that if second bit is one than pointer value + /// must be zero. + /// + /// In debug builds, panics if `bytes` has non-canonical representation, + /// i.e. any unused bits are set. `value_high_bit` in this case determines + /// whether for value reference the most significant bit should be set or + /// not. This is to facilitate decoding Value nodes. The argument is + /// ignored in builds with debug assertions disabled. + // TODO(mina86): Convert debug_assertions to the method returning Result. + fn from_raw(bytes: &'a [u8; 36], value_high_bit: bool) -> Self { + let (ptr, hash) = stdx::split_array_ref::<4, 32, 36>(bytes); + let ptr = u32::from_be_bytes(*ptr); + let hash = hash.into(); + if ptr & 0x4000_0000 == 0 { + // The two most significant bits must be zero. + debug_assert_eq!( + 0, + ptr & 0xC000_0000, + "Failed decoding Reference: {bytes:?}" + ); + let ptr = Ptr::new_truncated(ptr); + Self::Node(NodeRef { ptr, hash }) + } else { + // * The most significant bit is set only if value_high_bit is true. + // * The second most significant bit (so 0b4000_0000) is always set. + // * The third most significant bit (so 0b2000_0000) specifies + // whether value is sealed. + debug_assert_eq!( + 0x4000_0000 | (u32::from(value_high_bit) << 31), + ptr & !0x2000_0000, + "Failed decoding Reference: {bytes:?}" + ); + let is_sealed = ptr & 0x2000_0000 != 0; + Self::Value(ValueRef { is_sealed, hash }) + } + } + + /// Encodes the node reference into the buffer. + fn encode(&self) -> [u8; 36] { + let (num, hash) = match self { + Self::Node(node) => { + (node.ptr.map_or(0, |ptr| ptr.get()), node.hash) + } + Self::Value(value) => { + (0x4000_0000 | (u32::from(value.is_sealed) << 29), value.hash) + } + }; + let mut buf = [0; 36]; + let (left, right) = stdx::split_array_mut::<4, 32, 36>(&mut buf); + *left = num.to_be_bytes(); + *right = hash.into(); + buf + } +} + +impl<'a> Reference<'a, (), ()> { + pub fn new(is_value: bool, hash: &'a CryptoHash) -> Self { + match is_value { + false => NodeRef::new((), hash).into(), + true => ValueRef::new((), hash).into(), + } + } +} + +impl<'a, P> NodeRef<'a, P> { + /// Constructs a new node reference. + #[inline] + pub fn new(ptr: P, hash: &'a CryptoHash) -> Self { Self { ptr, hash } } +} + +impl<'a> NodeRef<'a, Option> { + /// Returns sealed version of the reference. The hash remains unchanged. + #[inline] + pub fn sealed(self) -> Self { Self { ptr: None, hash: self.hash } } +} + +impl<'a, S> ValueRef<'a, S> { + /// Constructs a new node reference. + #[inline] + pub fn new(is_sealed: S, hash: &'a CryptoHash) -> Self { + Self { is_sealed, hash } + } +} + +impl<'a> ValueRef<'a, bool> { + /// Returns sealed version of the reference. The hash remains unchanged. + #[inline] + pub fn sealed(self) -> Self { Self { is_sealed: true, hash: self.hash } } +} + + +// ============================================================================= +// PartialEq + +// Are those impls dumb? Yes, they absolutely are. However, when I used +// #[derive(PartialEq)] I run into lifetime issues. +// +// My understanding is that derive would create implementation for the same +// lifetime on LHS and RHS types (e.g. `impl<'a> PartialEq> for +// Ref<'a>`). As a result, when comparing two objects Rust would try to match +// their lifetimes which wasn’t always possible. + +impl<'a, 'b, P, S> core::cmp::PartialEq> for Node<'a, P, S> +where + P: PartialEq, + S: PartialEq, +{ + fn eq(&self, rhs: &Node<'b, P, S>) -> bool { + match (self, rhs) { + ( + Node::Branch { children: lhs }, + Node::Branch { children: rhs }, + ) => lhs == rhs, + ( + Node::Extension { key: lhs_key, child: lhs_child }, + Node::Extension { key: rhs_key, child: rhs_child }, + ) => lhs_key == rhs_key && lhs_child == rhs_child, + ( + Node::Value { value: lhs_value, child: lhs }, + Node::Value { value: rhs_value, child: rhs }, + ) => lhs_value == rhs_value && lhs == rhs, + _ => false, + } + } +} + +impl<'a, 'b, P, S> core::cmp::PartialEq> + for Reference<'a, P, S> +where + P: PartialEq, + S: PartialEq, +{ + fn eq(&self, rhs: &Reference<'b, P, S>) -> bool { + match (self, rhs) { + (Reference::Node(lhs), Reference::Node(rhs)) => lhs == rhs, + (Reference::Value(lhs), Reference::Value(rhs)) => lhs == rhs, + _ => false, + } + } +} + +impl<'a, 'b, P> core::cmp::PartialEq> for NodeRef<'a, P> +where + P: PartialEq, +{ + fn eq(&self, rhs: &NodeRef<'b, P>) -> bool { + self.ptr == rhs.ptr && self.hash == rhs.hash + } +} + +impl<'a, 'b, S> core::cmp::PartialEq> for ValueRef<'a, S> +where + S: PartialEq, +{ + fn eq(&self, rhs: &ValueRef<'b, S>) -> bool { + self.is_sealed == rhs.is_sealed && self.hash == rhs.hash + } +} + +// ============================================================================= +// Formatting + +impl core::fmt::Debug for RawNode { + fn fmt(&self, fmtr: &mut core::fmt::Formatter) -> core::fmt::Result { + fn write_raw_key( + fmtr: &mut core::fmt::Formatter, + separator: &str, + bytes: &[u8; 36], + ) -> core::fmt::Result { + let (tag, key) = stdx::split_array_ref::<2, 34, 36>(bytes); + write!(fmtr, "{separator}{:04x}", u16::from_be_bytes(*tag))?; + write_binary(fmtr, ":", key) + } + + fn write_raw_ptr( + fmtr: &mut core::fmt::Formatter, + separator: &str, + bytes: &[u8; 36], + ) -> core::fmt::Result { + let (ptr, hash) = stdx::split_array_ref::<4, 32, 36>(bytes); + let ptr = u32::from_be_bytes(*ptr); + let hash = <&CryptoHash>::from(hash); + write!(fmtr, "{separator}{ptr:08x}:{hash}") + } + + let (left, right) = self.halfs(); + if self.first() & 0xC0 == 0x80 { + write_raw_key(fmtr, "", left) + } else { + write_raw_ptr(fmtr, "", left) + }?; + write_raw_ptr(fmtr, ":", right) + } +} + +fn write_binary( + fmtr: &mut core::fmt::Formatter, + mut separator: &str, + bytes: &[u8], +) -> core::fmt::Result { + for byte in bytes { + write!(fmtr, "{separator}{byte:02x}")?; + separator = "_"; + } + Ok(()) +} diff --git a/sealable-trie/src/nodes/stress_tests.rs b/sealable-trie/src/nodes/stress_tests.rs new file mode 100644 index 00000000..ba73d74e --- /dev/null +++ b/sealable-trie/src/nodes/stress_tests.rs @@ -0,0 +1,134 @@ +//! Random stress tests. They generate random data and perform round-trip +//! conversion checking if they result in the same output. +//! +//! The test may be slow, especially when run under MIRI. Number of iterations +//! it performs can be controlled by STRESS_TEST_ITERATIONS environment +//! variable. + +use pretty_assertions::assert_eq; + +use crate::memory::Ptr; +use crate::nodes::{self, Node, NodeRef, RawNode, Reference, ValueRef}; +use crate::test_utils::get_iteration_count; +use crate::{bits, stdx}; + +/// Generates random raw representation and checks decode→encode round-trip. +#[test] +fn stress_test_raw_encoding_round_trip() { + let mut rng = rand::thread_rng(); + let mut raw = RawNode([0; 72]); + for _ in 0..get_iteration_count() { + gen_random_raw_node(&mut rng, &mut raw.0); + let node = raw.decode(); + // Test RawNode→Node→RawNode round trip conversion. + assert_eq!(Ok(raw), node.encode(), "node: {node:?}"); + } +} + +/// Generates a random raw node representation in canonical representation. +fn gen_random_raw_node(rng: &mut impl rand::Rng, bytes: &mut [u8; 72]) { + fn make_ref_canonical(bytes: &mut [u8]) { + if bytes[0] & 0x40 == 0 { + // Node reference. Pointer can be non-zero. + bytes[0] &= !0x80; + } else { + // Value reference. Pointer must be zero but key is_sealed flag: + // 0b01s0_0000 + bytes[..4].copy_from_slice(&0x6000_0000_u32.to_be_bytes()); + } + } + + rng.fill(&mut bytes[..]); + let tag = bytes[0] >> 6; + if tag == 0 || tag == 1 { + // Branch. + make_ref_canonical(&mut bytes[..36]); + make_ref_canonical(&mut bytes[36..]); + } else if tag == 2 { + // Extension. Key must be valid and the most significant bit of + // the child must be zero. For the former it’s easiest to just + // regenerate random data. + + // Random length and offset for the key. + let offset = rng.gen::() % 8; + let max_length = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16; + let length = rng.gen_range(1..=max_length - u16::from(offset)); + let tag = 0x8000 | (length << 3) | u16::from(offset); + bytes[..2].copy_from_slice(&tag.to_be_bytes()[..]); + + // Clear unused bits in the key. The easiest way to do it is by using + // bits::Slice. + let mut tmp = [0; 36]; + bits::Slice::new(&bytes[2..36], offset, length) + .unwrap() + .encode_into(&mut tmp, 0) + .unwrap(); + bytes[0..36].copy_from_slice(&tmp); + + make_ref_canonical(&mut bytes[36..]); + } else { + // Value. Most bits in the first four bytes must be zero and child must + // be a node reference. + bytes[0] &= 0xE0; + bytes[1] = 0; + bytes[2] = 0; + bytes[3] = 0; + bytes[36] &= !0xC0; + } +} + +// ============================================================================= + +/// Generates random node and tests encode→decode round trips. +#[test] +fn stress_test_node_encoding_round_trip() { + let mut rng = rand::thread_rng(); + let mut buf = [0; 66]; + for _ in 0..get_iteration_count() { + let node = gen_random_node(&mut rng, &mut buf); + + let raw = super::tests::raw_from_node(&node); + assert_eq!(node, raw.decode(), "Failed decoding Raw: {raw:?}"); + } +} + +/// Generates a random Node. +fn gen_random_node<'a>( + rng: &mut impl rand::Rng, + buf: &'a mut [u8; 66], +) -> Node<'a> { + fn rand_ref<'a>( + rng: &mut impl rand::Rng, + hash: &'a [u8; 32], + ) -> Reference<'a> { + let num = rng.gen::(); + if num < 0x8000_0000 { + Reference::node(Ptr::new(num).ok().flatten(), hash.into()) + } else { + Reference::value(num & 1 != 0, hash.into()) + } + } + + rng.fill(&mut buf[..]); + let (key, right) = stdx::split_array_ref::<34, 32, 66>(buf); + let (_, left) = stdx::split_array_ref::<2, 32, 34>(key); + match rng.gen_range(0..3) { + 0 => Node::branch(rand_ref(rng, &left), rand_ref(rng, &right)), + 1 => { + let offset = rng.gen::() % 8; + let max_length = (nodes::MAX_EXTENSION_KEY_SIZE * 8) as u16; + let length = rng.gen_range(1..=max_length - u16::from(offset)); + let key = bits::Slice::new(&key[..], offset, length).unwrap(); + Node::extension(key, rand_ref(rng, &right)) + } + 2 => { + let num = rng.gen::(); + let is_sealed = num & 0x8000_0000 != 0; + let value = ValueRef::new(is_sealed, left.into()); + let ptr = Ptr::new(num & 0x7FFF_FFFF).ok().flatten(); + let child = NodeRef::new(ptr, right.into()); + Node::value(value, child) + } + _ => unreachable!(), + } +} diff --git a/sealable-trie/src/nodes/tests.rs b/sealable-trie/src/nodes/tests.rs new file mode 100644 index 00000000..e1e8dd5e --- /dev/null +++ b/sealable-trie/src/nodes/tests.rs @@ -0,0 +1,251 @@ +use base64::engine::general_purpose::STANDARD as BASE64_ENGINE; +use base64::Engine; +use pretty_assertions::assert_eq; + +use crate::bits; +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, RawNode, Reference, ValueRef}; + +const DEAD: Ptr = match Ptr::new(0xDEAD) { + Ok(Some(ptr)) => ptr, + _ => panic!(), +}; +const BEEF: Ptr = match Ptr::new(0xBEEF) { + Ok(Some(ptr)) => ptr, + _ => panic!(), +}; + +const ONE: CryptoHash = CryptoHash([1; 32]); +const TWO: CryptoHash = CryptoHash([2; 32]); + +/// Converts `Node` into `RawNode` while also checking inverse conversion. +/// +/// Converts `Node` into `RawNode` and then back into `Node`. Panics if the +/// first and last objects aren’t equal. Returns the raw node. +#[track_caller] +pub(super) fn raw_from_node(node: &Node) -> RawNode { + let raw = node + .encode() + .unwrap_or_else(|()| panic!("Failed encoding node as raw: {node:?}")); + assert_eq!( + *node, + raw.decode(), + "Node → RawNode → Node gave different result:\n Raw: {raw:?}" + ); + raw +} + +/// Checks raw encoding of given node. +/// +/// 1. Encodes `node` into raw node node representation and compares the result +/// with expected `want` slices. +/// 2. Verifies Node→RawNode→Node round-trip conversion. +/// 3. Verifies that hash of the node equals the one provided. +/// 4. If node is an Extension, checks if slow path hash calculation produces +/// the same hash. +#[track_caller] +fn check_node_encoding(node: Node, want: [u8; 72], want_hash: &str) { + let raw = raw_from_node(&node); + assert_eq!(want, raw.0, "Unexpected raw representation"); + assert_eq!(node, RawNode(want).decode(), "Bad Raw→Node conversion"); + + let want_hash = BASE64_ENGINE.decode(want_hash).unwrap(); + let want_hash = <&[u8; 32]>::try_from(want_hash.as_slice()).unwrap(); + let want_hash = CryptoHash::from(*want_hash); + assert_eq!(want_hash, node.hash(), "Unexpected hash of {node:?}"); + + if let Node::Extension { key, child } = node { + let got = super::hash_extension_slow_path(key, &child); + assert_eq!(want_hash, got, "Unexpected slow path hash of {node:?}"); + } +} + +#[test] +#[rustfmt::skip] +fn test_branch_encoding() { + // Branch with two node children. + check_node_encoding(Node::Branch { + children: [ + Reference::node(Some(DEAD), &ONE), + Reference::node(None, &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0xDE, 0xAD, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "MvstRBYGfFv/BkI+GHFK04hDZde4FtNKd7M1J9hDhiQ="); + + check_node_encoding(Node::Branch { + children: [ + Reference::node(None, &ONE), + Reference::node(Some(DEAD), &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0xDE, 0xAD, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "MvstRBYGfFv/BkI+GHFK04hDZde4FtNKd7M1J9hDhiQ="); + + // Branch with first child being a node and second being a value. + check_node_encoding(Node::Branch { + children: [ + Reference::node(Some(DEAD), &ONE), + Reference::value(false, &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0xDE, 0xAD, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x40, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "szHabsSdRUfZlCpnJ+USP2m+1aC5esFxz7/WIBQx/Po="); + + check_node_encoding(Node::Branch { + children: [ + Reference::node(None, &ONE), + Reference::value(true, &TWO), + ], + }, [ + /* ptr1: */ 0, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x60, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "szHabsSdRUfZlCpnJ+USP2m+1aC5esFxz7/WIBQx/Po="); + + // Branch with first child being a value and second being a node. + check_node_encoding(Node::Branch { + children: [ + Reference::value(true, &ONE), + Reference::node(Some(BEEF), &TWO), + ], + }, [ + /* ptr1: */ 0x60, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0xBE, 0xEF, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "LGZgDJ1qtRlrhOX7OJQBVprw9OvP2sXOdj9Ow0xMQ18="); + + check_node_encoding(Node::Branch { + children: [ + Reference::value(false, &ONE), + Reference::node(None, &TWO), + ], + }, [ + /* ptr1: */ 0x40, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "LGZgDJ1qtRlrhOX7OJQBVprw9OvP2sXOdj9Ow0xMQ18="); + + // Branch with both children being values. + check_node_encoding(Node::Branch { + children: [ + Reference::value(false, &ONE), + Reference::value(true, &TWO), + ], + }, [ + /* ptr1: */ 0x40, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x60, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "O+AyRw5cqn52zppsf3w7xebru6xQ50qGvI7JgFQBNnE="); + + check_node_encoding(Node::Branch { + children: [ + Reference::value(true, &ONE), + Reference::value(false, &TWO), + ], + }, [ + /* ptr1: */ 0x60, 0, 0, 0, + /* hash1: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr2: */ 0x40, 0, 0, 0, + /* hash2: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "O+AyRw5cqn52zppsf3w7xebru6xQ50qGvI7JgFQBNnE="); +} + +#[test] +#[rustfmt::skip] +fn test_extension_encoding() { + // Extension pointing at a node + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 5, 25).unwrap(), + child: Reference::node(Some(DEAD), &ONE), + }, [ + /* tag: */ 0x80, 0xCD, + /* key: */ 0x07, 0xFF, 0xFF, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* ptr: */ 0, 0, 0xDE, 0xAD, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "JnUeS7R/A/mp22ytw/gzGLu24zHArCmVZJoMm4bcqGY="); + + // Extension pointing at a sealed node + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 5, 25).unwrap(), + child: Reference::node(None, &ONE), + }, [ + /* tag: */ 0x80, 0xCD, + /* key: */ 0x07, 0xFF, 0xFF, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* ptr: */ 0, 0, 0, 0, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "JnUeS7R/A/mp22ytw/gzGLu24zHArCmVZJoMm4bcqGY="); + + // Extension pointing at a value + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 4, 248).unwrap(), + child: Reference::value(false, &ONE), + }, [ + /* tag: */ 0x87, 0xC4, + /* key: */ 0x0F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + /* */ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xF0, + /* */ 0x00, 0x00, + /* ptr: */ 0x40, 0, 0, 0, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "uU9GlH+fEQAnezn3HWuvo/ZSBIhuSkuE2IGjhUFdC04="); + + check_node_encoding(Node::Extension { + key: bits::Slice::new(&[0xFF; 34], 4, 248).unwrap(), + child: Reference::value(true, &ONE), + }, [ + /* tag: */ 0x87, 0xC4, + /* key: */ 0x0F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + /* */ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xF0, + /* */ 0x00, 0x00, + /* ptr: */ 0x60, 0, 0, 0, + /* hash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ], "uU9GlH+fEQAnezn3HWuvo/ZSBIhuSkuE2IGjhUFdC04="); +} + +#[test] +#[rustfmt::skip] +fn test_value_encoding() { + check_node_encoding(Node::Value { + value: ValueRef::new(false, &ONE), + child: NodeRef::new(Some(BEEF), &TWO), + }, [ + /* tag: */ 0xC0, 0, 0, 0, + /* vhash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr: */ 0, 0, 0xBE, 0xEF, + /* chash: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "1uLWUNQTQCTNVP3Wle2aK1vQlrOPXf9EC0J6TLl4hrY="); + + check_node_encoding(Node::Value { + value: ValueRef::new(true, &ONE), + child: NodeRef::new(Some(BEEF), &TWO), + }, [ + /* tag: */ 0xE0, 0, 0, 0, + /* vhash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr: */ 0, 0, 0xBE, 0xEF, + /* chash: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "1uLWUNQTQCTNVP3Wle2aK1vQlrOPXf9EC0J6TLl4hrY="); + + check_node_encoding(Node::Value { + value: ValueRef::new(true, &ONE), + child: NodeRef::new(None, &TWO), + }, [ + /* tag: */ 0xE0, 0, 0, 0, + /* vhash: */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + /* ptr: */ 0, 0, 0, 0, + /* chash: */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], "1uLWUNQTQCTNVP3Wle2aK1vQlrOPXf9EC0J6TLl4hrY="); +} diff --git a/sealable-trie/src/proof.rs b/sealable-trie/src/proof.rs new file mode 100644 index 00000000..922a884c --- /dev/null +++ b/sealable-trie/src/proof.rs @@ -0,0 +1,399 @@ +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::num::NonZeroU16; + +use crate::bits; +use crate::hash::CryptoHash; +use crate::nodes::{Node, NodeRef, Reference, ValueRef}; + +/// A proof of a membership or non-membership of a key. +/// +/// The proof doesn’t include the key or value (in case of existence proofs). +/// It’s caller responsibility to pair proof with correct key and value. +#[derive(Clone, Debug, derive_more::From)] +pub enum Proof { + Positive(Membership), + Negative(NonMembership), +} + +/// A proof of a membership of a key. +#[derive(Clone, Debug)] +pub struct Membership(Vec); + +/// A proof of a membership of a key. +#[derive(Clone, Debug)] +pub struct NonMembership(Option>, Vec); + +/// A single item in a proof corresponding to a node in the trie. +#[derive(Clone, Debug)] +pub(crate) enum Item { + /// A Branch node where the other child is given reference. + Branch(OwnedRef), + /// An Extension node whose key has given length in bits. + Extension(NonZeroU16), + /// A Value node. + Value(CryptoHash), +} + +/// For non-membership proofs, description of the condition at which the lookup +/// failed. +#[derive(Clone, Debug)] +pub(crate) enum Actual { + /// A Branch node that has been reached at given key. + Branch(OwnedRef, OwnedRef), + + /// Length of the lookup key remaining after reaching given Extension node + /// whose key doesn’t match the lookup key. + Extension(u16, Box<[u8]>, OwnedRef), + + /// Length of the lookup key remaining after reaching a value reference with + /// given value hash. + LookupKeyLeft(NonZeroU16, CryptoHash), +} + +/// A reference to value or node. +#[derive(Clone, Debug)] +pub(crate) struct OwnedRef { + /// Whether the reference is for a value (rather than value). + is_value: bool, + /// Hash of the node or value the reference points at. + hash: CryptoHash, +} + +/// Builder for the proof. +pub(crate) struct Builder(Vec); + +impl Proof { + /// Verifies that this object proves membership or non-membership of given + /// key. + /// + /// If `value_hash` is `None`, verifies a non-membership proof. That is, + /// that this object proves that given `root_hash` there’s no value at + /// specified `key`. + /// + /// Otherwise, verifies a membership-proof. That is, that this object + /// proves that given `root_hash` there’s given `value_hash` stored at + /// specified `key`. + pub fn verify( + &self, + root_hash: &CryptoHash, + key: &[u8], + value_hash: Option<&CryptoHash>, + ) -> bool { + match (self, value_hash) { + (Self::Positive(proof), Some(hash)) => { + proof.verify(root_hash, key, hash) + } + (Self::Negative(proof), None) => proof.verify(root_hash, key), + _ => false, + } + } + + /// Creates a non-membership proof for cases when trie is empty. + pub(crate) fn empty_trie() -> Proof { + NonMembership(None, Vec::new()).into() + } + + /// Creates a builder which allows creation of proofs. + pub(crate) fn builder() -> Builder { Builder(Vec::new()) } +} + +impl Membership { + /// Verifies that this object proves membership of a given key. + pub fn verify( + &self, + root_hash: &CryptoHash, + key: &[u8], + value_hash: &CryptoHash, + ) -> bool { + if *root_hash == crate::trie::EMPTY_TRIE_ROOT { + false + } else if let Some(key) = bits::Slice::from_bytes(key) { + let want = OwnedRef::value(value_hash.clone()); + verify_impl(root_hash, key, want, &self.0).is_some() + } else { + false + } + } +} + +impl NonMembership { + /// Verifies that this object proves non-membership of a given key. + pub fn verify(&self, root_hash: &CryptoHash, key: &[u8]) -> bool { + if *root_hash == crate::trie::EMPTY_TRIE_ROOT { + true + } else if let Some((key, want)) = self.get_reference(key) { + verify_impl(root_hash, key, want, &self.1).is_some() + } else { + false + } + } + + /// Figures out reference to prove. + /// + /// For non-membership proofs, the proofs include the actual node that has + /// been found while looking up the key. This translates that information + /// into a key and reference that the rest of the commitment needs to prove. + fn get_reference<'a>( + &self, + key: &'a [u8], + ) -> Option<(bits::Slice<'a>, OwnedRef)> { + let mut key = bits::Slice::from_bytes(key)?; + match self.0.as_deref()? { + Actual::Branch(lft, rht) => { + // When traversing the trie, we’ve reached a Branch node at the + // lookup key. Lookup key is therefore a prefix of an existing + // value but there’s no value stored at it. + // + // We’re converting non-membership proof into proof that at key + // the given branch Node exists. + let node = Node::Branch { children: [lft.into(), rht.into()] }; + Some((key, OwnedRef::to(node))) + } + + Actual::Extension(left, key_buf, child) => { + // When traversing the trie, we’ve reached an Extension node + // whose key wasn’t a prefix of a lookup key. This could be + // because the extension key was longer or because some bits + // didn’t match. + // + // The first element specifies how many bits of the lookup key + // were left in it when the Extension node has been reached. + // + // We’re converting non-membership proof into proof that at + // shortened key the given Extension node exists. + let suffix = key.pop_back_slice(*left)?; + let ext_key = bits::Slice::decode(&key_buf[..], 0)?; + if suffix.starts_with(ext_key) { + // If key in the Extension node is a prefix of the + // remaining suffix of the lookup key, the proof is + // invalid. + None + } else { + let node = Node::Extension { + key: ext_key, + child: Reference::from(child), + }; + Some((key, OwnedRef::to(node))) + } + } + + Actual::LookupKeyLeft(len, hash) => { + // When traversing the trie, we’ve encountered a value reference + // before the lookup key has finished. `len` determines how + // many bits of the lookup key were not processed. `hash` is + // the value that was found at key[..(key.len() - len)] key. + // + // We’re converting non-membership proof into proof that at + // key[..(key.len() - len)] a `hash` value is stored. + key.pop_back_slice(len.get())?; + Some((key, OwnedRef::value(hash.clone()))) + } + } + } +} + +fn verify_impl( + root_hash: &CryptoHash, + mut key: bits::Slice, + mut want: OwnedRef, + proof: &[Item], +) -> Option<()> { + for item in proof { + let node = match item { + Item::Value(child) if want.is_value => Node::Value { + value: ValueRef::new((), &want.hash), + child: NodeRef::new((), child), + }, + Item::Value(child) => Node::Value { + value: ValueRef::new((), child), + child: NodeRef::new((), &want.hash), + }, + + Item::Branch(child) => { + let us = Reference::from(&want); + let them = child.into(); + let children = match key.pop_back()? { + false => [us, them], + true => [them, us], + }; + Node::Branch { children } + } + + Item::Extension(length) => Node::Extension { + key: key.pop_back_slice(length.get())?, + child: Reference::from(&want), + }, + }; + want = OwnedRef::to(node); + } + + // If we’re here we’ve reached root hash according to the proof. Check the + // key is empty and that hash we’ve calculated is the actual root. + (key.is_empty() && !want.is_value && *root_hash == want.hash).then_some(()) +} + +impl Item { + /// Constructs a new proof item corresponding to given branch. + /// + /// `us` indicates which branch is ours and which one is theirs. When + /// verifying a proof our hash can be computed thus the proof item will only + /// include their hash. + pub fn branch(us: bool, children: &[Reference; 2]) -> Self { + Self::Branch((&children[1 - usize::from(us)]).clone().into()) + } + + pub fn extension(length: u16) -> Option { + NonZeroU16::new(length).map(Self::Extension) + } +} + +impl Builder { + /// Adds a new item to the proof. + pub fn push(&mut self, item: Item) { self.0.push(item); } + + /// Reverses order of items in the builder. + /// + /// The items in the proof must be ordered from the node with the value + /// first with root node as the last entry. When traversing the trie nodes + /// may end up being added in opposite order. In those cases, this can be + /// used to reverse order of items so they are in the correct order. + pub fn reversed(mut self) -> Self { + self.0.reverse(); + self + } + + /// Constructs a new membership proof from added items. + pub fn build>(self) -> T { T::from(Membership(self.0)) } + + /// Constructs a new non-membership proof from added items and given + /// ‘actual’ entry. + /// + /// The actual describes what was actually found when traversing the trie. + pub fn negative>(self, actual: Actual) -> T { + T::from(NonMembership(Some(Box::new(actual)), self.0)) + } + + /// Creates a new non-membership proof after lookup reached a Branch node. + /// + /// If a Branch node has been found at the lookup key (rather than Value + /// node), this method allows creation of a non-membership proof. + /// `children` specifies children of the encountered Branch node. + pub fn reached_branch, P, S>( + self, + children: [Reference; 2], + ) -> T { + let [lft, rht] = children; + self.negative(Actual::Branch(lft.into(), rht.into())) + } + + /// Creates a new non-membership proof after lookup reached a Extension node. + /// + /// If a Extension node has been found which doesn’t match corresponding + /// portion of the lookup key (the extension key may be too long or just not + /// match it), this method allows creation of a non-membership proof. + /// + /// `left` is the number of bits left in the lookup key at the moment the + /// Extension node was encountered. `key` and `child` are corresponding + /// fields of the extension node. + pub fn reached_extension>( + self, + left: u16, + key: bits::Slice, + child: Reference, + ) -> T { + let mut buf = [0; 36]; + let len = key.encode_into(&mut buf, 0).unwrap(); + let ext_key = buf[..len].to_vec().into_boxed_slice(); + self.negative(Actual::Extension(left, ext_key, child.into())) + } + + /// Creates a new non-membership proof after lookup reached a value + /// reference. + /// + /// If the lookup key hasn’t terminated yet but a value reference has been + /// found, , this method allows creation of a non-membership proof. + /// + /// `left` is the number of bits left in the lookup key at the moment the + /// reference was encountered. `value` is the hash of the value from the + /// reference. + pub fn lookup_key_left>( + self, + left: NonZeroU16, + value: CryptoHash, + ) -> T { + self.negative(Actual::LookupKeyLeft(left, value)) + } +} + +impl OwnedRef { + /// Creates a reference pointing at node with given hash. + fn node(hash: CryptoHash) -> Self { Self { is_value: false, hash } } + /// Creates a reference pointing at value with given hash. + fn value(hash: CryptoHash) -> Self { Self { is_value: true, hash } } + /// Creates a reference pointing at given node. + fn to(node: Node) -> Self { Self::node(node.hash()) } +} + +impl<'a, P, S> From<&'a Reference<'a, P, S>> for OwnedRef { + fn from(rf: &'a Reference<'a, P, S>) -> OwnedRef { + let (is_value, hash) = match rf { + Reference::Node(node) => (false, node.hash.clone()), + Reference::Value(value) => (true, value.hash.clone()), + }; + Self { is_value, hash } + } +} + +impl<'a, P, S> From> for OwnedRef { + fn from(rf: Reference<'a, P, S>) -> OwnedRef { Self::from(&rf) } +} + +impl<'a> From<&'a OwnedRef> for Reference<'a, (), ()> { + fn from(rf: &'a OwnedRef) -> Self { Self::new(rf.is_value, &rf.hash) } +} + +#[test] +fn test_simple_success() { + let mut trie = crate::trie::Trie::test(1000); + let some_hash = CryptoHash::test(usize::MAX); + + for (idx, key) in ["foo", "bar", "baz", "qux"].into_iter().enumerate() { + let hash = CryptoHash::test(idx); + + assert_eq!( + Ok(()), + trie.set(key.as_bytes(), &hash), + "Failed setting {key} → {hash}", + ); + + let (got, proof) = trie.prove(key.as_bytes()).unwrap(); + assert_eq!(Some(hash.clone()), got, "Failed getting {key}"); + assert!( + proof.verify(trie.hash(), key.as_bytes(), Some(&hash)), + "Failed verifying {key} → {hash} get proof: {proof:?}", + ); + + assert!( + !proof.verify(trie.hash(), key.as_bytes(), None), + "Unexpectedly succeeded {key} → (none) proof: {proof:?}", + ); + assert!( + !proof.verify(trie.hash(), key.as_bytes(), Some(&some_hash)), + "Unexpectedly succeeded {key} → {some_hash} proof: {proof:?}", + ); + } + + for key in ["Foo", "fo", "ba", "bay", "foobar"] { + let (got, proof) = trie.prove(key.as_bytes()).unwrap(); + assert_eq!(None, got, "Unexpected result when getting {key}"); + assert!( + proof.verify(trie.hash(), key.as_bytes(), None), + "Failed verifying {key} → (none) proof: {proof:?}", + ); + assert!( + !proof.verify(trie.hash(), key.as_bytes(), Some(&some_hash)), + "Unexpectedly succeeded {key} → {some_hash} proof: {proof:?}", + ); + } +} diff --git a/sealable-trie/src/stdx.rs b/sealable-trie/src/stdx.rs new file mode 100644 index 00000000..5431d01d --- /dev/null +++ b/sealable-trie/src/stdx.rs @@ -0,0 +1,53 @@ +/// Splits `&[u8; L + R]` into `(&[u8; L], &[u8; R])`. +pub(crate) fn split_array_ref< + const L: usize, + const R: usize, + const N: usize, +>( + xs: &[u8; N], +) -> (&[u8; L], &[u8; R]) { + let () = AssertEqSum::::OK; + + let (left, right) = xs.split_at(L); + (left.try_into().unwrap(), right.try_into().unwrap()) +} + +/// Splits `&mut [u8; L + R]` into `(&mut [u8; L], &mut [u8; R])`. +pub(crate) fn split_array_mut< + const L: usize, + const R: usize, + const N: usize, +>( + xs: &mut [u8; N], +) -> (&mut [u8; L], &mut [u8; R]) { + let () = AssertEqSum::::OK; + + let (left, right) = xs.split_at_mut(L); + (left.try_into().unwrap(), right.try_into().unwrap()) +} + +/// Splits `&[u8]` into `(&[u8; L], &[u8])`. Returns `None` if input is too +/// shorter. +pub(crate) fn split_at(xs: &[u8]) -> Option<(&[u8; L], &[u8])> { + if xs.len() < L { + return None; + } + let (head, tail) = xs.split_at(L); + Some((head.try_into().unwrap(), tail)) +} + +/// Splits `&[u8]` into `(&[u8], &[u8; R])`. Returns `None` if input is too +/// shorter. +#[allow(dead_code)] +pub(crate) fn rsplit_at( + xs: &[u8], +) -> Option<(&[u8], &[u8; R])> { + let (head, tail) = xs.split_at(xs.len().checked_sub(R)?); + Some((head, tail.try_into().unwrap())) +} + +/// Asserts, at compile time, that `A + B == S`. +struct AssertEqSum; +impl AssertEqSum { + const OK: () = assert!(S == A + B); +} diff --git a/sealable-trie/src/test_utils.rs b/sealable-trie/src/test_utils.rs new file mode 100644 index 00000000..49fc317c --- /dev/null +++ b/sealable-trie/src/test_utils.rs @@ -0,0 +1,23 @@ +/// Reads `STRESS_TEST_ITERATIONS` environment variable to determine how many +/// iterations random tests should try. +/// +/// The variable is used by stress tests which generate random data to verify +/// invariant. By default they run hundred thousand iterations. The +/// aforementioned environment variable allows that number to be changed +/// (including to zero which effectively disables such tests). +pub(crate) fn get_iteration_count() -> usize { + use core::str::FromStr; + match std::env::var_os("STRESS_TEST_ITERATIONS") { + None => 100_000, + Some(val) => usize::from_str(val.to_str().unwrap()).unwrap(), + } +} + +/// Returns zero if dividend is zero otherwise `max(dividend / divisor, 1)`. +pub(crate) fn div_max_1(dividind: usize, divisor: usize) -> usize { + if dividind == 0 { + 0 + } else { + 1.max(dividind / divisor) + } +} diff --git a/sealable-trie/src/trie.rs b/sealable-trie/src/trie.rs new file mode 100644 index 00000000..396430c8 --- /dev/null +++ b/sealable-trie/src/trie.rs @@ -0,0 +1,317 @@ +use core::num::NonZeroU16; + +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, Reference}; +use crate::{bits, memory, proof}; + +mod seal; +mod set; +#[cfg(test)] +mod tests; + +/// Root trie hash if the trie is empty. +pub const EMPTY_TRIE_ROOT: CryptoHash = CryptoHash::DEFAULT; + +/// A Merkle Patricia Trie with sealing/pruning feature. +/// +/// The trie is designed to work in situations where space is constrained. To +/// that effect, it imposes certain limitations and implements feature which +/// help reduce its size. +/// +/// In the abstract, the trie is a regular Merkle Patricia Trie which allows +/// storing arbitrary (key, value) pairs. However: +/// +/// 1. The trie doesn’t actually store values but only their hashes. (Another +/// way to think about it is that all values are 32-byte long byte slices). +/// It’s assumed that clients store values in another location and use this +/// data structure only as a witness. Even though it doesn’t contain the +/// values it can generate proof of existence or non-existence of keys. +/// +/// 2. The trie allows values to be sealed. A hash of a sealed value can no +/// longer be accessed even though in abstract sense the value still resides +/// in the trie. That is, sealing a value doesn’t affect the state root +/// hash and old proofs for the value continue to be valid. +/// +/// Nodes of sealed values are removed from the trie to save storage. +/// Furthermore, if a children of an internal node have been sealed, that +/// node becomes sealed as well. For example, if keys `a` and `b` has +/// both been sealed, than branch node above them becomes sealed as well. +/// +/// To take most benefits from sealing, it’s best to seal consecutive keys. +/// For example, sealing keys `a`, `b`, `c` and `d` will seal their parents +/// as well while sealing keys `a`, `c`, `e` and `g` will leave their parents +/// unsealed and thus kept in the trie. +/// +/// 3. The trie is designed to work with a pool allocator and supports keeping +/// at most 2³⁰-2 nodes. Sealed values don’t count towards this limit since +/// they aren’t stored. In any case, this should be plenty since fully +/// balanced binary tree with that many nodes allows storing 500K keys. +/// +/// 4. Keys are limited to 8191 bytes (technically 2¹⁶-1 bits but there’s no +/// interface for keys which hold partial bytes). It would be possible to +/// extend this limit but 8k bytes should be plenty for any reasonable usage. +/// +/// As an optimisation to take advantage of trie’s internal structure, it’s +/// best to keep keys up to 36-byte long. Or at least, to keep common key +/// prefixes to be at most 36-byte long. For example, a trie which has +/// a single value at a key whose length is withing 36 bytes has a single +/// node however if that key is longer than 36 bytes the trie needs at least +/// two nodes. +pub struct Trie { + /// Pointer to the root node. `None` if the trie is empty or the root node + /// has been sealed. + root_ptr: Option, + + /// Hash of the root node; [`EMPTY_TRIE_ROOT`] if trie is empty. + root_hash: CryptoHash, + + /// Allocator used to access and allocate nodes. + alloc: A, +} + +/// Possible errors when reading or modifying the trie. +#[derive(Copy, Clone, PartialEq, Eq, Debug, derive_more::Display)] +pub enum Error { + #[display(fmt = "Tried to access empty key")] + EmptyKey, + #[display(fmt = "Key longer than 8191 bytes")] + KeyTooLong, + #[display(fmt = "Tried to access sealed node")] + Sealed, + #[display(fmt = "Value not found")] + NotFound, + #[display(fmt = "Not enough space")] + OutOfMemory, +} + +impl From for Error { + fn from(_: memory::OutOfMemory) -> Self { Self::OutOfMemory } +} + +type Result = ::core::result::Result; + +macro_rules! proof { + ($proof:ident push $item:expr) => { + $proof.as_mut().map(|proof| proof.push($item)); + }; + ($proof:ident rev) => { + $proof.map(|builder| builder.reversed().build()) + }; + ($proof:ident rev .$func:ident $($tt:tt)*) => { + $proof.map(|builder| builder.reversed().$func $($tt)*) + }; +} + +impl Trie { + /// Creates a new trie using given allocator. + pub fn new(alloc: A) -> Self { + Self { root_ptr: None, root_hash: EMPTY_TRIE_ROOT, alloc } + } + + /// Returns hash of the root node. + pub fn hash(&self) -> &CryptoHash { &self.root_hash } + + /// Returns whether the trie is empty. + pub fn is_empty(&self) -> bool { self.root_hash == EMPTY_TRIE_ROOT } + + /// Retrieves value at given key. + /// + /// Returns `None` if there’s no value at given key. Returns an error if + /// the value (or its ancestor) has been sealed. + pub fn get(&self, key: &[u8]) -> Result> { + let (value, _) = self.get_impl(key, true)?; + Ok(value) + } + + /// Retrieves value at given key and provides proof of the result. + /// + /// Returns `None` if there’s no value at given key. Returns an error if + /// the value (or its ancestor) has been sealed. + pub fn prove( + &self, + key: &[u8], + ) -> Result<(Option, proof::Proof)> { + let (value, proof) = self.get_impl(key, true)?; + Ok((value, proof.unwrap())) + } + + fn get_impl( + &self, + key: &[u8], + include_proof: bool, + ) -> Result<(Option, Option)> { + let mut key = bits::Slice::from_bytes(key).ok_or(Error::KeyTooLong)?; + if self.root_hash == EMPTY_TRIE_ROOT { + let proof = include_proof.then(|| proof::Proof::empty_trie()); + return Ok((None, proof)); + } + + let mut proof = include_proof.then(|| proof::Proof::builder()); + let mut node_ptr = self.root_ptr; + let mut node_hash = self.root_hash.clone(); + loop { + let node = self.alloc.get(node_ptr.ok_or(Error::Sealed)?); + let node = node.decode(); + debug_assert_eq!(node_hash, node.hash()); + + let child = match node { + Node::Branch { children } => { + if let Some(us) = key.pop_front() { + proof!(proof push proof::Item::branch(us, &children)); + children[usize::from(us)] + } else { + let proof = proof!(proof rev.reached_branch(children)); + return Ok((None, proof)); + } + } + + Node::Extension { key: ext_key, child } => { + if key.strip_prefix(ext_key) { + proof!(proof push proof::Item::extension(ext_key.len()).unwrap()); + child + } else { + let proof = proof!(proof rev.reached_extension(key.len(), ext_key, child)); + return Ok((None, proof)); + } + } + + Node::Value { value, child } => { + if value.is_sealed { + return Err(Error::Sealed); + } else if key.is_empty() { + proof!(proof push proof::Item::Value(child.hash.clone())); + let proof = proof!(proof rev.build()); + return Ok((Some(value.hash.clone()), proof)); + } else { + proof!(proof push proof::Item::Value(value.hash.clone())); + node_ptr = child.ptr; + node_hash = child.hash.clone(); + continue; + } + } + }; + + match child { + Reference::Node(node) => { + node_ptr = node.ptr; + node_hash = node.hash.clone(); + } + Reference::Value(value) => { + return if value.is_sealed { + Err(Error::Sealed) + } else if let Some(len) = NonZeroU16::new(key.len()) { + let proof = proof!(proof rev.lookup_key_left(len, value.hash.clone())); + Ok((None, proof)) + } else { + let proof = proof!(proof rev.build()); + Ok((Some(value.hash.clone()), proof)) + }; + } + }; + } + } + + /// Inserts a new value hash at given key. + /// + /// Sets value hash at given key to given to the provided one. If the value + /// (or one of its ancestors) has been sealed the operation fails with + /// [`Error::Sealed`] error. + /// + /// If `proof` is specified, stores proof nodes into the provided vector. + // TODO(mina86): Add set_with_proof as well as set_and_seal and + // set_and_seal_with_proof. + pub fn set(&mut self, key: &[u8], value_hash: &CryptoHash) -> Result<()> { + let (ptr, hash) = (self.root_ptr, self.root_hash.clone()); + let key = bits::Slice::from_bytes(key).ok_or(Error::KeyTooLong)?; + let (ptr, hash) = + set::SetContext::new(&mut self.alloc, key, value_hash) + .set(ptr, &hash)?; + self.root_ptr = Some(ptr); + self.root_hash = hash; + Ok(()) + } + + /// Seals value at given key as well as all descendant values. + /// + /// Once value is sealed, its hash can no longer be retrieved nor can it be + /// changed. Sealing a value seals the entire subtrie rooted at the key + /// (that is, if key `foo` is sealed, `foobar` is also sealed). + /// + /// However, it’s impossible to seal a subtrie unless there’s a value stored + /// at the key. For example, if trie contains key `foobar` only, neither + /// `foo` nor `qux` keys can be sealed. In those cases, function returns + /// an error. + // TODO(mina86): Add seal_with_proof. + pub fn seal(&mut self, key: &[u8]) -> Result<()> { + let key = bits::Slice::from_bytes(key).ok_or(Error::KeyTooLong)?; + if self.root_hash == EMPTY_TRIE_ROOT { + return Err(Error::NotFound); + } + + let seal = seal::SealContext::new(&mut self.alloc, key) + .seal(NodeRef::new(self.root_ptr, &self.root_hash))?; + if seal { + self.root_ptr = None; + } + Ok(()) + } + + /// Prints the trie. Used for testing and debugging only. + #[cfg(test)] + pub(crate) fn print(&self) { + use std::println; + + if self.root_hash == EMPTY_TRIE_ROOT { + println!("(empty)"); + } else { + self.print_impl(NodeRef::new(self.root_ptr, &self.root_hash), 0); + } + } + + #[cfg(test)] + fn print_impl(&self, nref: NodeRef, depth: usize) { + use std::{print, println}; + + let print_ref = |rf, depth| match rf { + Reference::Node(node) => self.print_impl(node, depth), + Reference::Value(value) => { + let is_sealed = if value.is_sealed { " (sealed)" } else { "" }; + println!("{:depth$}value {}{}", "", value.hash, is_sealed) + } + }; + + print!("{:depth$}«{}»", "", nref.hash); + let ptr = if let Some(ptr) = nref.ptr { + ptr + } else { + println!(" (sealed)"); + return; + }; + match self.alloc.get(ptr).decode() { + Node::Branch { children } => { + println!(" Branch"); + print_ref(children[0], depth + 2); + print_ref(children[1], depth + 2); + } + Node::Extension { key, child } => { + println!(" Extension {key}"); + print_ref(child, depth + 2); + } + Node::Value { value, child } => { + let is_sealed = if value.is_sealed { " (sealed)" } else { "" }; + println!(" Value {}{}", value.hash, is_sealed); + print_ref(Reference::from(child), depth + 2); + } + } + } +} + + +#[cfg(test)] +impl Trie { + /// Creates a test trie using a TestAllocator with given capacity. + pub(crate) fn test(capacity: usize) -> Self { + Self::new(memory::test_utils::TestAllocator::new(capacity)) + } +} diff --git a/sealable-trie/src/trie/seal.rs b/sealable-trie/src/trie/seal.rs new file mode 100644 index 00000000..50e68fdb --- /dev/null +++ b/sealable-trie/src/trie/seal.rs @@ -0,0 +1,175 @@ +use alloc::vec::Vec; + +use super::{Error, Result}; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, RawNode, Reference, ValueRef}; +use crate::{bits, memory}; + +/// Context for [`Trie::seal`] operation. +pub(super) struct SealContext<'a, A> { + /// Part of the key yet to be traversed. + /// + /// It starts as the key user provided and as trie is traversed bits are + /// removed from its front. + key: bits::Slice<'a>, + + /// Allocator used to retrieve and free nodes. + alloc: &'a mut A, +} + +impl<'a, A: memory::Allocator> SealContext<'a, A> { + pub(super) fn new(alloc: &'a mut A, key: bits::Slice<'a>) -> Self { + Self { key, alloc } + } + + /// Traverses the trie starting from node `ptr` to find node at context’s + /// key and seals it. + /// + /// Returns `true` if node at `ptr` has been sealed. This lets caller know + /// that `ptr` has been freed and it has to update references to it. + pub(super) fn seal(&mut self, nref: NodeRef) -> Result { + let ptr = nref.ptr.ok_or(Error::Sealed)?; + let node = self.alloc.get(ptr); + let node = node.decode(); + debug_assert_eq!(*nref.hash, node.hash()); + + let result = match node { + Node::Branch { children } => self.seal_branch(children), + Node::Extension { key, child } => self.seal_extension(key, child), + Node::Value { value, child } => self.seal_value(value, child), + }?; + + match result { + SealResult::Replace(node) => { + self.alloc.set(ptr, node); + Ok(false) + } + SealResult::Free => { + self.alloc.free(ptr); + Ok(true) + } + SealResult::Done => Ok(false), + } + } + + fn seal_branch( + &mut self, + mut children: [Reference; 2], + ) -> Result { + let side = match self.key.pop_front() { + None => return Err(Error::NotFound), + Some(bit) => usize::from(bit), + }; + match self.seal_child(children[side])? { + None => Ok(SealResult::Done), + Some(_) if children[1 - side].is_sealed() => Ok(SealResult::Free), + Some(child) => { + children[side] = child; + let node = RawNode::branch(children[0], children[1]); + Ok(SealResult::Replace(node)) + } + } + } + + fn seal_extension( + &mut self, + ext_key: bits::Slice, + child: Reference, + ) -> Result { + if !self.key.strip_prefix(ext_key) { + return Err(Error::NotFound); + } + Ok(if let Some(child) = self.seal_child(child)? { + let node = RawNode::extension(ext_key, child).unwrap(); + SealResult::Replace(node) + } else { + SealResult::Done + }) + } + + fn seal_value( + &mut self, + value: ValueRef, + child: NodeRef, + ) -> Result { + if value.is_sealed { + Err(Error::Sealed) + } else if self.key.is_empty() { + prune(self.alloc, child.ptr); + Ok(SealResult::Free) + } else if self.seal(child)? { + let child = NodeRef::new(None, child.hash); + let node = RawNode::value(value, child); + Ok(SealResult::Replace(node)) + } else { + Ok(SealResult::Done) + } + } + + fn seal_child<'b>( + &mut self, + child: Reference<'b>, + ) -> Result>> { + match child { + Reference::Node(node) => Ok(if self.seal(node)? { + Some(Reference::Node(node.sealed())) + } else { + None + }), + Reference::Value(value) => { + if value.is_sealed { + Err(Error::Sealed) + } else if self.key.is_empty() { + Ok(Some(value.sealed().into())) + } else { + Err(Error::NotFound) + } + } + } + } +} + +enum SealResult { + Free, + Replace(RawNode), + Done, +} + +/// Frees node and all its descendants from the allocator. +fn prune(alloc: &mut impl memory::Allocator, ptr: Option) { + let mut ptr = match ptr { + Some(ptr) => ptr, + None => return, + }; + let mut queue = Vec::new(); + loop { + let children = get_children(&alloc.get(ptr)); + alloc.free(ptr); + match children { + (None, None) => match queue.pop() { + Some(p) => ptr = p, + None => break, + }, + (Some(p), None) | (None, Some(p)) => ptr = p, + (Some(lhs), Some(rht)) => { + queue.push(lhs); + ptr = rht + } + } + } +} + +fn get_children(node: &RawNode) -> (Option, Option) { + fn get_ptr(child: Reference) -> Option { + match child { + Reference::Node(node) => node.ptr, + Reference::Value { .. } => None, + } + } + + match node.decode() { + Node::Branch { children: [lft, rht] } => (get_ptr(lft), get_ptr(rht)), + Node::Extension { child, .. } => (get_ptr(child), None), + Node::Value { child, .. } => (child.ptr, None), + } +} diff --git a/sealable-trie/src/trie/set.rs b/sealable-trie/src/trie/set.rs new file mode 100644 index 00000000..2a4cbd39 --- /dev/null +++ b/sealable-trie/src/trie/set.rs @@ -0,0 +1,344 @@ +use super::{Error, Result}; +use crate::hash::CryptoHash; +use crate::memory::Ptr; +use crate::nodes::{Node, NodeRef, RawNode, Reference, ValueRef}; +use crate::{bits, memory}; + +/// Context for [`Trie::set`] operation. +pub(super) struct SetContext<'a, A: memory::Allocator> { + /// Part of the key yet to be traversed. + /// + /// It starts as the key user provided and as trie is traversed bits are + /// removed from its front. + key: bits::Slice<'a>, + + /// Hash to insert into the trie. + value_hash: &'a CryptoHash, + + /// Allocator used to allocate new nodes. + wlog: memory::WriteLog<'a, A>, +} + +impl<'a, A: memory::Allocator> SetContext<'a, A> { + pub(super) fn new( + alloc: &'a mut A, + key: bits::Slice<'a>, + value_hash: &'a CryptoHash, + ) -> Self { + let wlog = memory::WriteLog::new(alloc); + Self { key, value_hash, wlog } + } + + /// Inserts value hash into the trie. + pub(super) fn set( + mut self, + root_ptr: Option, + root_hash: &CryptoHash, + ) -> Result<(Ptr, CryptoHash)> { + let res = (|| { + if let Some(ptr) = root_ptr { + // Trie is non-empty, handle normally. + self.handle(NodeRef { ptr: Some(ptr), hash: root_hash }) + } else if *root_hash != super::EMPTY_TRIE_ROOT { + // Trie is sealed (it’s not empty but ptr is None). + Err(Error::Sealed) + } else if let OwnedRef::Node(ptr, hash) = self.insert_value()? { + // Trie is empty and we’ve just inserted Extension leading to + // the value. + Ok((ptr, hash)) + } else { + // If the key was non-empty, self.insert_value would have + // returned a node reference. If it didn’t, it means key was + // empty which is an error condition. + Err(Error::EmptyKey) + } + })(); + if res.is_ok() { + self.wlog.commit(); + } + res + } + + /// Inserts value into the trie starting at node pointed by given reference. + fn handle(&mut self, nref: NodeRef) -> Result<(Ptr, CryptoHash)> { + let nref = (nref.ptr.ok_or(Error::Sealed)?, nref.hash); + let node = self.wlog.allocator().get(nref.0); + let node = node.decode(); + debug_assert_eq!(*nref.1, node.hash()); + match node { + Node::Branch { children } => self.handle_branch(nref, children), + Node::Extension { key, child } => { + self.handle_extension(nref, key, child) + } + Node::Value { value, child } => { + self.handle_value(nref, value, child) + } + } + } + + /// Inserts value assuming current node is a Branch with given children. + fn handle_branch( + &mut self, + nref: (Ptr, &CryptoHash), + children: [Reference<'_>; 2], + ) -> Result<(Ptr, CryptoHash)> { + let bit = if let Some(bit) = self.key.pop_front() { + bit + } else { + // If Key is empty, insert a new Node value with this node as + // a child. + return self.alloc_value_node(self.value_hash, nref.0, nref.1); + }; + + // Figure out which direction the key leads and update the node + // in-place. + let owned_ref = self.handle_reference(children[usize::from(bit)])?; + let child = owned_ref.to_ref(); + let children = + if bit { [children[0], child] } else { [child, children[1]] }; + Ok(self.set_node(nref.0, RawNode::branch(children[0], children[1]))) + // let child = owned_ref.to_ref(); + // let (left, right) = if bit == 0 { + // (child, children[1]) + // } else { + // (children[0], child) + // }; + + // // Update the node in place with the new child. + // Ok((nref.0, self.set_node(nref.0, RawNode::branch(left, right)))) + } + + /// Inserts value assuming current node is an Extension. + fn handle_extension( + &mut self, + nref: (Ptr, &CryptoHash), + mut key: bits::Slice<'_>, + child: Reference<'_>, + ) -> Result<(Ptr, CryptoHash)> { + // If key is empty, insert a new Value node with this node as a child. + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Val(val, ⫯) + // ↓ ↓ + // C Ext(key, ⫯) + // ↓ + // C + if self.key.is_empty() { + return self.alloc_value_node(self.value_hash, nref.0, nref.1); + } + + let prefix = self.key.forward_common_prefix(&mut key); + let mut suffix = key; + + // The entire extension key matched. Handle the child reference and + // update the node. + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Ext(key, ⫯) + // ↓ ↓ + // C C′ + if suffix.is_empty() { + let owned_ref = self.handle_reference(child)?; + let node = RawNode::extension(prefix, owned_ref.to_ref()).unwrap(); + return Ok(self.set_node(nref.0, node)); + } + + let our = if let Some(bit) = self.key.pop_front() { + usize::from(bit) + } else { + // Our key is done. We need to split the Extension node into + // two and insert Value node in between. + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Ext(prefix, ⫯) + // ↓ ↓ + // C Value(val, ⫯) + // ↓ + // Ext(suffix, ⫯) + // ↓ + // C + let (ptr, hash) = self.alloc_extension_node(suffix, child)?; + let (ptr, hash) = + self.alloc_value_node(self.value_hash, ptr, &hash)?; + let child = Reference::node(Some(ptr), &hash); + let node = RawNode::extension(prefix, child).unwrap(); + return Ok(self.set_node(nref.0, node)); + }; + + let theirs = usize::from(suffix.pop_front().unwrap()); + assert_ne!(our, theirs); + + // We need to split the Extension node with a Branch node in between. + // One child of the Branch will lead to our value; the other will lead + // to subtrie that the Extension points to. + // + // + // P P + // ↓ ↓ + // Ext(key, ⫯) → Ext(prefix, ⫯) + // ↓ ↓ + // C Branch(⫯, ⫯) + // ↓ ↓ + // V Ext(suffix, ⫯) + // ↓ + // C + // + // However, keep in mind that each of prefix or suffix may be empty. If + // that’s the case, corresponding Extension node is not created. + let our_ref = self.insert_value()?; + let their_hash: CryptoHash; + let their_ref = if let Some(node) = RawNode::extension(suffix, child) { + let (ptr, hash) = self.alloc_node(node)?; + their_hash = hash; + Reference::node(Some(ptr), &their_hash) + } else { + child + }; + let mut children = [their_ref; 2]; + children[our] = our_ref.to_ref(); + let node = RawNode::branch(children[0], children[1]); + let (ptr, hash) = self.set_node(nref.0, node); + + match RawNode::extension(prefix, Reference::node(Some(ptr), &hash)) { + Some(node) => self.alloc_node(node), + None => Ok((ptr, hash)), + } + } + + /// Inserts value assuming current node is an unsealed Value. + fn handle_value( + &mut self, + nref: (Ptr, &CryptoHash), + value: ValueRef, + child: NodeRef, + ) -> Result<(Ptr, CryptoHash)> { + if value.is_sealed { + return Err(Error::Sealed); + } + let node = if self.key.is_empty() { + RawNode::value(ValueRef::new(false, self.value_hash), child) + } else { + let (ptr, hash) = self.handle(child)?; + RawNode::value(value, NodeRef::new(Some(ptr), &hash)) + }; + Ok(self.set_node(nref.0, node)) + } + + /// Handles a reference which can either point at a node or a value. + /// + /// Returns a new value for the reference updating it such that it points at + /// the subtrie updated with the inserted value. + fn handle_reference(&mut self, child: Reference<'_>) -> Result { + match child { + Reference::Node(node) => { + // Handle node references recursively. We cannot special handle + // our key being empty because we need to handle cases where the + // reference points at a Value node correctly. + self.handle(node).map(|(p, h)| OwnedRef::Node(p, h)) + } + Reference::Value(value) => { + if value.is_sealed { + return Err(Error::Sealed); + } + // It’s a value reference so we just need to update it + // accordingly. One tricky thing is that we need to insert + // Value node with the old hash if our key isn’t empty. + match self.insert_value()? { + rf @ OwnedRef::Value(_) => Ok(rf), + OwnedRef::Node(p, h) => { + let child = NodeRef::new(Some(p), &h); + let node = RawNode::value(value, child); + self.alloc_node(node).map(|(p, h)| OwnedRef::Node(p, h)) + } + } + } + } + } + + /// Inserts the value into the trie and returns reference to it. + /// + /// If key is empty, doesn’t insert any nodes and instead returns a value + /// reference to the value. + /// + /// Otherwise, inserts one or more Extension nodes (depending on the length + /// of the key) and returns reference to the first ancestor node. + fn insert_value(&mut self) -> Result { + let mut ptr: Option = None; + let mut hash = self.value_hash.clone(); + for chunk in self.key.chunks().rev() { + let child = match ptr { + None => Reference::value(false, &hash), + Some(_) => Reference::node(ptr, &hash), + }; + let (p, h) = self.alloc_extension_node(chunk, child)?; + ptr = Some(p); + hash = h; + } + + Ok(if let Some(ptr) = ptr { + // We’ve updated some nodes. Insert node reference to the first + // one. + OwnedRef::Node(ptr, hash) + } else { + // ptr being None means that the above loop never run which means + // self.key is empty. We just need to return value reference. + OwnedRef::Value(hash) + }) + } + + /// A convenience method which allocates a new Extension node and sets it to + /// given value. + /// + /// **Panics** if `key` is empty or too long. + fn alloc_extension_node( + &mut self, + key: bits::Slice<'_>, + child: Reference<'_>, + ) -> Result<(Ptr, CryptoHash)> { + self.alloc_node(RawNode::extension(key, child).unwrap()) + } + + /// A convenience method which allocates a new Value node and sets it to + /// given value. + fn alloc_value_node( + &mut self, + value_hash: &CryptoHash, + ptr: Ptr, + hash: &CryptoHash, + ) -> Result<(Ptr, CryptoHash)> { + let value = ValueRef::new(false, value_hash); + let child = NodeRef::new(Some(ptr), hash); + self.alloc_node(RawNode::value(value, child)) + } + + /// Sets value of a node cell at given address and returns its hash. + fn set_node(&mut self, ptr: Ptr, node: RawNode) -> (Ptr, CryptoHash) { + let hash = node.decode().hash(); + self.wlog.set(ptr, node); + (ptr, hash) + } + + /// Allocates a new node and sets it to given value. + fn alloc_node(&mut self, node: RawNode) -> Result<(Ptr, CryptoHash)> { + let hash = node.decode().hash(); + let ptr = self.wlog.alloc(node)?; + Ok((ptr, hash)) + } +} + +enum OwnedRef { + Node(Ptr, CryptoHash), + Value(CryptoHash), +} + +impl OwnedRef { + fn to_ref(&self) -> Reference { + match self { + Self::Node(ptr, hash) => Reference::node(Some(*ptr), &hash), + Self::Value(hash) => Reference::value(false, &hash), + } + } +} diff --git a/sealable-trie/src/trie/tests.rs b/sealable-trie/src/trie/tests.rs new file mode 100644 index 00000000..a88d634b --- /dev/null +++ b/sealable-trie/src/trie/tests.rs @@ -0,0 +1,179 @@ +use std::collections::HashMap; +use std::println; + +use rand::Rng; + +use crate::hash::CryptoHash; +use crate::memory::test_utils::TestAllocator; + +fn do_test_inserts<'a>( + keys: impl IntoIterator, + verbose: bool, +) -> TestTrie { + let keys = keys.into_iter(); + let count = keys.size_hint().1.unwrap_or(1000).saturating_mul(4); + let mut trie = TestTrie::new(count); + for key in keys { + trie.set(key, verbose) + } + trie +} + +#[test] +fn test_msb_difference() { do_test_inserts([&[0][..], &[0x80][..]], true); } + +#[test] +fn test_sequence() { + do_test_inserts( + b"0123456789:;<=>?".iter().map(core::slice::from_ref), + true, + ); +} + +#[test] +fn test_2byte_extension() { + do_test_inserts([&[123, 40][..], &[134, 233][..]], true); +} + +#[test] +fn test_prefix() { + let key = b"xy"; + do_test_inserts([&key[..], &key[..1]], true); + do_test_inserts([&key[..1], &key[..]], true); +} + +#[test] +fn test_seal() { + let mut trie = do_test_inserts( + b"0123456789:;<=>?".iter().map(core::slice::from_ref), + true, + ); + + for b in b'0'..=b'?' { + trie.seal(&[b], true); + } +} + +#[test] +fn stress_test() { + struct RandKeys<'a> { + buf: &'a mut [u8; 35], + rng: rand::rngs::ThreadRng, + } + + impl<'a> Iterator for RandKeys<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + let len = self.rng.gen_range(1..self.buf.len()); + let key = &mut self.buf[..len]; + self.rng.fill(key); + let key = &key[..]; + // Transmute lifetimes. This is probably not sound in general but + // it works for our needs in this test. + unsafe { core::mem::transmute(key) } + } + } + + let count = crate::test_utils::get_iteration_count(); + let count = crate::test_utils::div_max_1(count, 100); + let keys = RandKeys { buf: &mut [0; 35], rng: rand::thread_rng() }; + do_test_inserts(keys.take(count), false); +} + +#[derive(Eq, Ord)] +struct Key { + len: u8, + buf: [u8; 35], +} + +impl Key { + fn as_bytes(&self) -> &[u8] { &self.buf[..usize::from(self.len)] } +} + +impl core::cmp::PartialEq for Key { + fn eq(&self, other: &Self) -> bool { self.as_bytes() == other.as_bytes() } +} + +impl core::cmp::PartialOrd for Key { + fn partial_cmp(&self, other: &Self) -> Option { + self.as_bytes().partial_cmp(other.as_bytes()) + } +} + +impl core::hash::Hash for Key { + fn hash(&self, state: &mut H) { + self.as_bytes().hash(state) + } +} + +impl core::fmt::Debug for Key { + fn fmt(&self, fmtr: &mut core::fmt::Formatter) -> core::fmt::Result { + self.as_bytes().fmt(fmtr) + } +} + +struct TestTrie { + trie: super::Trie, + mapping: HashMap, + count: usize, +} + +impl TestTrie { + pub fn new(count: usize) -> Self { + Self { + trie: super::Trie::test(count), + mapping: Default::default(), + count: 0, + } + } + + pub fn set(&mut self, key: &[u8], verbose: bool) { + assert!(key.len() <= 35); + let key = Key { + len: key.len() as u8, + buf: { + let mut buf = [0; 35]; + buf[..key.len()].copy_from_slice(key); + buf + }, + }; + + let value = self.next_value(); + println!("{}Inserting {key:?}", if verbose { "\n" } else { "" }); + self.trie + .set(key.as_bytes(), &value) + .unwrap_or_else(|err| panic!("Failed setting ‘{key:?}’: {err}")); + self.mapping.insert(key, value); + if verbose { + self.trie.print(); + } + for (key, value) in self.mapping.iter() { + let key = key.as_bytes(); + let got = self.trie.get(key).unwrap_or_else(|err| { + panic!("Failed getting ‘{key:?}’: {err}") + }); + assert_eq!(Some(value), got.as_ref(), "Invalid value at ‘{key:?}’"); + } + } + + pub fn seal(&mut self, key: &[u8], verbose: bool) { + println!("{}Sealing {key:?}", if verbose { "\n" } else { "" }); + self.trie + .seal(key) + .unwrap_or_else(|err| panic!("Failed sealing ‘{key:?}’: {err}")); + if verbose { + self.trie.print(); + } + assert_eq!( + Err(super::Error::Sealed), + self.trie.get(key), + "Unexpectedly can read ‘{key:?}’ after sealing" + ) + } + + fn next_value(&mut self) -> CryptoHash { + self.count += 1; + CryptoHash::test(self.count) + } +}