diff --git a/src/accumulator/node_hash.rs b/src/accumulator/node_hash.rs index 820d8c7..83c1309 100644 --- a/src/accumulator/node_hash.rs +++ b/src/accumulator/node_hash.rs @@ -229,6 +229,46 @@ impl NodeHash { pub const fn placeholder() -> Self { NodeHash::Placeholder } + + /// write to buffer + pub(super) fn write(&self, writer: &mut W) -> std::io::Result<()> + where + W: std::io::Write, + { + match self { + Self::Empty => writer.write_all(&[0]), + Self::Placeholder => writer.write_all(&[1]), + Self::Some(hash) => { + writer.write_all(&[2])?; + writer.write_all(hash) + } + } + } + + /// Read from buffer + pub(super) fn read(reader: &mut R) -> std::io::Result + where + R: std::io::Read, + { + let mut tag = [0]; + reader.read_exact(&mut tag)?; + match tag { + [0] => Ok(Self::Empty), + [1] => Ok(Self::Placeholder), + [2] => { + let mut hash = [0; 32]; + reader.read_exact(&mut hash)?; + Ok(Self::Some(hash)) + } + [_] => { + let err = std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected tag for NodeHash", + ); + Err(err) + } + } + } } #[cfg(test)] diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 35e0138..194ae81 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -91,7 +91,7 @@ impl Node { NodeType::Branch => writer.write_all(&0_u64.to_le_bytes())?, NodeType::Leaf => writer.write_all(&1_u64.to_le_bytes())?, } - writer.write_all(&*self.data.get())?; + self.data.get().write(writer)?; self.left .borrow() .as_ref() @@ -101,7 +101,7 @@ impl Node { self.right .borrow() .as_ref() - .map(|l| l.write_one(writer)) + .map(|r| r.write_one(writer)) .transpose()?; Ok(()) } @@ -118,10 +118,9 @@ impl Node { reader: &mut R, index: &mut HashMap>, ) -> std::io::Result> { - let mut data = [0u8; 32]; let mut ty = [0u8; 8]; reader.read_exact(&mut ty)?; - reader.read_exact(&mut data)?; + let data = NodeHash::read(reader)?; let ty = match u64::from_le_bytes(ty) { 0 => NodeType::Branch, @@ -131,7 +130,7 @@ impl Node { if ty == NodeType::Leaf { let leaf = Rc::new(Node { ty, - data: Cell::new(data.into()), + data: Cell::new(data), parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), left: RefCell::new(None), right: RefCell::new(None), @@ -141,16 +140,17 @@ impl Node { } let node = Rc::new(Node { ty: NodeType::Branch, - data: Cell::new(data.into()), + data: Cell::new(data), parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), left: RefCell::new(None), right: RefCell::new(None), }); - let left = _read_one(Some(node.clone()), reader, index)?; - let right = _read_one(Some(node.clone()), reader, index)?; - node.left.replace(Some(left)); - node.right.replace(Some(right)); - + if !data.is_empty() { + let left = _read_one(Some(node.clone()), reader, index)?; + let right = _read_one(Some(node.clone()), reader, index)?; + node.left.replace(Some(left)); + node.right.replace(Some(right)); + } node.left .borrow() .as_ref() @@ -545,6 +545,9 @@ impl Pollard { } /// to_string returns the full pollard in a string for all forests less than 6 rows. fn string(&self) -> String { + if self.leaves == 0 { + return "empty".to_owned(); + } let fh = tree_rows(self.leaves); // The accumulator should be less than 6 rows. if fh > 6 { @@ -972,4 +975,34 @@ mod test { fn get_hash_vec_of(elements: &[u8]) -> Vec { elements.iter().map(|el| hash_from_u8(*el)).collect() } + + #[test] + fn test_display_empty() { + let p = Pollard::new(); + let _ = p.to_string(); + } + + #[test] + fn test_serialization_roundtrip() { + let mut p = Pollard::new(); + let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + let hashes: Vec = values + .into_iter() + .map(|i| NodeHash::from([i; 32])) + .collect(); + p.modify(&hashes, &[]).expect("modify should work"); + assert_eq!(p.get_roots().len(), 1); + assert!(!p.get_roots()[0].get_data().is_empty()); + assert_eq!(p.leaves, 16); + p.modify(&[], &hashes).expect("modify should work"); + assert_eq!(p.get_roots().len(), 1); + assert!(p.get_roots()[0].get_data().is_empty()); + assert_eq!(p.leaves, 16); + let mut serialized = Vec::::new(); + p.serialize(&mut serialized).expect("serialize should work"); + let deserialized = Pollard::deserialize(&*serialized).expect("deserialize should work"); + assert_eq!(deserialized.get_roots().len(), 1); + assert!(deserialized.get_roots()[0].get_data().is_empty()); + assert_eq!(deserialized.leaves, 16); + } }