diff --git a/benches/mmr_benchmark.rs b/benches/mmr_benchmark.rs index d7c9b19..012e4f9 100644 --- a/benches/mmr_benchmark.rs +++ b/benches/mmr_benchmark.rs @@ -78,7 +78,7 @@ fn bench(c: &mut Criterion) { let proofs: Vec<_> = (0..10_000) .map(|_| { let pos = positions.choose(&mut rng).unwrap(); - let elem = (&store).get_elem(*pos).unwrap().unwrap(); + let elem = (&store).get(*pos).unwrap().unwrap(); let proof = mmr.gen_proof(vec![*pos]).unwrap(); (pos, elem, proof) }) diff --git a/src/error.rs b/src/error.rs index c1c9276..9ae1b30 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,7 +11,8 @@ pub enum Error { NodeProofsNotSupported, /// The leaves is an empty list, or beyond the mmr range GenProofForInvalidLeaves, - + /// The position of the updating leaf is out of range + UpdateLeafOutOfRange, /// The two nodes couldn't merge into one. MergeError(crate::string::String), } @@ -26,6 +27,7 @@ impl core::fmt::Display for Error { CorruptedProof => write!(f, "Corrupted proof")?, NodeProofsNotSupported => write!(f, "Tried to verify membership of a non-leaf")?, GenProofForInvalidLeaves => write!(f, "Generate proof ofr invalid leaves")?, + UpdateLeafOutOfRange => write!(f, "Update leaf out of range")?, MergeError(msg) => write!(f, "Merge error {}", msg)?, } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 96b3fc0..1e15dbe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,11 +21,13 @@ cfg_if::cfg_if! { use std::collections; use std::vec; use std::string; + use std::collections::BTreeMap; } else { extern crate alloc; use alloc::borrow; use alloc::collections; use alloc::vec; use alloc::string; + use alloc::collections::BTreeMap; } } diff --git a/src/mmr.rs b/src/mmr.rs index fcd32c4..9926aa0 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -45,9 +45,17 @@ impl MMR { &self.batch } + pub fn batch_mut(&mut self) -> &mut MMRBatch { + &mut self.batch + } + pub fn store(&self) -> &S { self.batch.store() } + + pub fn store_mut(&mut self) -> &mut S { + self.batch.store_mut() + } } impl, S: MMRStoreReadOps> MMR { @@ -61,7 +69,7 @@ impl, S: MMRStoreReadOps> MMR Result { let mut elems = vec![elem]; let elem_pos = self.mmr_size; @@ -84,6 +92,48 @@ impl, S: MMRStoreReadOps> MMR Result<()> { + if pos >= self.mmr_size || pos_height_in_tree(pos) > 0 { + return Err(Error::UpdateLeafOutOfRange); + } + self.batch.insert(pos, elem.clone()); + + let peaks = get_peaks(self.mmr_size); + let peak = peaks + .into_iter() + .find(|p| pos <= *p) + .expect("checked pos < self.mmr_size"); + let mut height = 0; + while pos < peak { + let next_height = pos_height_in_tree(pos + 1); + let sibling_offset = sibling_offset(height); + if next_height > height { + elem = M::merge( + &self + .batch + .get_elem(pos - sibling_offset)? + .ok_or(Error::InconsistentStore)?, + &elem, + )?; + pos += 1; + } else { + elem = M::merge( + &elem, + &self + .batch + .get_elem(pos + sibling_offset)? + .ok_or(Error::InconsistentStore)?, + )?; + pos += parent_offset(height); + } + self.batch.insert(pos, elem.clone()); + height += 1; + } + + Ok(()) + } + /// get_root pub fn get_root(&self) -> Result { if self.mmr_size == 0 { diff --git a/src/mmr_store.rs b/src/mmr_store.rs index 1586d70..3df38e8 100644 --- a/src/mmr_store.rs +++ b/src/mmr_store.rs @@ -1,65 +1,63 @@ -use crate::{vec::Vec, Result}; +use crate::{vec::Vec, BTreeMap, Result}; #[derive(Default)] pub struct MMRBatch { - memory_batch: Vec<(u64, Vec)>, + memory_batch: BTreeMap, store: Store, } impl MMRBatch { pub fn new(store: Store) -> Self { MMRBatch { - memory_batch: Vec::new(), + memory_batch: BTreeMap::new(), store, } } pub fn append(&mut self, pos: u64, elems: Vec) { - self.memory_batch.push((pos, elems)); + for (i, elem) in elems.into_iter().enumerate() { + self.insert(pos + i as u64, elem); + } + } + + pub fn insert(&mut self, pos: u64, elem: Elem) { + self.memory_batch.insert(pos, elem); } pub fn store(&self) -> &Store { &self.store } + + pub fn store_mut(&mut self) -> &mut Store { + &mut self.store + } } impl> MMRBatch { pub fn get_elem(&self, pos: u64) -> Result> { - for (start_pos, elems) in self.memory_batch.iter().rev() { - if pos < *start_pos { - continue; - } else if pos < start_pos + elems.len() as u64 { - return Ok(elems.get((pos - start_pos) as usize).cloned()); - } else { - break; - } + if let Some(elem) = self.memory_batch.get(&pos) { + Ok(Some(elem.clone())) + } else { + self.store.get(pos) } - self.store.get_elem(pos) } } impl> MMRBatch { pub fn commit(&mut self) -> Result<()> { - for (pos, elems) in self.memory_batch.drain(..) { - self.store.append(pos, elems)?; + let batch = core::mem::take(&mut self.memory_batch); + + for (pos, elem) in batch { + self.store.insert(pos, elem)?; } Ok(()) } } -impl IntoIterator for MMRBatch { - type Item = (u64, Vec); - type IntoIter = crate::vec::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.memory_batch.into_iter() - } -} - pub trait MMRStoreReadOps { - fn get_elem(&self, pos: u64) -> Result>; + fn get(&self, pos: u64) -> Result>; } pub trait MMRStoreWriteOps { - fn append(&mut self, pos: u64, elems: Vec) -> Result<()>; + fn insert(&mut self, pos: u64, elem: Elem) -> Result<()>; } diff --git a/src/tests/test_accumulate_headers.rs b/src/tests/test_accumulate_headers.rs index 71c26b8..8f75700 100644 --- a/src/tests/test_accumulate_headers.rs +++ b/src/tests/test_accumulate_headers.rs @@ -106,7 +106,7 @@ impl Prover { let mut mmr = MMR::<_, MergeHashWithTD, _>::new(self.positions.len() as u64, &self.store); // get previous element let mut previous = if let Some(pos) = self.positions.last() { - mmr.store().get_elem(*pos)?.expect("exists") + mmr.store().get(*pos)?.expect("exists") } else { let genesis = Header::default(); @@ -187,7 +187,7 @@ fn test_insert_header() { let proof = prover.gen_proof(h1, h2).expect("gen proof"); let pos = leaf_index_to_pos(h1); assert_eq!(pos, prover.get_pos(h1)); - assert_eq!(prove_elem, (&prover.store).get_elem(pos).unwrap().unwrap()); + assert_eq!(prove_elem, (&prover.store).get(pos).unwrap().unwrap()); let result = proof.verify(root, vec![(pos, prove_elem)]).expect("verify"); assert!(result); } diff --git a/src/tests/test_mmr.rs b/src/tests/test_mmr.rs index 7e4405c..e2814b9 100644 --- a/src/tests/test_mmr.rs +++ b/src/tests/test_mmr.rs @@ -150,6 +150,46 @@ fn test_gen_proof_with_duplicate_leaves() { test_mmr(10, vec![5, 5]); } +#[test] +fn test_update_mmr() { + let count = 1234; + + let root = { + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + let positions: Vec = (0u32..count / 2) + .map(|i| mmr.push(NumberHash::from(i)).unwrap()) + .collect(); + assert!(positions.len() == count as usize / 2); + + for (i, pos) in positions.into_iter().enumerate() { + if i % 3 == 1 { + mmr.update(pos, NumberHash::from(i as u32 * 3)).unwrap(); + } + } + + for i in count / 2..count { + mmr.push(NumberHash::from(i)).unwrap(); + } + mmr.get_root().expect("get root") + }; + + let new_root = { + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + (0u32..count).for_each(|i| { + if i % 3 == 1 && i < count / 2 { + mmr.push(NumberHash::from(i * 3)).unwrap(); + } else { + mmr.push(NumberHash::from(i)).unwrap(); + } + }); + mmr.get_root().expect("get root") + }; + + assert_eq!(root, new_root); +} + fn test_invalid_proof_verification( leaf_count: u32, positions_to_verify: Vec, diff --git a/src/util.rs b/src/util.rs index f021de3..2c3a9ae 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,5 +1,5 @@ use crate::collections::BTreeMap; -use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Result, MMR}; +use crate::{MMRStoreReadOps, MMRStoreWriteOps, Result, MMR}; use core::cell::RefCell; #[derive(Clone)] @@ -18,17 +18,14 @@ impl MemStore { } impl MMRStoreReadOps for &MemStore { - fn get_elem(&self, pos: u64) -> Result> { + fn get(&self, pos: u64) -> Result> { Ok(self.0.borrow().get(&pos).cloned()) } } impl MMRStoreWriteOps for &MemStore { - fn append(&mut self, pos: u64, elems: Vec) -> Result<()> { - let mut store = self.0.borrow_mut(); - for (i, elem) in elems.into_iter().enumerate() { - store.insert(pos + i as u64, elem); - } + fn insert(&mut self, pos: u64, elem: T) -> Result<()> { + self.0.borrow_mut().insert(pos, elem); Ok(()) } }