Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HOLD] feat: updatable mmr #31

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/mmr_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fn bench(c: &mut Criterion) {
let proofs: Vec<_> = (0..10_000)
.map(|_| {
let pos = positions.choose(&mut rng).unwrap();
let elem = (&store).get_elem(*pos).unwrap().unwrap();
let elem = (&store).get(*pos).unwrap().unwrap();
let proof = mmr.gen_proof(vec![*pos]).unwrap();
(pos, elem, proof)
})
Expand Down
4 changes: 3 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ pub enum Error {
NodeProofsNotSupported,
/// The leaves is an empty list, or beyond the mmr range
GenProofForInvalidLeaves,

/// The position of the updating leaf is out of range
UpdateLeafOutOfRange,
/// The two nodes couldn't merge into one.
MergeError(crate::string::String),
}
Expand All @@ -26,6 +27,7 @@ impl core::fmt::Display for Error {
CorruptedProof => write!(f, "Corrupted proof")?,
NodeProofsNotSupported => write!(f, "Tried to verify membership of a non-leaf")?,
GenProofForInvalidLeaves => write!(f, "Generate proof ofr invalid leaves")?,
UpdateLeafOutOfRange => write!(f, "Update leaf out of range")?,
MergeError(msg) => write!(f, "Merge error {}", msg)?,
}
Ok(())
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ cfg_if::cfg_if! {
use std::collections;
use std::vec;
use std::string;
use std::collections::BTreeMap;
} else {
extern crate alloc;
use alloc::borrow;
use alloc::collections;
use alloc::vec;
use alloc::string;
use alloc::collections::BTreeMap;
}
}
52 changes: 51 additions & 1 deletion src/mmr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,17 @@ impl<T, M, S> MMR<T, M, S> {
&self.batch
}

pub fn batch_mut(&mut self) -> &mut MMRBatch<T, S> {
&mut self.batch
}

pub fn store(&self) -> &S {
self.batch.store()
}

pub fn store_mut(&mut self) -> &mut S {
self.batch.store_mut()
}
}

impl<T: Clone + PartialEq, M: Merge<Item = T>, S: MMRStoreReadOps<T>> MMR<T, M, S> {
Expand All @@ -61,7 +69,7 @@ impl<T: Clone + PartialEq, M: Merge<Item = T>, S: MMRStoreReadOps<T>> MMR<T, M,
Ok(Cow::Owned(elem))
}

// push a element and return position
// push an element and return position
pub fn push(&mut self, elem: T) -> Result<u64> {
let mut elems = vec![elem];
let elem_pos = self.mmr_size;
Expand All @@ -84,6 +92,48 @@ impl<T: Clone + PartialEq, M: Merge<Item = T>, S: MMRStoreReadOps<T>> MMR<T, M,
Ok(elem_pos)
}

// update an element at position
pub fn update(&mut self, mut pos: u64, mut elem: T) -> Result<()> {
if pos >= self.mmr_size || pos_height_in_tree(pos) > 0 {
return Err(Error::UpdateLeafOutOfRange);
}
self.batch.insert(pos, elem.clone());

let peaks = get_peaks(self.mmr_size);
let peak = peaks
.into_iter()
.find(|p| pos <= *p)
.expect("checked pos < self.mmr_size");
let mut height = 0;
while pos < peak {
let next_height = pos_height_in_tree(pos + 1);
let sibling_offset = sibling_offset(height);
if next_height > height {
elem = M::merge(
&self
.batch
.get_elem(pos - sibling_offset)?
.ok_or(Error::InconsistentStore)?,
&elem,
)?;
pos += 1;
} else {
elem = M::merge(
&elem,
&self
.batch
.get_elem(pos + sibling_offset)?
.ok_or(Error::InconsistentStore)?,
)?;
pos += parent_offset(height);
}
self.batch.insert(pos, elem.clone());
height += 1;
}

Ok(())
}

