diff --git a/Cargo.lock b/Cargo.lock index 9ceee52..ee6f2dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + [[package]] name = "crunchy" version = "0.2.2" @@ -14,6 +20,34 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -30,3 +64,10 @@ dependencies = [ "hex", "tiny-keccak", ] + +[[package]] +name = "zk-kit-smt" +version = "0.0.1" +dependencies = [ + "num-bigint", +] diff --git a/crates/smt/Cargo.toml b/crates/smt/Cargo.toml new file mode 100644 index 0000000..dbd11fe --- /dev/null +++ b/crates/smt/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "zk-kit-smt" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true +description = "Sparse Merkle Tree" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +num-bigint = "0.4.6" diff --git a/crates/smt/LICENSE b/crates/smt/LICENSE new file mode 100644 index 0000000..1763b4d --- /dev/null +++ b/crates/smt/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Pinco + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/smt/README.md b/crates/smt/README.md new file mode 100644 index 0000000..0457578 --- /dev/null +++ b/crates/smt/README.md @@ -0,0 +1,90 @@ +

+

+ Sparse Merkle tree +

+

Sparse Merkle tree implementation in Rust.

+

+ +

+ + + + + License + + + Version + + + Downloads + +

+ +
+

+ + 🗣️ Chat & Support + +

