diff --git a/Cargo.lock b/Cargo.lock
index 9ceee52..ee6f2dc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2,6 +2,12 @@
# It is not intended for manual editing.
version = 3
+[[package]]
+name = "autocfg"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
+
[[package]]
name = "crunchy"
version = "0.2.2"
@@ -14,6 +20,34 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
+[[package]]
+name = "num-bigint"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
+dependencies = [
+ "num-integer",
+ "num-traits",
+]
+
+[[package]]
+name = "num-integer"
+version = "0.1.46"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
+dependencies = [
+ "autocfg",
+]
+
[[package]]
name = "tiny-keccak"
version = "2.0.2"
@@ -30,3 +64,10 @@ dependencies = [
"hex",
"tiny-keccak",
]
+
+[[package]]
+name = "zk-kit-smt"
+version = "0.0.1"
+dependencies = [
+ "num-bigint",
+]
diff --git a/crates/smt/Cargo.toml b/crates/smt/Cargo.toml
new file mode 100644
index 0000000..dbd11fe
--- /dev/null
+++ b/crates/smt/Cargo.toml
@@ -0,0 +1,12 @@
+[package]
+name = "zk-kit-smt"
+version.workspace = true
+edition.workspace = true
+license.workspace = true
+publish.workspace = true
+description = "Sparse Merkle Tree"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+num-bigint = "0.4.6"
diff --git a/crates/smt/LICENSE b/crates/smt/LICENSE
new file mode 100644
index 0000000..1763b4d
--- /dev/null
+++ b/crates/smt/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Pinco
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/crates/smt/README.md b/crates/smt/README.md
new file mode 100644
index 0000000..0457578
--- /dev/null
+++ b/crates/smt/README.md
@@ -0,0 +1,90 @@
+
+
+ Sparse Merkle tree
+
+ Sparse Merkle tree implementation in Rust.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+A sparse Merkle tree is a data structure useful for storing a key/value map where every leaf node of the tree contains the cryptographic hash of a key/value pair and every non leaf node contains the concatenated hashes of its child nodes. Sparse Merkle trees provides a secure and efficient verification of large data sets and they are often used in peer-to-peer technologies. This implementation is an optimized version of the traditional sparse Merkle tree and it is based on the concepts expressed in the papers and resources below.
+
+## References
+
+1. Rasmus Dahlberg, Tobias Pulls and Roel Peeters. _Efficient Sparse Merkle Trees: Caching Strategies and Secure (Non-)Membership Proofs_. Cryptology ePrint Archive: Report 2016/683, 2016. https://eprint.iacr.org/2016/683.
+2. Faraz Haider. _Compact sparse merkle trees_. Cryptology ePrint Archive: Report 2018/955, 2018. https://eprint.iacr.org/2018/955.
+3. Jordi Baylina and Marta Bellés. _Sparse Merkle Trees_. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf.
+4. Vitalik Buterin Fichter. _Optimizing sparse Merkle trees_. https://ethresear.ch/t/optimizing-sparse-merkle-trees/3751.
+
+---
+
+## 🛠 Install
+
+You can install `zk-kit-smt` crate with `cargo`:
+
+```bash
+cargo add zk-kit-smt
+```
+
+## 📜 Usage
+
+```rust
+use zk_kit_smt::smt::{Key, Node, Value, SMT};
+
+fn hash_function(nodes: Vec) -> Node {
+ let strings: Vec = nodes.iter().map(|node| node.to_string()).collect();
+ Node::Str(strings.join(","))
+}
+
+fn main() {
+ // Initialize the Sparse Merkle Tree with a hash function.
+ let mut smt = SMT::new(hash_function, false);
+
+ let key = Key::Str("aaa".to_string());
+ let value = Value::Str("bbb".to_string());
+
+ // Add a key-value pair to the Sparse Merkle Tree.
+ smt.add(key.clone(), value.clone()).unwrap();
+
+ // Get the value of the key.
+ let get = smt.get(key.clone());
+ assert_eq!(get, Some(value));
+
+ // Update the value of the key.
+ let new_value = Value::Str("ccc".to_string());
+ let update = smt.update(key.clone(), new_value.clone());
+ assert!(update.is_ok());
+ assert_eq!(smt.get(key.clone()), Some(new_value));
+
+ // Create and verify a proof for the key.
+ let create_proof = smt.create_proof(key.clone());
+ let verify_proof = smt.verify_proof(create_proof);
+ assert!(verify_proof);
+
+ // Delete the key.
+ let delete = smt.delete(key.clone());
+ assert!(delete.is_ok());
+ assert_eq!(smt.get(key.clone()), None);
+}
+```
diff --git a/crates/smt/examples/smt.rs b/crates/smt/examples/smt.rs
new file mode 100644
index 0000000..8a6fe05
--- /dev/null
+++ b/crates/smt/examples/smt.rs
@@ -0,0 +1,37 @@
+use zk_kit_smt::smt::{Key, Node, Value, SMT};
+
+fn hash_function(nodes: Vec) -> Node {
+ let strings: Vec = nodes.iter().map(|node| node.to_string()).collect();
+ Node::Str(strings.join(","))
+}
+
+fn main() {
+ // Initialize the Sparse Merkle Tree with a hash function.
+ let mut smt = SMT::new(hash_function, false);
+
+ let key = Key::Str("aaa".to_string());
+ let value = Value::Str("bbb".to_string());
+
+ // Add a key-value pair to the Sparse Merkle Tree.
+ smt.add(key.clone(), value.clone()).unwrap();
+
+ // Get the value of the key.
+ let get = smt.get(key.clone());
+ assert_eq!(get, Some(value));
+
+ // Update the value of the key.
+ let new_value = Value::Str("ccc".to_string());
+ let update = smt.update(key.clone(), new_value.clone());
+ assert!(update.is_ok());
+ assert_eq!(smt.get(key.clone()), Some(new_value));
+
+ // Create and verify a proof for the key.
+ let create_proof = smt.create_proof(key.clone());
+ let verify_proof = smt.verify_proof(create_proof);
+ assert!(verify_proof);
+
+ // Delete the key.
+ let delete = smt.delete(key.clone());
+ assert!(delete.is_ok());
+ assert_eq!(smt.get(key.clone()), None);
+}
diff --git a/crates/smt/src/lib.rs b/crates/smt/src/lib.rs
new file mode 100644
index 0000000..6d39c77
--- /dev/null
+++ b/crates/smt/src/lib.rs
@@ -0,0 +1,2 @@
+pub mod smt;
+mod utils;
diff --git a/crates/smt/src/smt.rs b/crates/smt/src/smt.rs
new file mode 100644
index 0000000..265d487
--- /dev/null
+++ b/crates/smt/src/smt.rs
@@ -0,0 +1,933 @@
+use std::{collections::HashMap, str::FromStr};
+
+use num_bigint::BigInt;
+
+use crate::utils::{
+ get_first_common_elements, get_index_of_last_non_zero_element, is_hexadecimal, key_to_path,
+};
+
+use std::fmt;
+
+#[derive(Debug, PartialEq)]
+pub enum SMTError {
+ KeyAlreadyExist(String),
+ KeyDoesNotExist(String),
+ InvalidParameterType(String, String),
+ InvalidSiblingIndex,
+}
+
+impl fmt::Display for SMTError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ SMTError::KeyAlreadyExist(s) => write!(f, "Key {} already exists", s),
+ SMTError::KeyDoesNotExist(s) => write!(f, "Key {} does not exist", s),
+ SMTError::InvalidParameterType(p, t) => {
+ write!(f, "Parameter {} must be a {}", p, t)
+ },
+ SMTError::InvalidSiblingIndex => write!(f, "Invalid sibling index"),
+ }
+ }
+}
+
+impl std::error::Error for SMTError {}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum Node {
+ Str(String),
+ BigInt(BigInt),
+}
+
+impl fmt::Display for Node {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Node::Str(s) => write!(f, "{}", s),
+ Node::BigInt(n) => write!(f, "{}", n),
+ }
+ }
+}
+
+impl FromStr for Node {
+ type Err = SMTError;
+
+ fn from_str(s: &str) -> Result {
+ if let Ok(bigint) = s.parse::() {
+ Ok(Node::BigInt(bigint))
+ } else if is_hexadecimal(s) {
+ Ok(Node::Str(s.to_string()))
+ } else {
+ Err(SMTError::InvalidParameterType(
+ s.to_string(),
+ "BigInt or hexadecimal string".to_string(),
+ ))
+ }
+ }
+}
+
+pub type Key = Node;
+pub type Value = Node;
+pub type EntryMark = Node;
+
+pub type Entry = (Key, Value, EntryMark);
+pub type ChildNodes = Vec;
+pub type Siblings = Vec;
+
+pub type HashFunction = fn(ChildNodes) -> Node;
+
+pub struct EntryResponse {
+ pub entry: Vec,
+ pub matching_entry: Option>,
+ pub siblings: Siblings,
+}
+
+#[allow(dead_code)]
+pub struct MerkleProof {
+ entry_response: EntryResponse,
+ root: Node,
+ membership: bool,
+}
+
+#[allow(dead_code)]
+pub struct SMT {
+ hash: HashFunction,
+ big_numbers: bool,
+ zero_node: Node,
+ entry_mark: Node,
+ nodes: HashMap>,
+ root: Node,
+}
+
+impl SMT {
+ /// Initializes a new instance of the Sparse Merkle Tree (SMT).
+ ///
+ /// # Arguments
+ ///
+ /// * `hash` - The hash function used to hash the child nodes.
+ /// * `big_numbers` - A flag indicating whether the SMT supports big numbers or not.
+ ///
+ /// # Returns
+ ///
+ /// A new instance of the SMT.
+ pub fn new(hash: HashFunction, big_numbers: bool) -> Self {
+ let zero_node;
+ let entry_mark;
+
+ if big_numbers {
+ zero_node = Node::BigInt(BigInt::from(0));
+ entry_mark = Node::BigInt(BigInt::from(1));
+ } else {
+ zero_node = Node::Str("0".to_string());
+ entry_mark = Node::Str("1".to_string());
+ }
+
+ SMT {
+ hash,
+ big_numbers,
+ zero_node: zero_node.clone(),
+ entry_mark,
+ nodes: HashMap::new(),
+ root: zero_node,
+ }
+ }
+
+ /// Retrieves the value associated with the given key from the SMT.
+ ///
+ /// # Arguments
+ ///
+ /// * `key` - The key to retrieve the value for.
+ ///
+ /// # Returns
+ ///
+ /// An `Option` containing the value associated with the key, or `None` if the key does not exist.
+ pub fn get(&self, key: Key) -> Option {
+ let key = key.to_string().parse::().unwrap();
+
+ let EntryResponse { entry, .. } = self.retrieve_entry(key);
+
+ entry.get(1).cloned()
+ }
+
+ /// Adds a new key-value pair to the SMT.
+ ///
+ /// It retrieves a matching entry or a zero node with a top-down approach and then it updates
+ /// all the hashes of the nodes in the path of the new entry with a bottom up approach.
+ ///
+ /// # Arguments
+ ///
+ /// * `key` - The key to add.
+ /// * `value` - The value associated with the key.
+ ///
+ /// # Returns
+ ///
+ /// An `Result` indicating whether the operation was successful or not.
+ pub fn add(&mut self, key: Key, value: Value) -> Result<(), SMTError> {
+ let key = key.to_string().parse::().unwrap();
+ let value = value.to_string().parse::().unwrap();
+
+ let EntryResponse {
+ entry,
+ matching_entry,
+ mut siblings,
+ } = self.retrieve_entry(key.clone());
+
+ if entry.get(1).is_some() {
+ return Err(SMTError::KeyAlreadyExist(key.to_string()));
+ }
+
+ let path = key_to_path(&key.to_string());
+ // If there is a matching entry, its node is saved in the `node` variable, otherwise the
+ // `zero_node` is saved. This node is used below as the first node (starting from the
+ // bottom of the tree) to obtain the new nodes up to the root.
+ let node = if let Some(ref matching_entry) = matching_entry {
+ (self.hash)(matching_entry.clone())
+ } else {
+ self.zero_node.clone()
+ };
+
+ // If there are siblings, the old nodes are deleted and will be re-created below with new hashes.
+ if !siblings.is_empty() {
+ self.delete_old_nodes(node.clone(), &path, &siblings)
+ }
+
+ // If there is a matching entry, further N zero siblings are added in the `siblings` vector,
+ // followed by the matching node itself. N is the number of the first matching bits of the paths.
+ // This is helpful in the non-membership proof verification as explained in the function below.
+ if let Some(matching_entry) = matching_entry {
+ let matching_path = key_to_path(&matching_entry[0].to_string());
+ let mut i = siblings.len();
+
+ while matching_path[i] == path[i] {
+ siblings.push(self.zero_node.clone());
+ i += 1;
+ }
+
+ siblings.push(node.clone());
+ }
+
+ // Adds the new entry and re-creates the nodes of the path with the new hashes with a bottom
+ // up approach. The `add_new_nodes` function returns the new root of the tree.
+ let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]);
+
+ self.nodes
+ .insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]);
+ self.root = self
+ .add_new_nodes(new_node, &path, &siblings, None)
+ .unwrap();
+
+ Ok(())
+ }
+
+ /// Updates the value associated with the given key in the SMT.
+ ///
+ /// Also in this case, all the hashes of the nodes in the path of the updated entry are updated
+ /// with a bottom up approach.
+ ///
+ /// # Arguments
+ ///
+ /// * `key` - The key to update the value for.
+ /// * `value` - The new value associated with the key.
+ ///
+ /// # Returns
+ ///
+ /// An `Result` indicating whether the operation was successful or not.
+ pub fn update(&mut self, key: Key, value: Value) -> Result<(), SMTError> {
+ let key = key.to_string().parse::().unwrap();
+ let value = value.to_string().parse::().unwrap();
+
+ let EntryResponse {
+ entry, siblings, ..
+ } = self.retrieve_entry(key.clone());
+
+ if entry.get(1).is_none() {
+ return Err(SMTError::KeyDoesNotExist(key.to_string()));
+ }
+
+ let path = key_to_path(&key.to_string());
+
+ // Deletes the old nodes and re-creates them with the new hashes.
+ let old_node = (self.hash)(entry.clone());
+ self.nodes.remove(&old_node);
+ self.delete_old_nodes(old_node.clone(), &path, &siblings);
+
+ let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]);
+ self.nodes
+ .insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]);
+ self.root = self
+ .add_new_nodes(new_node, &path, &siblings, None)
+ .unwrap();
+
+ Ok(())
+ }
+
+ /// Deletes the key-value pair associated with the given key from the SMT.
+ ///
+ /// Also in this case, all the hashes of the nodes in the path of the deleted entry are updated
+ /// with a bottom up approach.
+ ///
+ /// # Arguments
+ ///
+ /// * `key` - The key to delete.
+ ///
+ /// # Returns
+ ///
+ /// An `Result` indicating whether the operation was successful or not.
+ pub fn delete(&mut self, key: Key) -> Result<(), SMTError> {
+ let key = key.to_string().parse::().unwrap();
+
+ let EntryResponse {
+ entry,
+ mut siblings,
+ ..
+ } = self.retrieve_entry(key.clone());
+
+ if entry.get(1).is_none() {
+ return Err(SMTError::KeyDoesNotExist(key.to_string()));
+ }
+
+ let path = key_to_path(&key.to_string());
+
+ let node = (self.hash)(entry.clone());
+ self.nodes.remove(&node);
+
+ self.root = self.zero_node.clone();
+
+ // If there are siblings, the old nodes are deleted and will be re-created below with new hashes.
+ if !siblings.is_empty() {
+ self.delete_old_nodes(node.clone(), &path, &siblings);
+
+ // If the last sibling is not a leaf node, it adds all the nodes of the path starting from
+ // a zero node, otherwise it removes the last non-zero sibling from the `siblings` vector
+ // and it starts from it by skipping the last zero nodes.
+ if !self.is_leaf(&siblings.last().cloned().unwrap()) {
+ self.root = self
+ .add_new_nodes(self.zero_node.clone(), &path, &siblings, None)
+ .unwrap();
+ } else {
+ let first_sibling = siblings.pop().unwrap();
+ let i = get_index_of_last_non_zero_element(
+ siblings
+ .iter()
+ .map(|s| s.to_string())
+ .collect::>()
+ .iter()
+ .map(|s| s.as_str())
+ .collect::>(),
+ );
+
+ self.root = self.add_new_nodes(first_sibling, &path, &siblings, Some(i))?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Creates a proof to prove the membership or the non-membership of a tree entry.
+ ///
+ /// # Arguments
+ ///
+ /// * `key` - The key to create the proof for.
+ ///
+ /// # Returns
+ ///
+ /// A `MerkleProof` containing the proof information.
+ pub fn create_proof(&self, key: Key) -> MerkleProof {
+ let key = key.to_string().parse::().unwrap();
+
+ let EntryResponse {
+ entry,
+ matching_entry,
+ siblings,
+ } = self.retrieve_entry(key);
+
+ // If the key exists, the function returns a proof with the entry itself, otherwise it returns
+ // a non-membership proof with the matching entry.
+ MerkleProof {
+ entry_response: EntryResponse {
+ entry: entry.clone(),
+ matching_entry,
+ siblings,
+ },
+ root: self.root.clone(),
+ membership: entry.get(1).is_some(),
+ }
+ }
+
+ /// Verifies a membership or a non-membership proof for a given key in the SMT.
+ ///
+ /// # Arguments
+ ///
+ /// * `merkle_proof` - The Merkle proof to verify.
+ ///
+ /// # Returns
+ ///
+ /// A boolean indicating whether the proof is valid or not.
+ pub fn verify_proof(&self, merkle_proof: MerkleProof) -> bool {
+ // If there is no matching entry, it simply obtains the root hash by using the siblings and the
+ // path of the key.
+ if merkle_proof.entry_response.matching_entry.is_none() {
+ let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string());
+ // If there is not an entry value, the proof is a non-membership proof. In this case, since there
+ // is not a matching entry, the node is set to a zero node. If there is an entry value, the proof
+ // is a membership proof and the node is set to the hash of the entry.
+ let node = if merkle_proof.entry_response.entry.get(1).is_some() {
+ (self.hash)(merkle_proof.entry_response.entry)
+ } else {
+ self.zero_node.clone()
+ };
+ let root = self.calculate_root(node, &path, &merkle_proof.entry_response.siblings);
+
+ // If the obtained root is equal to the proof root, then the proof is valid.
+ return root == merkle_proof.root;
+ }
+
+ // If there is a matching entry, the proof is definitely a non-membership proof. In this case, it checks
+ // if the matching node belongs to the tree, and then it checks if the number of the first matching bits
+ // of the keys is greater than or equal to the number of the siblings.
+ if let Some(matching_entry) = &merkle_proof.entry_response.matching_entry {
+ let matching_path = key_to_path(&matching_entry[0].to_string());
+ let node = (self.hash)(matching_entry.to_vec());
+ let root =
+ self.calculate_root(node, &matching_path, &merkle_proof.entry_response.siblings);
+
+ if root == merkle_proof.root {
+ let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string());
+ // Returns the first common bits of the two keys: the non-member key and the matching key.
+ let first_matching_bits = get_first_common_elements(&path, &matching_path);
+
+ // If the non-member key was a key of a tree entry, the depth of the matching node should be
+ // greater than the number of the fisrt matching bits. Otherwise, the depth of the node can be
+ // defined by the number of its siblings.
+ return merkle_proof.entry_response.siblings.len() <= first_matching_bits.len();
+ }
+ }
+
+ false
+ }
+
+ /// Retrieves the entry associated with the given key from the SMT.
+ ///
+ /// If the key passed as parameter exists in the SMT, the function returns the entry itself, otherwise
+ /// it returns the entry with only the key. When there is another matching entry in the same path, it
+ /// returns the matching entry as well.
+ ///
+ /// In any case, the function returns the siblings of the path.
+ ///
+ /// # Arguments
+ ///
+ /// * `key` - The key to retrieve the entry for.
+ ///
+ /// # Returns
+ ///
+ /// An `EntryResponse` struct containing the entry, the matching entry (if any), and the siblings of the leaf node.
+ fn retrieve_entry(&self, key: Key) -> EntryResponse {
+ let path = key_to_path(&key.to_string());
+ let mut siblings: Siblings = Vec::new();
+ let mut node = self.root.clone();
+
+ let mut i = 0;
+
+ // Starting from the root, it traverses the tree until it reaches a leaf node, a zero node,
+ // or a matching entry.
+ while node != self.zero_node {
+ let child_nodes = self.nodes.get(&node).unwrap_or(&Vec::new()).clone();
+ let direction = path[i];
+
+ // If the third element of the child nodes is not None, it means that the node is an entry of the tree.
+ if child_nodes.get(2).is_some() {
+ if child_nodes[0] == key {
+ // An entry is found with the same key, and it returns it with the siblings.
+ return EntryResponse {
+ entry: child_nodes,
+ matching_entry: None,
+ siblings,
+ };
+ }
+
+ // An entry was found with a different key, but the key of this particular entry matches the first 'i'
+ // bits of the key passed as parameter. It can be useful in several functions.
+ return EntryResponse {
+ entry: vec![key.clone()],
+ matching_entry: Some(child_nodes),
+ siblings,
+ };
+ }
+
+ // When it goes down into the tree and follows the path, in every step a node is chosen between left
+ // and right child nodes, and the opposite node is saved in the `siblings` vector.
+ node = child_nodes[direction].clone();
+ siblings.push(child_nodes[1 - direction].clone());
+
+ i += 1;
+ }
+
+ // The path led to a zero node.
+ EntryResponse {
+ entry: vec![key],
+ matching_entry: None,
+ siblings,
+ }
+ }
+
+ /// Calculates the root of the tree by using the given node, the path, and the siblings.
+ ///
+ /// It calculates with a bottom up approach by starting from the node and going up to the root.
+ ///
+ /// # Arguments
+ ///
+ /// * `node` - The node to start the calculation from.
+ /// * `path` - The path of the key.
+ /// * `siblings` - The siblings of the path.
+ ///
+ /// # Returns
+ ///
+ /// The root of the tree.
+ fn calculate_root(&self, mut node: Node, path: &[usize], siblings: &Siblings) -> Node {
+ for i in (0..siblings.len()).rev() {
+ let child_nodes: ChildNodes = if path[i] != 0 {
+ vec![siblings[i].clone(), node.clone()]
+ } else {
+ vec![node.clone(), siblings[i].clone()]
+ };
+
+ node = (self.hash)(child_nodes);
+ }
+
+ node
+ }
+
+ /// Adds new nodes to the tree with the new hashes.
+ ///
+ /// It starts with a bottom up approach until it reaches the root of the tree.
+ ///
+ /// # Arguments
+ ///
+ /// * `node` - The node to start the calculation from.
+ /// * `path` - The path of the key.
+ /// * `siblings` - The siblings of the path.
+ /// * `i` - The index of the sibling to start from.
+ ///
+ /// # Returns
+ ///
+ /// The new root of the tree.
+ fn add_new_nodes(
+ &mut self,
+ mut node: Node,
+ path: &[usize],
+ siblings: &Siblings,
+ i: Option,
+ ) -> Result {
+ let mut starting_index = if let Some(i) = i {
+ i
+ } else {
+ siblings.len() as isize - 1
+ };
+
+ while starting_index > 0 {
+ if siblings.get(starting_index as usize).is_none() {
+ return Err(SMTError::InvalidSiblingIndex);
+ }
+
+ let child_nodes: ChildNodes = if path.get(starting_index as usize).is_some() {
+ vec![siblings[starting_index as usize].clone(), node.clone()]
+ } else {
+ vec![node.clone(), siblings[starting_index as usize].clone()]
+ };
+
+ node = (self.hash)(child_nodes.clone());
+
+ self.nodes.insert(node.clone(), child_nodes);
+
+ starting_index -= 1;
+ }
+
+ Ok(node)
+ }
+
+ /// Deletes the old nodes of the tree.
+ ///
+ /// It starts with a bottom up approach until it reaches the root of the tree.
+ ///
+ /// # Arguments
+ ///
+ /// * `node` - The node to start the calculation from.
+ /// * `path` - The path of the key.
+ /// * `siblings` - The siblings of the path.
+ fn delete_old_nodes(&mut self, mut node: Node, path: &[usize], siblings: &Siblings) {
+ for i in (0..siblings.len()).rev() {
+ let child_nodes: ChildNodes = if path.get(i).is_some() {
+ vec![siblings[i].clone(), node.clone()]
+ } else {
+ vec![node.clone(), siblings[i].clone()]
+ };
+
+ node = (self.hash)(child_nodes);
+
+ self.nodes.remove(&node);
+ }
+ }
+
+ /// Checks if the given node is a leaf node or not.
+ ///
+ /// # Arguments
+ ///
+ /// * `node` - The node to check.
+ ///
+ /// # Returns
+ ///
+ /// A boolean indicating whether the node is a leaf node or not.
+ fn is_leaf(&self, node: &Node) -> bool {
+ if let Some(child_nodes) = self.nodes.get(node) {
+ child_nodes.get(2).is_some()
+ } else {
+ false
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn hash_function(nodes: Vec) -> Node {
+ let strings: Vec = nodes.iter().map(|node| node.to_string()).collect();
+ Node::Str(strings.join(","))
+ }
+
+ #[test]
+ fn test_new() {
+ let smt = SMT::new(hash_function, false);
+ assert!(!smt.big_numbers);
+ assert_eq!(smt.zero_node, Node::Str("0".to_string()));
+ assert_eq!(smt.entry_mark, Node::Str("1".to_string()));
+ assert_eq!(smt.nodes, HashMap::new());
+ assert_eq!(smt.root, Node::Str("0".to_string()));
+
+ let smt = SMT::new(hash_function, true);
+ assert!(smt.big_numbers);
+ assert_eq!(smt.zero_node, Node::BigInt(BigInt::from(0)));
+ assert_eq!(smt.entry_mark, Node::BigInt(BigInt::from(1)));
+ assert_eq!(smt.nodes, HashMap::new());
+ assert_eq!(smt.root, Node::BigInt(BigInt::from(0)));
+ }
+
+ #[test]
+ fn test_get() {
+ let mut smt = SMT::new(hash_function, false);
+ let key = Key::Str("aaa".to_string());
+ let value = Value::Str("bbb".to_string());
+ let _ = smt.add(key.clone(), value.clone());
+ let result = smt.get(key.clone());
+ println!("{:?}", result);
+ assert_eq!(result, Some(value));
+
+ let key2 = Key::Str("ccc".to_string());
+ let result2 = smt.get(key2.clone());
+ assert_eq!(result2, None);
+
+ let mut smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let value = Value::BigInt(BigInt::from(456));
+ let _ = smt.add(key.clone(), value.clone());
+ let result = smt.get(key.clone());
+ assert_eq!(result, Some(value));
+ }
+ #[test]
+ fn test_add() {
+ let mut smt = SMT::new(hash_function, false);
+ let key = Key::Str("aaa".to_string());
+ let value = Value::Str("bbb".to_string());
+ let result = smt.add(key.clone(), value.clone());
+ assert!(result.is_ok());
+ assert_eq!(smt.nodes.len(), 1);
+ assert_eq!(
+ smt.nodes.get(&smt.root),
+ Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()])
+ );
+
+ let mut smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let value = Value::BigInt(BigInt::from(456));
+ let result = smt.add(key.clone(), value.clone());
+ assert!(result.is_ok());
+ assert_eq!(smt.nodes.len(), 1);
+ assert_eq!(
+ smt.nodes.get(&smt.root),
+ Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()])
+ );
+ }
+
+ #[test]
+ fn test_update() {
+ let mut smt = SMT::new(hash_function, false);
+ let key = Key::Str("aaa".to_string());
+ let value = Value::Str("bbb".to_string());
+ let _ = smt.add(key.clone(), value.clone());
+
+ let new_value = Value::Str("ccc".to_string());
+ let result = smt.update(key.clone(), new_value.clone());
+ assert!(result.is_ok());
+ assert_eq!(smt.nodes.len(), 1);
+ assert_eq!(
+ smt.nodes.get(&smt.root),
+ Some(&vec![
+ key.clone(),
+ new_value.clone(),
+ smt.entry_mark.clone()
+ ])
+ );
+
+ let key2 = Key::Str("def".to_string());
+ let result2 = smt.update(key2.clone(), new_value.clone());
+ assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string())));
+
+ let mut smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let value = Value::BigInt(BigInt::from(456));
+ let _ = smt.add(key.clone(), value.clone());
+
+ let new_value = Value::BigInt(BigInt::from(789));
+ let result = smt.update(key.clone(), new_value.clone());
+ assert!(result.is_ok());
+ assert_eq!(smt.nodes.len(), 1);
+ assert_eq!(
+ smt.nodes.get(&smt.root),
+ Some(&vec![
+ key.clone(),
+ new_value.clone(),
+ smt.entry_mark.clone()
+ ])
+ );
+ }
+
+ #[test]
+ fn test_delete() {
+ let mut smt = SMT::new(hash_function, false);
+ let key = Key::Str("abc".to_string());
+ let value = Value::Str("123".to_string());
+ let _ = smt.add(key.clone(), value.clone());
+ let result = smt.delete(key.clone());
+ assert!(result.is_ok());
+ assert_eq!(smt.nodes.len(), 0);
+ assert_eq!(smt.root, smt.zero_node);
+
+ let key2 = Key::Str("def".to_string());
+ let result2 = smt.delete(key2.clone());
+ assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string())));
+
+ let mut smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let value = Value::BigInt(BigInt::from(456));
+ let _ = smt.add(key.clone(), value.clone());
+ let result = smt.delete(key.clone());
+ assert!(result.is_ok());
+ assert_eq!(smt.nodes.len(), 0);
+ assert_eq!(smt.root, smt.zero_node);
+ }
+
+ #[test]
+ fn test_create_proof() {
+ let mut smt = SMT::new(hash_function, false);
+ let key = Key::Str("abc".to_string());
+ let value = Value::Str("123".to_string());
+ let _ = smt.add(key.clone(), value.clone());
+ let proof = smt.create_proof(key.clone());
+ assert_eq!(proof.root, smt.root);
+
+ let mut smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let value = Value::BigInt(BigInt::from(456));
+ let _ = smt.add(key.clone(), value.clone());
+ let proof = smt.create_proof(key.clone());
+ assert_eq!(proof.root, smt.root);
+ }
+
+ #[test]
+ fn test_verify_proof() {
+ let mut smt = SMT::new(hash_function, false);
+ let key = Key::Str("abc".to_string());
+ let value = Value::Str("123".to_string());
+ let _ = smt.add(key.clone(), value.clone());
+ let proof = smt.create_proof(key.clone());
+ let result = smt.verify_proof(proof);
+ assert!(result);
+
+ let key2 = Key::Str("def".to_string());
+ let false_proof = MerkleProof {
+ entry_response: EntryResponse {
+ entry: vec![key2.clone()],
+ matching_entry: None,
+ siblings: Vec::new(),
+ },
+ root: smt.root.clone(),
+ membership: false,
+ };
+ let fun = smt.verify_proof(false_proof);
+ assert!(!fun);
+
+ let mut smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let value = Value::BigInt(BigInt::from(456));
+ let _ = smt.add(key.clone(), value.clone());
+ let proof = smt.create_proof(key.clone());
+ let result = smt.verify_proof(proof);
+ assert!(result);
+
+ let key2 = Key::BigInt(BigInt::from(789));
+ let false_proof = MerkleProof {
+ entry_response: EntryResponse {
+ entry: vec![key2.clone()],
+ matching_entry: None,
+ siblings: Vec::new(),
+ },
+ root: smt.root.clone(),
+ membership: true,
+ };
+ let fun = smt.verify_proof(false_proof);
+ assert!(!fun);
+ }
+
+ #[test]
+ fn test_retrieve_entry() {
+ let smt = SMT::new(hash_function, false);
+ let key = Key::Str("be12".to_string());
+ let entry_response = smt.retrieve_entry(key.clone());
+ assert_eq!(entry_response.entry, vec![key]);
+ assert_eq!(entry_response.matching_entry, None);
+ assert_eq!(entry_response.siblings, Vec::new());
+
+ let smt = SMT::new(hash_function, true);
+ let key = Key::BigInt(BigInt::from(123));
+ let entry_response = smt.retrieve_entry(key.clone());
+ assert_eq!(entry_response.entry, vec![key]);
+ assert_eq!(entry_response.matching_entry, None);
+ assert_eq!(entry_response.siblings, Vec::new());
+ }
+
+ #[test]
+ fn test_calculate_root() {
+ let smt = SMT::new(hash_function, false);
+ let node = Node::Str("node".to_string());
+ let path = &[0, 1, 0];
+ let siblings = vec![
+ Node::Str("sibling1".to_string()),
+ Node::Str("sibling2".to_string()),
+ Node::Str("sibling3".to_string()),
+ ];
+ let root = smt.calculate_root(node.clone(), path, &siblings);
+ assert_eq!(
+ root,
+ Node::Str("sibling2,node,sibling3,sibling1".to_string())
+ );
+
+ let smt = SMT::new(hash_function, true);
+ let node = Node::BigInt(BigInt::from(123));
+ let path = &[1, 0];
+ let siblings = vec![
+ Node::BigInt(BigInt::from(456)),
+ Node::BigInt(BigInt::from(789)),
+ ];
+ let root = smt.calculate_root(node.clone(), path, &siblings);
+ assert_eq!(root, Node::Str("456,123,789".to_string()));
+ }
+
+ #[test]
+ fn test_add_new_nodes() {
+ let mut smt = SMT::new(hash_function, false);
+ let node = Node::Str("node".to_string());
+ let path = &[0, 1, 0];
+ let siblings = vec![
+ Node::Str("sibling1".to_string()),
+ Node::Str("sibling2".to_string()),
+ Node::Str("sibling3".to_string()),
+ ];
+ let new_node = smt
+ .add_new_nodes(node.clone(), path, &siblings, None)
+ .unwrap();
+ assert_eq!(new_node, Node::Str("sibling2,sibling3,node".to_string()));
+
+ let starting_index = smt
+ .add_new_nodes(node.clone(), path, &siblings, Some(1))
+ .unwrap();
+ assert_eq!(starting_index, Node::Str("sibling2,node".to_string()));
+
+ let mut smt = SMT::new(hash_function, true);
+ let node = Node::BigInt(BigInt::from(111));
+ let path = &[1, 0, 0];
+ let siblings = vec![
+ Node::BigInt(BigInt::from(222)),
+ Node::BigInt(BigInt::from(333)),
+ Node::BigInt(BigInt::from(444)),
+ ];
+ let new_node = smt
+ .add_new_nodes(node.clone(), path, &siblings, None)
+ .unwrap();
+ assert_eq!(new_node, Node::Str("333,444,111".to_string()));
+
+ let starting_index = smt
+ .add_new_nodes(node.clone(), path, &siblings, Some(1))
+ .unwrap();
+ assert_eq!(starting_index, Node::Str("333,111".to_string()));
+ }
+
+ #[test]
+ fn test_delete_old_nodes() {
+ let mut smt = SMT::new(hash_function, false);
+ let node = Node::Str("abc".to_string());
+ let path = &[0, 1, 0];
+ let siblings = vec![
+ Node::Str("sibling1".to_string()),
+ Node::Str("sibling2".to_string()),
+ Node::Str("sibling3".to_string()),
+ ];
+ let new_node = smt
+ .add_new_nodes(node.clone(), path, &siblings, None)
+ .unwrap();
+ assert_eq!(new_node, Node::Str("sibling2,sibling3,abc".to_string()));
+ smt.delete_old_nodes(node.clone(), path, &siblings);
+ assert_eq!(smt.nodes.len(), 0);
+
+ let mut smt = SMT::new(hash_function, true);
+ let node = Node::BigInt(BigInt::from(123));
+ let path = &[1, 0];
+ let siblings = vec![
+ Node::BigInt(BigInt::from(456)),
+ Node::BigInt(BigInt::from(789)),
+ ];
+ let new_node = smt
+ .add_new_nodes(node.clone(), path, &siblings, None)
+ .unwrap();
+ assert_eq!(new_node, Node::Str("789,123".to_string()));
+ smt.delete_old_nodes(node.clone(), path, &siblings);
+ assert_eq!(smt.nodes.len(), 0);
+ }
+
+ #[test]
+ fn test_is_leaf() {
+ let mut smt = SMT::new(hash_function, false);
+ let node = Node::Str("abc".to_string());
+ assert!(!smt.is_leaf(&node));
+
+ smt.nodes.insert(
+ Node::Str("abc".to_string()),
+ vec![
+ Node::Str("123".to_string()),
+ Node::Str("456".to_string()),
+ Node::Str("789".to_string()),
+ ],
+ );
+ assert!(smt.is_leaf(&node));
+
+ let mut smt = SMT::new(hash_function, true);
+ let node = Node::BigInt(BigInt::from(123));
+ assert!(!smt.is_leaf(&node));
+
+ smt.nodes.insert(
+ Node::BigInt(BigInt::from(123)),
+ vec![
+ Node::BigInt(BigInt::from(111)),
+ Node::BigInt(BigInt::from(222)),
+ Node::BigInt(BigInt::from(333)),
+ ],
+ );
+ assert!(smt.is_leaf(&node));
+ }
+}
diff --git a/crates/smt/src/utils.rs b/crates/smt/src/utils.rs
new file mode 100644
index 0000000..1943406
--- /dev/null
+++ b/crates/smt/src/utils.rs
@@ -0,0 +1,185 @@
+/// Converts a hexadecimal string to a binary string.
+///
+/// # Arguments
+///
+/// * `n` - The hexadecimal string to convert.
+///
+/// # Returns
+///
+/// The binary representation of the hexadecimal string.
+pub fn hex_to_bin(n: &str) -> String {
+ let mut chars = n.chars();
+ let first_char = chars.next().unwrap();
+ let mut bin = format!(
+ "{:b}",
+ u8::from_str_radix(&first_char.to_string(), 16).unwrap()
+ );
+
+ for c in chars {
+ bin += &format!("{:04b}", u8::from_str_radix(&c.to_string(), 16).unwrap());
+ }
+
+ bin
+}
+
+/// Converts a hexadecimal key to a path represented as a vector of usize.
+///
+/// For each key, it is possible to obtain an array of 256 padded bits.
+///
+/// # Arguments
+///
+/// * `key` - The hexadecimal key to convert.
+///
+/// # Returns
+///
+/// The path represented as a vector of usize.
+pub fn key_to_path(key: &str) -> Vec {
+ let bits = if let Ok(num) = u128::from_str_radix(key, 16) {
+ format!("{:b}", num)
+ } else {
+ hex_to_bin(key)
+ };
+
+ let padded_bits = format!("{:0>256}", bits).chars().rev().collect::();
+ let bits_array = padded_bits
+ .chars()
+ .map(|c| c.to_digit(10).unwrap() as usize)
+ .collect();
+
+ bits_array
+}
+
+/// Returns the index of the last non-zero element in the array.
+///
+/// # Arguments
+///
+/// * `array` - The array of hexadecimal strings.
+///
+/// # Returns
+///
+/// The index of the last non-zero element in the array, or -1 if no non-zero element is found.
+pub fn get_index_of_last_non_zero_element(array: Vec<&str>) -> isize {
+ for (i, &item) in array.iter().enumerate().rev() {
+ if u128::from_str_radix(item, 16).unwrap_or(0) != 0 {
+ return i as isize;
+ }
+ }
+
+ -1
+}
+
+/// Returns the first common elements between two arrays.
+///
+/// # Arguments
+///
+/// * `array1` - The first array.
+/// * `array2` - The second array.
+///
+/// # Returns
+///
+/// The first common elements between the two arrays.
+pub fn get_first_common_elements(array1: &[T], array2: &[T]) -> Vec {
+ let min_length = std::cmp::min(array1.len(), array2.len());
+
+ for i in 0..min_length {
+ if array1[i] != array2[i] {
+ return array1[0..i].to_vec();
+ }
+ }
+
+ array1[0..min_length].to_vec()
+}
+
+/// Checks if a string is a valid hexadecimal string.
+///
+/// # Arguments
+///
+/// * `s` - The string to check.
+///
+/// # Returns
+///
+/// `true` if the string is a valid hexadecimal string, `false` otherwise.
+pub fn is_hexadecimal(s: &str) -> bool {
+ s.chars().all(|c| c.is_ascii_hexdigit())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_hex_to_bin() {
+ assert_eq!(hex_to_bin("A"), "1010");
+ assert_eq!(hex_to_bin("F"), "1111");
+ assert_eq!(hex_to_bin("1A"), "11010");
+ assert_eq!(hex_to_bin("FF"), "11111111");
+ assert_eq!(hex_to_bin("12"), "10010");
+ }
+
+ #[test]
+ fn test_key_to_path() {
+ let path = key_to_path("17");
+ assert_eq!(path.len(), 256);
+ assert_eq!(&path[0..5], vec![1, 1, 1, 0, 1]);
+ }
+
+ #[test]
+ fn test_get_index_of_last_non_zero_element() {
+ assert_eq!(get_index_of_last_non_zero_element(vec![]), -1);
+ assert_eq!(get_index_of_last_non_zero_element(vec!["0", "0", "0"]), -1);
+
+ assert_eq!(get_index_of_last_non_zero_element(vec!["0", "0", "1"]), 2);
+ assert_eq!(get_index_of_last_non_zero_element(vec!["0", "1", "0"]), 1);
+ assert_eq!(get_index_of_last_non_zero_element(vec!["1", "0", "0"]), 0);
+
+ assert_eq!(
+ get_index_of_last_non_zero_element(vec!["0", "1", "0", "1", "0"]),
+ 3
+ );
+ assert_eq!(
+ get_index_of_last_non_zero_element(vec!["1", "0", "1", "0", "0"]),
+ 2
+ );
+ assert_eq!(
+ get_index_of_last_non_zero_element(vec!["0", "0", "0", "1", "1"]),
+ 4
+ );
+ assert_eq!(
+ get_index_of_last_non_zero_element(vec![
+ "0", "17", "3", "0", "3", "0", "3", "2", "0", "0"
+ ]),
+ 7
+ )
+ }
+
+ #[test]
+ fn test_get_first_common_elements() {
+ assert_eq!(get_first_common_elements::(&[], &[]), vec![]);
+
+ assert_eq!(
+ get_first_common_elements(&[1, 2, 3], &[1, 2, 3, 4, 5]),
+ vec![1, 2, 3]
+ );
+ assert_eq!(
+ get_first_common_elements(&[1, 2, 3, 4, 5], &[1, 2, 3]),
+ vec![1, 2, 3]
+ );
+
+ assert_eq!(
+ get_first_common_elements(&[1, 2, 3], &[1, 2, 4]),
+ vec![1, 2]
+ );
+ assert_eq!(get_first_common_elements(&[1, 2, 3], &[4, 5, 6]), vec![]);
+ }
+
+ #[test]
+ fn test_is_hexadecimal() {
+ assert!(is_hexadecimal("be12"));
+ assert!(is_hexadecimal("ABCDEF"));
+ assert!(is_hexadecimal("1234567890abcdef"));
+
+ assert!(!is_hexadecimal("gbe12"));
+ assert!(!is_hexadecimal("123XYZ"));
+ assert!(!is_hexadecimal("abcdefg"));
+ }
+}