/// get_root
pub fn get_root(&self) -> Result<T> {
if self.mmr_size == 0 {
Expand Down
50 changes: 24 additions & 26 deletions src/mmr_store.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,63 @@
use crate::{vec::Vec, Result};
use crate::{vec::Vec, BTreeMap, Result};

#[derive(Default)]
pub struct MMRBatch<Elem, Store> {
memory_batch: Vec<(u64, Vec<Elem>)>,
memory_batch: BTreeMap<u64, Elem>,
store: Store,
}

impl<Elem, Store> MMRBatch<Elem, Store> {
pub fn new(store: Store) -> Self {
MMRBatch {
memory_batch: Vec::new(),
memory_batch: BTreeMap::new(),
store,
}
}

pub fn append(&mut self, pos: u64, elems: Vec<Elem>) {
self.memory_batch.push((pos, elems));
for (i, elem) in elems.into_iter().enumerate() {
self.insert(pos + i as u64, elem);
}
}

pub fn insert(&mut self, pos: u64, elem: Elem) {
self.memory_batch.insert(pos, elem);
}

pub fn store(&self) -> &Store {
&self.store
}

pub fn store_mut(&mut self) -> &mut Store {
&mut self.store
}
}

impl<Elem: Clone, Store: MMRStoreReadOps<Elem>> MMRBatch<Elem, Store> {
pub fn get_elem(&self, pos: u64) -> Result<Option<Elem>> {
for (start_pos, elems) in self.memory_batch.iter().rev() {
if pos < *start_pos {
continue;
} else if pos < start_pos + elems.len() as u64 {
return Ok(elems.get((pos - start_pos) as usize).cloned());
} else {
break;
}
if let Some(elem) = self.memory_batch.get(&pos) {
Ok(Some(elem.clone()))
} else {
self.store.get(pos)
}
self.store.get_elem(pos)
}
}

impl<Elem, Store: MMRStoreWriteOps<Elem>> MMRBatch<Elem, Store> {
pub fn commit(&mut self) -> Result<()> {
for (pos, elems) in self.memory_batch.drain(..) {
self.store.append(pos, elems)?;
let batch = core::mem::take(&mut self.memory_batch);

for (pos, elem) in batch {
self.store.insert(pos, elem)?;
}
Ok(())
}
}

impl<Elem, Store> IntoIterator for MMRBatch<Elem, Store> {
type Item = (u64, Vec<Elem>);
type IntoIter = crate::vec::IntoIter<Self::Item>;

fn into_iter(self) -> Self::IntoIter {
self.memory_batch.into_iter()
}
}

pub trait MMRStoreReadOps<Elem> {
fn get_elem(&self, pos: u64) -> Result<Option<Elem>>;
fn get(&self, pos: u64) -> Result<Option<Elem>>;
}

pub trait MMRStoreWriteOps<Elem> {
fn append(&mut self, pos: u64, elems: Vec<Elem>) -> Result<()>;
fn insert(&mut self, pos: u64, elem: Elem) -> Result<()>;
}
4 changes: 2 additions & 2 deletions src/tests/test_accumulate_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl Prover {
let mut mmr = MMR::<_, MergeHashWithTD, _>::new(self.positions.len() as u64, &self.store);
// get previous element
let mut previous = if let Some(pos) = self.positions.last() {
mmr.store().get_elem(*pos)?.expect("exists")
mmr.store().get(*pos)?.expect("exists")
} else {
let genesis = Header::default();

Expand Down Expand Up @@ -187,7 +187,7 @@ fn test_insert_header() {
let proof = prover.gen_proof(h1, h2).expect("gen proof");
let pos = leaf_index_to_pos(h1);
assert_eq!(pos, prover.get_pos(h1));
assert_eq!(prove_elem, (&prover.store).get_elem(pos).unwrap().unwrap());
assert_eq!(prove_elem, (&prover.store).get(pos).unwrap().unwrap());
let result = proof.verify(root, vec![(pos, prove_elem)]).expect("verify");
assert!(result);
}
40 changes: 40 additions & 0 deletions src/tests/test_mmr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,46 @@ fn test_gen_proof_with_duplicate_leaves() {
test_mmr(10, vec![5, 5]);
}

#[test]
fn test_update_mmr() {
let count = 1234;

let root = {
let store = MemStore::default();
let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store);
let positions: Vec<u64> = (0u32..count / 2)
.map(|i| mmr.push(NumberHash::from(i)).unwrap())
.collect();
assert!(positions.len() == count as usize / 2);

for (i, pos) in positions.into_iter().enumerate() {
if i % 3 == 1 {
mmr.update(pos, NumberHash::from(i as u32 * 3)).unwrap();
}
}

for i in count / 2..count {
mmr.push(NumberHash::from(i)).unwrap();
}
mmr.get_root().expect("get root")
};

let new_root = {
let store = MemStore::default();
let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store);
(0u32..count).for_each(|i| {
if i % 3 == 1 && i < count / 2 {
mmr.push(NumberHash::from(i * 3)).unwrap();
} else {
mmr.push(NumberHash::from(i)).unwrap();
}
});
mmr.get_root().expect("get root")
};

assert_eq!(root, new_root);
}

fn test_invalid_proof_verification(
leaf_count: u32,
positions_to_verify: Vec<u64>,
Expand Down
11 changes: 4 additions & 7 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::collections::BTreeMap;
use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Result, MMR};
use crate::{MMRStoreReadOps, MMRStoreWriteOps, Result, MMR};
use core::cell::RefCell;

#[derive(Clone)]
Expand All @@ -18,17 +18,14 @@ impl<T> MemStore<T> {
}

impl<T: Clone> MMRStoreReadOps<T> for &MemStore<T> {
fn get_elem(&self, pos: u64) -> Result<Option<T>> {
fn get(&self, pos: u64) -> Result<Option<T>> {
Ok(self.0.borrow().get(&pos).cloned())
}
}

impl<T> MMRStoreWriteOps<T> for &MemStore<T> {
fn append(&mut self, pos: u64, elems: Vec<T>) -> Result<()> {
let mut store = self.0.borrow_mut();
for (i, elem) in elems.into_iter().enumerate() {
store.insert(pos + i as u64, elem);
}
fn insert(&mut self, pos: u64, elem: T) -> Result<()> {
self.0.borrow_mut().insert(pos, elem);
Ok(())
}
}
Expand Down