+
+ +A sparse Merkle tree is a data structure useful for storing a key/value map where every leaf node of the tree contains the cryptographic hash of a key/value pair and every non leaf node contains the concatenated hashes of its child nodes. Sparse Merkle trees provides a secure and efficient verification of large data sets and they are often used in peer-to-peer technologies. This implementation is an optimized version of the traditional sparse Merkle tree and it is based on the concepts expressed in the papers and resources below. + +## References + +1. Rasmus Dahlberg, Tobias Pulls and Roel Peeters. _Efficient Sparse Merkle Trees: Caching Strategies and Secure (Non-)Membership Proofs_. Cryptology ePrint Archive: Report 2016/683, 2016. https://eprint.iacr.org/2016/683. +2. Faraz Haider. _Compact sparse merkle trees_. Cryptology ePrint Archive: Report 2018/955, 2018. https://eprint.iacr.org/2018/955. +3. Jordi Baylina and Marta Bellés. _Sparse Merkle Trees_. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf. +4. Vitalik Buterin Fichter. _Optimizing sparse Merkle trees_. https://ethresear.ch/t/optimizing-sparse-merkle-trees/3751. + +--- + +## 🛠 Install + +You can install `zk-kit-smt` crate with `cargo`: + +```bash +cargo add zk-kit-smt +``` + +## 📜 Usage + +```rust +use zk_kit_smt::smt::{Key, Node, Value, SMT}; + +fn hash_function(nodes: Vec) -> Node { + let strings: Vec = nodes.iter().map(|node| node.to_string()).collect(); + Node::Str(strings.join(",")) +} + +fn main() { + // Initialize the Sparse Merkle Tree with a hash function. + let mut smt = SMT::new(hash_function, false); + + let key = Key::Str("aaa".to_string()); + let value = Value::Str("bbb".to_string()); + + // Add a key-value pair to the Sparse Merkle Tree. + smt.add(key.clone(), value.clone()).unwrap(); + + // Get the value of the key. + let get = smt.get(key.clone()); + assert_eq!(get, Some(value)); + + // Update the value of the key. + let new_value = Value::Str("ccc".to_string()); + let update = smt.update(key.clone(), new_value.clone()); + assert!(update.is_ok()); + assert_eq!(smt.get(key.clone()), Some(new_value)); + + // Create and verify a proof for the key. + let create_proof = smt.create_proof(key.clone()); + let verify_proof = smt.verify_proof(create_proof); + assert!(verify_proof); + + // Delete the key. + let delete = smt.delete(key.clone()); + assert!(delete.is_ok()); + assert_eq!(smt.get(key.clone()), None); +} +``` diff --git a/crates/smt/examples/smt.rs b/crates/smt/examples/smt.rs new file mode 100644 index 0000000..8a6fe05 --- /dev/null +++ b/crates/smt/examples/smt.rs @@ -0,0 +1,37 @@ +use zk_kit_smt::smt::{Key, Node, Value, SMT}; + +fn hash_function(nodes: Vec) -> Node { + let strings: Vec = nodes.iter().map(|node| node.to_string()).collect(); + Node::Str(strings.join(",")) +} + +fn main() { + // Initialize the Sparse Merkle Tree with a hash function. + let mut smt = SMT::new(hash_function, false); + + let key = Key::Str("aaa".to_string()); + let value = Value::Str("bbb".to_string()); + + // Add a key-value pair to the Sparse Merkle Tree. + smt.add(key.clone(), value.clone()).unwrap(); + + // Get the value of the key. + let get = smt.get(key.clone()); + assert_eq!(get, Some(value)); + + // Update the value of the key. + let new_value = Value::Str("ccc".to_string()); + let update = smt.update(key.clone(), new_value.clone()); + assert!(update.is_ok()); + assert_eq!(smt.get(key.clone()), Some(new_value)); + + // Create and verify a proof for the key. + let create_proof = smt.create_proof(key.clone()); + let verify_proof = smt.verify_proof(create_proof); + assert!(verify_proof); + + // Delete the key. + let delete = smt.delete(key.clone()); + assert!(delete.is_ok()); + assert_eq!(smt.get(key.clone()), None); +} diff --git a/crates/smt/src/lib.rs b/crates/smt/src/lib.rs new file mode 100644 index 0000000..6d39c77 --- /dev/null +++ b/crates/smt/src/lib.rs @@ -0,0 +1,2 @@ +pub mod smt; +mod utils; diff --git a/crates/smt/src/smt.rs b/crates/smt/src/smt.rs new file mode 100644 index 0000000..265d487 --- /dev/null +++ b/crates/smt/src/smt.rs @@ -0,0 +1,933 @@ +use std::{collections::HashMap, str::FromStr}; + +use num_bigint::BigInt; + +use crate::utils::{ + get_first_common_elements, get_index_of_last_non_zero_element, is_hexadecimal, key_to_path, +}; + +use std::fmt; + +#[derive(Debug, PartialEq)] +pub enum SMTError { + KeyAlreadyExist(String), + KeyDoesNotExist(String), + InvalidParameterType(String, String), + InvalidSiblingIndex, +} + +impl fmt::Display for SMTError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SMTError::KeyAlreadyExist(s) => write!(f, "Key {} already exists", s), + SMTError::KeyDoesNotExist(s) => write!(f, "Key {} does not exist", s), + SMTError::InvalidParameterType(p, t) => { + write!(f, "Parameter {} must be a {}", p, t) + }, + SMTError::InvalidSiblingIndex => write!(f, "Invalid sibling index"), + } + } +} + +impl std::error::Error for SMTError {} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Node { + Str(String), + BigInt(BigInt), +} + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Node::Str(s) => write!(f, "{}", s), + Node::BigInt(n) => write!(f, "{}", n), + } + } +} + +impl FromStr for Node { + type Err = SMTError; + + fn from_str(s: &str) -> Result { + if let Ok(bigint) = s.parse::() { + Ok(Node::BigInt(bigint)) + } else if is_hexadecimal(s) { + Ok(Node::Str(s.to_string())) + } else { + Err(SMTError::InvalidParameterType( + s.to_string(), + "BigInt or hexadecimal string".to_string(), + )) + } + } +} + +pub type Key = Node; +pub type Value = Node; +pub type EntryMark = Node; + +pub type Entry = (Key, Value, EntryMark); +pub type ChildNodes = Vec; +pub type Siblings = Vec; + +pub type HashFunction = fn(ChildNodes) -> Node; + +pub struct EntryResponse { + pub entry: Vec, + pub matching_entry: Option>, + pub siblings: Siblings, +} + +#[allow(dead_code)] +pub struct MerkleProof { + entry_response: EntryResponse, + root: Node, + membership: bool, +} + +#[allow(dead_code)] +pub struct SMT { + hash: HashFunction, + big_numbers: bool, + zero_node: Node, + entry_mark: Node, + nodes: HashMap>, + root: Node, +} + +impl SMT { + /// Initializes a new instance of the Sparse Merkle Tree (SMT). + /// + /// # Arguments + /// + /// * `hash` - The hash function used to hash the child nodes. + /// * `big_numbers` - A flag indicating whether the SMT supports big numbers or not. + /// + /// # Returns + /// + /// A new instance of the SMT. + pub fn new(hash: HashFunction, big_numbers: bool) -> Self { + let zero_node; + let entry_mark; + + if big_numbers { + zero_node = Node::BigInt(BigInt::from(0)); + entry_mark = Node::BigInt(BigInt::from(1)); + } else { + zero_node = Node::Str("0".to_string()); + entry_mark = Node::Str("1".to_string()); + } + + SMT { + hash, + big_numbers, + zero_node: zero_node.clone(), + entry_mark, + nodes: HashMap::new(), + root: zero_node, + } + } + + /// Retrieves the value associated with the given key from the SMT. + /// + /// # Arguments + /// + /// * `key` - The key to retrieve the value for. + /// + /// # Returns + /// + /// An `Option` containing the value associated with the key, or `None` if the key does not exist. + pub fn get(&self, key: Key) -> Option { + let key = key.to_string().parse::().unwrap(); + + let EntryResponse { entry, .. } = self.retrieve_entry(key); + + entry.get(1).cloned() + } + + /// Adds a new key-value pair to the SMT. + /// + /// It retrieves a matching entry or a zero node with a top-down approach and then it updates + /// all the hashes of the nodes in the path of the new entry with a bottom up approach. + /// + /// # Arguments + /// + /// * `key` - The key to add. + /// * `value` - The value associated with the key. + /// + /// # Returns + /// + /// An `Result` indicating whether the operation was successful or not. + pub fn add(&mut self, key: Key, value: Value) -> Result<(), SMTError> { + let key = key.to_string().parse::().unwrap(); + let value = value.to_string().parse::().unwrap(); + + let EntryResponse { + entry, + matching_entry, + mut siblings, + } = self.retrieve_entry(key.clone()); + + if entry.get(1).is_some() { + return Err(SMTError::KeyAlreadyExist(key.to_string())); + } + + let path = key_to_path(&key.to_string()); + // If there is a matching entry, its node is saved in the `node` variable, otherwise the + // `zero_node` is saved. This node is used below as the first node (starting from the + // bottom of the tree) to obtain the new nodes up to the root. + let node = if let Some(ref matching_entry) = matching_entry { + (self.hash)(matching_entry.clone()) + } else { + self.zero_node.clone() + }; + + // If there are siblings, the old nodes are deleted and will be re-created below with new hashes. + if !siblings.is_empty() { + self.delete_old_nodes(node.clone(), &path, &siblings) + } + + // If there is a matching entry, further N zero siblings are added in the `siblings` vector, + // followed by the matching node itself. N is the number of the first matching bits of the paths. + // This is helpful in the non-membership proof verification as explained in the function below. + if let Some(matching_entry) = matching_entry { + let matching_path = key_to_path(&matching_entry[0].to_string()); + let mut i = siblings.len(); + + while matching_path[i] == path[i] { + siblings.push(self.zero_node.clone()); + i += 1; + } + + siblings.push(node.clone()); + } + + // Adds the new entry and re-creates the nodes of the path with the new hashes with a bottom + // up approach. The `add_new_nodes` function returns the new root of the tree. + let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]); + + self.nodes + .insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]); + self.root = self + .add_new_nodes(new_node, &path, &siblings, None) + .unwrap(); + + Ok(()) + } + + /// Updates the value associated with the given key in the SMT. + /// + /// Also in this case, all the hashes of the nodes in the path of the updated entry are updated + /// with a bottom up approach. + /// + /// # Arguments + /// + /// * `key` - The key to update the value for. + /// * `value` - The new value associated with the key. + /// + /// # Returns + /// + /// An `Result` indicating whether the operation was successful or not. + pub fn update(&mut self, key: Key, value: Value) -> Result<(), SMTError> { + let key = key.to_string().parse::().unwrap(); + let value = value.to_string().parse::().unwrap(); + + let EntryResponse { + entry, siblings, .. + } = self.retrieve_entry(key.clone()); + + if entry.get(1).is_none() { + return Err(SMTError::KeyDoesNotExist(key.to_string())); + } + + let path = key_to_path(&key.to_string()); + + // Deletes the old nodes and re-creates them with the new hashes. + let old_node = (self.hash)(entry.clone()); + self.nodes.remove(&old_node); + self.delete_old_nodes(old_node.clone(), &path, &siblings); + + let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]); + self.nodes + .insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]); + self.root = self + .add_new_nodes(new_node, &path, &siblings, None) + .unwrap(); + + Ok(()) + } + + /// Deletes the key-value pair associated with the given key from the SMT. + /// + /// Also in this case, all the hashes of the nodes in the path of the deleted entry are updated + /// with a bottom up approach. + /// + /// # Arguments + /// + /// * `key` - The key to delete. + /// + /// # Returns + /// + /// An `Result` indicating whether the operation was successful or not. + pub fn delete(&mut self, key: Key) -> Result<(), SMTError> { + let key = key.to_string().parse::().unwrap(); + + let EntryResponse { + entry, + mut siblings, + .. + } = self.retrieve_entry(key.clone()); + + if entry.get(1).is_none() { + return Err(SMTError::KeyDoesNotExist(key.to_string())); + } + + let path = key_to_path(&key.to_string()); + + let node = (self.hash)(entry.clone()); + self.nodes.remove(&node); + + self.root = self.zero_node.clone(); + + // If there are siblings, the old nodes are deleted and will be re-created below with new hashes. + if !siblings.is_empty() { + self.delete_old_nodes(node.clone(), &path, &siblings); + + // If the last sibling is not a leaf node, it adds all the nodes of the path starting from + // a zero node, otherwise it removes the last non-zero sibling from the `siblings` vector + // and it starts from it by skipping the last zero nodes. + if !self.is_leaf(&siblings.last().cloned().unwrap()) { + self.root = self + .add_new_nodes(self.zero_node.clone(), &path, &siblings, None) + .unwrap(); + } else { + let first_sibling = siblings.pop().unwrap(); + let i = get_index_of_last_non_zero_element( + siblings + .iter() + .map(|s| s.to_string()) + .collect::>() + .iter() + .map(|s| s.as_str()) + .collect::>(), + ); + + self.root = self.add_new_nodes(first_sibling, &path, &siblings, Some(i))?; + } + } + + Ok(()) + } + + /// Creates a proof to prove the membership or the non-membership of a tree entry. + /// + /// # Arguments + /// + /// * `key` - The key to create the proof for. + /// + /// # Returns + /// + /// A `MerkleProof` containing the proof information. + pub fn create_proof(&self, key: Key) -> MerkleProof { + let key = key.to_string().parse::().unwrap(); + + let EntryResponse { + entry, + matching_entry, + siblings, + } = self.retrieve_entry(key); + + // If the key exists, the function returns a proof with the entry itself, otherwise it returns + // a non-membership proof with the matching entry. + MerkleProof { + entry_response: EntryResponse { + entry: entry.clone(), + matching_entry, + siblings, + }, + root: self.root.clone(), + membership: entry.get(1).is_some(), + } + } + + /// Verifies a membership or a non-membership proof for a given key in the SMT. + /// + /// # Arguments + /// + /// * `merkle_proof` - The Merkle proof to verify. + /// + /// # Returns + /// + /// A boolean indicating whether the proof is valid or not. + pub fn verify_proof(&self, merkle_proof: MerkleProof) -> bool { + // If there is no matching entry, it simply obtains the root hash by using the siblings and the + // path of the key. + if merkle_proof.entry_response.matching_entry.is_none() { + let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string()); + // If there is not an entry value, the proof is a non-membership proof. In this case, since there + // is not a matching entry, the node is set to a zero node. If there is an entry value, the proof + // is a membership proof and the node is set to the hash of the entry. + let node = if merkle_proof.entry_response.entry.get(1).is_some() { + (self.hash)(merkle_proof.entry_response.entry) + } else { + self.zero_node.clone() + }; + let root = self.calculate_root(node, &path, &merkle_proof.entry_response.siblings); + + // If the obtained root is equal to the proof root, then the proof is valid. + return root == merkle_proof.root; + } + + // If there is a matching entry, the proof is definitely a non-membership proof. In this case, it checks + // if the matching node belongs to the tree, and then it checks if the number of the first matching bits + // of the keys is greater than or equal to the number of the siblings. + if let Some(matching_entry) = &merkle_proof.entry_response.matching_entry { + let matching_path = key_to_path(&matching_entry[0].to_string()); + let node = (self.hash)(matching_entry.to_vec()); + let root = + self.calculate_root(node, &matching_path, &merkle_proof.entry_response.siblings); + + if root == merkle_proof.root { + let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string()); + // Returns the first common bits of the two keys: the non-member key and the matching key. + let first_matching_bits = get_first_common_elements(&path, &matching_path); + + // If the non-member key was a key of a tree entry, the depth of the matching node should be + // greater than the number of the fisrt matching bits. Otherwise, the depth of the node can be + // defined by the number of its siblings. + return merkle_proof.entry_response.siblings.len() <= first_matching_bits.len(); + } + } + + false + } + + /// Retrieves the entry associated with the given key from the SMT. + /// + /// If the key passed as parameter exists in the SMT, the function returns the entry itself, otherwise + /// it returns the entry with only the key. When there is another matching entry in the same path, it + /// returns the matching entry as well. + /// + /// In any case, the function returns the siblings of the path. + /// + /// # Arguments + /// + /// * `key` - The key to retrieve the entry for. + /// + /// # Returns + /// + /// An `EntryResponse` struct containing the entry, the matching entry (if any), and the siblings of the leaf node. + fn retrieve_entry(&self, key: Key) -> EntryResponse { + let path = key_to_path(&key.to_string()); + let mut siblings: Siblings = Vec::new(); + let mut node = self.root.clone(); + + let mut i = 0; + + // Starting from the root, it traverses the tree until it reaches a leaf node, a zero node, + // or a matching entry. + while node != self.zero_node { + let child_nodes = self.nodes.get(&node).unwrap_or(&Vec::new()).clone(); + let direction = path[i]; + + // If the third element of the child nodes is not None, it means that the node is an entry of the tree. + if child_nodes.get(2).is_some() { + if child_nodes[0] == key { + // An entry is found with the same key, and it returns it with the siblings. + return EntryResponse { + entry: child_nodes, + matching_entry: None, + siblings, + }; + } + + // An entry was found with a different key, but the key of this particular entry matches the first 'i' + // bits of the key passed as parameter. It can be useful in several functions. + return EntryResponse { + entry: vec![key.clone()], + matching_entry: Some(child_nodes), + siblings, + }; + } + + // When it goes down into the tree and follows the path, in every step a node is chosen between left + // and right child nodes, and the opposite node is saved in the `siblings` vector. + node = child_nodes[direction].clone(); + siblings.push(child_nodes[1 - direction].clone()); + + i += 1; + } + + // The path led to a zero node. + EntryResponse { + entry: vec![key], + matching_entry: None, + siblings, + } + } + + /// Calculates the root of the tree by using the given node, the path, and the siblings. + /// + /// It calculates with a bottom up approach by starting from the node and going up to the root. + /// + /// # Arguments + /// + /// * `node` - The node to start the calculation from. + /// * `path` - The path of the key. + /// * `siblings` - The siblings of the path. + /// + /// # Returns + /// + /// The root of the tree. + fn calculate_root(&self, mut node: Node, path: &[usize], siblings: &Siblings) -> Node { + for i in (0..siblings.len()).rev() { + let child_nodes: ChildNodes = if path[i] != 0 { + vec![siblings[i].clone(), node.clone()] + } else { + vec![node.clone(), siblings[i].clone()] + }; + + node = (self.hash)(child_nodes); + } + + node + } + + /// Adds new nodes to the tree with the new hashes. + /// + /// It starts with a bottom up approach until it reaches the root of the tree. + /// + /// # Arguments + /// + /// * `node` - The node to start the calculation from. + /// * `path` - The path of the key. + /// * `siblings` - The siblings of the path. + /// * `i` - The index of the sibling to start from. + /// + /// # Returns + /// + /// The new root of the tree. + fn add_new_nodes( + &mut self, + mut node: Node, + path: &[usize], + siblings: &Siblings, + i: Option, + ) -> Result { + let mut starting_index = if let Some(i) = i { + i + } else { + siblings.len() as isize - 1 + }; + + while starting_index > 0 { + if siblings.get(starting_index as usize).is_none() { + return Err(SMTError::InvalidSiblingIndex); + } + + let child_nodes: ChildNodes = if path.get(starting_index as usize).is_some() { + vec![siblings[starting_index as usize].clone(), node.clone()] + } else { + vec![node.clone(), siblings[starting_index as usize].clone()] + }; + + node = (self.hash)(child_nodes.clone()); + + self.nodes.insert(node.clone(), child_nodes); + + starting_index -= 1; + } + + Ok(node) + } + + /// Deletes the old nodes of the tree. + /// + /// It starts with a bottom up approach until it reaches the root of the tree. + /// + /// # Arguments + /// + /// * `node` - The node to start the calculation from. + /// * `path` - The path of the key. + /// * `siblings` - The siblings of the path. + fn delete_old_nodes(&mut self, mut node: Node, path: &[usize], siblings: &Siblings) { + for i in (0..siblings.len()).rev() { + let child_nodes: ChildNodes = if path.get(i).is_some() { + vec![siblings[i].clone(), node.clone()] + } else { + vec![node.clone(), siblings[i].clone()] + }; + + node = (self.hash)(child_nodes); + + self.nodes.remove(&node); + } + } + + /// Checks if the given node is a leaf node or not. + /// + /// # Arguments + /// + /// * `node` - The node to check. + /// + /// # Returns + /// + /// A boolean indicating whether the node is a leaf node or not. + fn is_leaf(&self, node: &Node) -> bool { + if let Some(child_nodes) = self.nodes.get(node) { + child_nodes.get(2).is_some() + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn hash_function(nodes: Vec) -> Node { + let strings: Vec = nodes.iter().map(|node| node.to_string()).collect(); + Node::Str(strings.join(",")) + } + + #[test] + fn test_new() { + let smt = SMT::new(hash_function, false); + assert!(!smt.big_numbers); + assert_eq!(smt.zero_node, Node::Str("0".to_string())); + assert_eq!(smt.entry_mark, Node::Str("1".to_string())); + assert_eq!(smt.nodes, HashMap::new()); + assert_eq!(smt.root, Node::Str("0".to_string())); + + let smt = SMT::new(hash_function, true); + assert!(smt.big_numbers); + assert_eq!(smt.zero_node, Node::BigInt(BigInt::from(0))); + assert_eq!(smt.entry_mark, Node::BigInt(BigInt::from(1))); + assert_eq!(smt.nodes, HashMap::new()); + assert_eq!(smt.root, Node::BigInt(BigInt::from(0))); + } + + #[test] + fn test_get() { + let mut smt = SMT::new(hash_function, false); + let key = Key::Str("aaa".to_string()); + let value = Value::Str("bbb".to_string()); + let _ = smt.add(key.clone(), value.clone()); + let result = smt.get(key.clone()); + println!("{:?}", result); + assert_eq!(result, Some(value)); + + let key2 = Key::Str("ccc".to_string()); + let result2 = smt.get(key2.clone()); + assert_eq!(result2, None); + + let mut smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let value = Value::BigInt(BigInt::from(456)); + let _ = smt.add(key.clone(), value.clone()); + let result = smt.get(key.clone()); + assert_eq!(result, Some(value)); + } + #[test] + fn test_add() { + let mut smt = SMT::new(hash_function, false); + let key = Key::Str("aaa".to_string()); + let value = Value::Str("bbb".to_string()); + let result = smt.add(key.clone(), value.clone()); + assert!(result.is_ok()); + assert_eq!(smt.nodes.len(), 1); + assert_eq!( + smt.nodes.get(&smt.root), + Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()]) + ); + + let mut smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let value = Value::BigInt(BigInt::from(456)); + let result = smt.add(key.clone(), value.clone()); + assert!(result.is_ok()); + assert_eq!(smt.nodes.len(), 1); + assert_eq!( + smt.nodes.get(&smt.root), + Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()]) + ); + } + + #[test] + fn test_update() { + let mut smt = SMT::new(hash_function, false); + let key = Key::Str("aaa".to_string()); + let value = Value::Str("bbb".to_string()); + let _ = smt.add(key.clone(), value.clone()); + + let new_value = Value::Str("ccc".to_string()); + let result = smt.update(key.clone(), new_value.clone()); + assert!(result.is_ok()); + assert_eq!(smt.nodes.len(), 1); + assert_eq!( + smt.nodes.get(&smt.root), + Some(&vec![ + key.clone(), + new_value.clone(), + smt.entry_mark.clone() + ]) + ); + + let key2 = Key::Str("def".to_string()); + let result2 = smt.update(key2.clone(), new_value.clone()); + assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string()))); + + let mut smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let value = Value::BigInt(BigInt::from(456)); + let _ = smt.add(key.clone(), value.clone()); + + let new_value = Value::BigInt(BigInt::from(789)); + let result = smt.update(key.clone(), new_value.clone()); + assert!(result.is_ok()); + assert_eq!(smt.nodes.len(), 1); + assert_eq!( + smt.nodes.get(&smt.root), + Some(&vec![ + key.clone(), + new_value.clone(), + smt.entry_mark.clone() + ]) + ); + } + + #[test] + fn test_delete() { + let mut smt = SMT::new(hash_function, false); + let key = Key::Str("abc".to_string()); + let value = Value::Str("123".to_string()); + let _ = smt.add(key.clone(), value.clone()); + let result = smt.delete(key.clone()); + assert!(result.is_ok()); + assert_eq!(smt.nodes.len(), 0); + assert_eq!(smt.root, smt.zero_node); + + let key2 = Key::Str("def".to_string()); + let result2 = smt.delete(key2.clone()); + assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string()))); + + let mut smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let value = Value::BigInt(BigInt::from(456)); + let _ = smt.add(key.clone(), value.clone()); + let result = smt.delete(key.clone()); + assert!(result.is_ok()); + assert_eq!(smt.nodes.len(), 0); + assert_eq!(smt.root, smt.zero_node); + } + + #[test] + fn test_create_proof() { + let mut smt = SMT::new(hash_function, false); + let key = Key::Str("abc".to_string()); + let value = Value::Str("123".to_string()); + let _ = smt.add(key.clone(), value.clone()); + let proof = smt.create_proof(key.clone()); + assert_eq!(proof.root, smt.root); + + let mut smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let value = Value::BigInt(BigInt::from(456)); + let _ = smt.add(key.clone(), value.clone()); + let proof = smt.create_proof(key.clone()); + assert_eq!(proof.root, smt.root); + } + + #[test] + fn test_verify_proof() { + let mut smt = SMT::new(hash_function, false); + let key = Key::Str("abc".to_string()); + let value = Value::Str("123".to_string()); + let _ = smt.add(key.clone(), value.clone()); + let proof = smt.create_proof(key.clone()); + let result = smt.verify_proof(proof); + assert!(result); + + let key2 = Key::Str("def".to_string()); + let false_proof = MerkleProof { + entry_response: EntryResponse { + entry: vec![key2.clone()], + matching_entry: None, + siblings: Vec::new(), + }, + root: smt.root.clone(), + membership: false, + }; + let fun = smt.verify_proof(false_proof); + assert!(!fun); + + let mut smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let value = Value::BigInt(BigInt::from(456)); + let _ = smt.add(key.clone(), value.clone()); + let proof = smt.create_proof(key.clone()); + let result = smt.verify_proof(proof); + assert!(result); + + let key2 = Key::BigInt(BigInt::from(789)); + let false_proof = MerkleProof { + entry_response: EntryResponse { + entry: vec![key2.clone()], + matching_entry: None, + siblings: Vec::new(), + }, + root: smt.root.clone(), + membership: true, + }; + let fun = smt.verify_proof(false_proof); + assert!(!fun); + } + + #[test] + fn test_retrieve_entry() { + let smt = SMT::new(hash_function, false); + let key = Key::Str("be12".to_string()); + let entry_response = smt.retrieve_entry(key.clone()); + assert_eq!(entry_response.entry, vec![key]); + assert_eq!(entry_response.matching_entry, None); + assert_eq!(entry_response.siblings, Vec::new()); + + let smt = SMT::new(hash_function, true); + let key = Key::BigInt(BigInt::from(123)); + let entry_response = smt.retrieve_entry(key.clone()); + assert_eq!(entry_response.entry, vec![key]); + assert_eq!(entry_response.matching_entry, None); + assert_eq!(entry_response.siblings, Vec::new()); + } + + #[test] + fn test_calculate_root() { + let smt = SMT::new(hash_function, false); + let node = Node::Str("node".to_string()); + let path = &[0, 1, 0]; + let siblings = vec![ + Node::Str("sibling1".to_string()), + Node::Str("sibling2".to_string()), + Node::Str("sibling3".to_string()), + ]; + let root = smt.calculate_root(node.clone(), path, &siblings); + assert_eq!( + root, + Node::Str("sibling2,node,sibling3,sibling1".to_string()) + ); + + let smt = SMT::new(hash_function, true); + let node = Node::BigInt(BigInt::from(123)); + let path = &[1, 0]; + let siblings = vec![ + Node::BigInt(BigInt::from(456)), + Node::BigInt(BigInt::from(789)), + ]; + let root = smt.calculate_root(node.clone(), path, &siblings); + assert_eq!(root, Node::Str("456,123,789".to_string())); + } + + #[test] + fn test_add_new_nodes() { + let mut smt = SMT::new(hash_function, false); + let node = Node::Str("node".to_string()); + let path = &[0, 1, 0]; + let siblings = vec![ + Node::Str("sibling1".to_string()), + Node::Str("sibling2".to_string()), + Node::Str("sibling3".to_string()), + ]; + let new_node = smt + .add_new_nodes(node.clone(), path, &siblings, None) + .unwrap(); + assert_eq!(new_node, Node::Str("sibling2,sibling3,node".to_string())); + + let starting_index = smt + .add_new_nodes(node.clone(), path, &siblings, Some(1)) + .unwrap(); + assert_eq!(starting_index, Node::Str("sibling2,node".to_string())); + + let mut smt = SMT::new(hash_function, true); + let node = Node::BigInt(BigInt::from(111)); + let path = &[1, 0, 0]; + let siblings = vec![ + Node::BigInt(BigInt::from(222)), + Node::BigInt(BigInt::from(333)), + Node::BigInt(BigInt::from(444)), + ]; + let new_node = smt + .add_new_nodes(node.clone(), path, &siblings, None) + .unwrap(); + assert_eq!(new_node, Node::Str("333,444,111".to_string())); + + let starting_index = smt + .add_new_nodes(node.clone(), path, &siblings, Some(1)) + .unwrap(); + assert_eq!(starting_index, Node::Str("333,111".to_string())); + } + + #[test] + fn test_delete_old_nodes() { + let mut smt = SMT::new(hash_function, false); + let node = Node::Str("abc".to_string()); + let path = &[0, 1, 0]; + let siblings = vec![ + Node::Str("sibling1".to_string()), + Node::Str("sibling2".to_string()), + Node::Str("sibling3".to_string()), + ]; + let new_node = smt + .add_new_nodes(node.clone(), path, &siblings, None) + .unwrap(); + assert_eq!(new_node, Node::Str("sibling2,sibling3,abc".to_string())); + smt.delete_old_nodes(node.clone(), path, &siblings); + assert_eq!(smt.nodes.len(), 0); + + let mut smt = SMT::new(hash_function, true); + let node = Node::BigInt(BigInt::from(123)); + let path = &[1, 0]; + let siblings = vec![ + Node::BigInt(BigInt::from(456)), + Node::BigInt(BigInt::from(789)), + ]; + let new_node = smt + .add_new_nodes(node.clone(), path, &siblings, None) + .unwrap(); + assert_eq!(new_node, Node::Str("789,123".to_string())); + smt.delete_old_nodes(node.clone(), path, &siblings); + assert_eq!(smt.nodes.len(), 0); + } + + #[test] + fn test_is_leaf() { + let mut smt = SMT::new(hash_function, false); + let node = Node::Str("abc".to_string()); + assert!(!smt.is_leaf(&node)); + + smt.nodes.insert( + Node::Str("abc".to_string()), + vec![ + Node::Str("123".to_string()), + Node::Str("456".to_string()), + Node::Str("789".to_string()), + ], + ); + assert!(smt.is_leaf(&node)); + + let mut smt = SMT::new(hash_function, true); + let node = Node::BigInt(BigInt::from(123)); + assert!(!smt.is_leaf(&node)); + + smt.nodes.insert( + Node::BigInt(BigInt::from(123)), + vec![ + Node::BigInt(BigInt::from(111)), + Node::BigInt(BigInt::from(222)), + Node::BigInt(BigInt::from(333)), + ], + ); + assert!(smt.is_leaf(&node)); + } +} diff --git a/crates/smt/src/utils.rs b/crates/smt/src/utils.rs new file mode 100644 index 0000000..1943406 --- /dev/null +++ b/crates/smt/src/utils.rs @@ -0,0 +1,185 @@ +/// Converts a hexadecimal string to a binary string. +/// +/// # Arguments +/// +/// * `n` - The hexadecimal string to convert. +/// +/// # Returns +/// +/// The binary representation of the hexadecimal string. +pub fn hex_to_bin(n: &str) -> String { + let mut chars = n.chars(); + let first_char = chars.next().unwrap(); + let mut bin = format!( + "{:b}", + u8::from_str_radix(&first_char.to_string(), 16).unwrap() + ); + + for c in chars { + bin += &format!("{:04b}", u8::from_str_radix(&c.to_string(), 16).unwrap()); + } + + bin +} + +/// Converts a hexadecimal key to a path represented as a vector of usize. +/// +/// For each key, it is possible to obtain an array of 256 padded bits. +/// +/// # Arguments +/// +/// * `key` - The hexadecimal key to convert. +/// +/// # Returns +/// +/// The path represented as a vector of usize. +pub fn key_to_path(key: &str) -> Vec { + let bits = if let Ok(num) = u128::from_str_radix(key, 16) { + format!("{:b}", num) + } else { + hex_to_bin(key) + }; + + let padded_bits = format!("{:0>256}", bits).chars().rev().collect::(); + let bits_array = padded_bits + .chars() + .map(|c| c.to_digit(10).unwrap() as usize) + .collect(); + + bits_array +} + +/// Returns the index of the last non-zero element in the array. +/// +/// # Arguments +/// +/// * `array` - The array of hexadecimal strings. +/// +/// # Returns +/// +/// The index of the last non-zero element in the array, or -1 if no non-zero element is found. +pub fn get_index_of_last_non_zero_element(array: Vec<&str>) -> isize { + for (i, &item) in array.iter().enumerate().rev() { + if u128::from_str_radix(item, 16).unwrap_or(0) != 0 { + return i as isize; + } + } + + -1 +} + +/// Returns the first common elements between two arrays. +/// +/// # Arguments +/// +/// * `array1` - The first array. +/// * `array2` - The second array. +/// +/// # Returns +/// +/// The first common elements between the two arrays. +pub fn get_first_common_elements(array1: &[T], array2: &[T]) -> Vec { + let min_length = std::cmp::min(array1.len(), array2.len()); + + for i in 0..min_length { + if array1[i] != array2[i] { + return array1[0..i].to_vec(); + } + } + + array1[0..min_length].to_vec() +} + +/// Checks if a string is a valid hexadecimal string. +/// +/// # Arguments +/// +/// * `s` - The string to check. +/// +/// # Returns +/// +/// `true` if the string is a valid hexadecimal string, `false` otherwise. +pub fn is_hexadecimal(s: &str) -> bool { + s.chars().all(|c| c.is_ascii_hexdigit()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hex_to_bin() { + assert_eq!(hex_to_bin("A"), "1010"); + assert_eq!(hex_to_bin("F"), "1111"); + assert_eq!(hex_to_bin("1A"), "11010"); + assert_eq!(hex_to_bin("FF"), "11111111"); + assert_eq!(hex_to_bin("12"), "10010"); + } + + #[test] + fn test_key_to_path() { + let path = key_to_path("17"); + assert_eq!(path.len(), 256); + assert_eq!(&path[0..5], vec![1, 1, 1, 0, 1]); + } + + #[test] + fn test_get_index_of_last_non_zero_element() { + assert_eq!(get_index_of_last_non_zero_element(vec![]), -1); + assert_eq!(get_index_of_last_non_zero_element(vec!["0", "0", "0"]), -1); + + assert_eq!(get_index_of_last_non_zero_element(vec!["0", "0", "1"]), 2); + assert_eq!(get_index_of_last_non_zero_element(vec!["0", "1", "0"]), 1); + assert_eq!(get_index_of_last_non_zero_element(vec!["1", "0", "0"]), 0); + + assert_eq!( + get_index_of_last_non_zero_element(vec!["0", "1", "0", "1", "0"]), + 3 + ); + assert_eq!( + get_index_of_last_non_zero_element(vec!["1", "0", "1", "0", "0"]), + 2 + ); + assert_eq!( + get_index_of_last_non_zero_element(vec!["0", "0", "0", "1", "1"]), + 4 + ); + assert_eq!( + get_index_of_last_non_zero_element(vec![ + "0", "17", "3", "0", "3", "0", "3", "2", "0", "0" + ]), + 7 + ) + } + + #[test] + fn test_get_first_common_elements() { + assert_eq!(get_first_common_elements::(&[], &[]), vec![]); + + assert_eq!( + get_first_common_elements(&[1, 2, 3], &[1, 2, 3, 4, 5]), + vec![1, 2, 3] + ); + assert_eq!( + get_first_common_elements(&[1, 2, 3, 4, 5], &[1, 2, 3]), + vec![1, 2, 3] + ); + + assert_eq!( + get_first_common_elements(&[1, 2, 3], &[1, 2, 4]), + vec![1, 2] + ); + assert_eq!(get_first_common_elements(&[1, 2, 3], &[4, 5, 6]), vec![]); + } + + #[test] + fn test_is_hexadecimal() { + assert!(is_hexadecimal("be12")); + assert!(is_hexadecimal("ABCDEF")); + assert!(is_hexadecimal("1234567890abcdef")); + + assert!(!is_hexadecimal("gbe12")); + assert!(!is_hexadecimal("123XYZ")); + assert!(!is_hexadecimal("abcdefg")); + } +}