diff --git a/src/bin_storage.rs b/src/bin_storage.rs index 6956d0b..7aa1ce2 100644 --- a/src/bin_storage.rs +++ b/src/bin_storage.rs @@ -22,18 +22,9 @@ pub trait IterableStorage { where Self: 'a; - fn keys<'a>(&'a self, start: Option<&'a [u8]>, end: Option<&'a [u8]>) - -> Self::KeysIterator<'a>; - fn values<'a>( - &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, - ) -> Self::ValuesIterator<'a>; - fn pairs<'a>( - &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, - ) -> Self::PairsIterator<'a>; + fn keys<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::KeysIterator<'a>; + fn values<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::ValuesIterator<'a>; + fn pairs<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::PairsIterator<'a>; } pub trait RevIterableStorage { @@ -49,17 +40,17 @@ pub trait RevIterableStorage { fn rev_keys<'a>( &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, + start: Option<&[u8]>, + end: Option<&[u8]>, ) -> Self::RevKeysIterator<'a>; fn rev_values<'a>( &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, + start: Option<&[u8]>, + end: Option<&[u8]>, ) -> Self::RevValuesIterator<'a>; fn rev_pairs<'a>( &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, + start: Option<&[u8]>, + end: Option<&[u8]>, ) -> Self::RevPairsIterator<'a>; } diff --git a/src/containers/item.rs b/src/containers/item.rs index ea5f092..8a81365 100644 --- a/src/containers/item.rs +++ b/src/containers/item.rs @@ -6,7 +6,7 @@ use crate::{ Storage, StorageMut, }; -use super::Storable; +use super::{KeyDecodeError, Storable}; pub struct Item { prefix: &'static [u8], @@ -39,6 +39,9 @@ where T: EncodableWith + DecodableWith, { type AccessorT = ItemAccess; + type Key = (); + type Value = T; + type ValueDecodeError = E::DecodeError; fn access_impl(storage: S) -> ItemAccess { ItemAccess { @@ -46,6 +49,18 @@ where phantom: PhantomData, } } + + fn decode_key(key: &[u8]) -> Result<(), KeyDecodeError> { + if key.is_empty() { + Ok(()) + } else { + Err(KeyDecodeError) + } + } + + fn decode_value(value: &[u8]) -> Result { + T::decode(value) + } } pub struct ItemAccess { diff --git a/src/containers/map.rs b/src/containers/map.rs index f1ec5b4..f413921 100644 --- a/src/containers/map.rs +++ b/src/containers/map.rs @@ -1,8 +1,9 @@ -use std::marker::PhantomData; +use std::{borrow::Borrow, marker::PhantomData}; -use crate::{storage_branch::StorageBranch, Storage}; +use crate::storage_branch::StorageBranch; +use crate::{IterableStorage, Storage}; -use super::{Key, Storable}; +use super::{KeyDecodeError, Storable, StorableIter}; pub struct Map { prefix: &'static [u8], @@ -11,7 +12,7 @@ pub struct Map { impl Map where - K: ?Sized, + K: OwnedKey, V: Storable, { pub const fn new(prefix: &'static [u8]) -> Self { @@ -31,10 +32,13 @@ where impl Storable for Map where - K: ?Sized, + K: OwnedKey, V: Storable, { type AccessorT = MapAccess; + type Key = (K, V::Key); + type Value = V::Value; + type ValueDecodeError = V::ValueDecodeError; fn access_impl(storage: S) -> MapAccess { MapAccess { @@ -42,6 +46,23 @@ where phantom: PhantomData, } } + + fn decode_key(key: &[u8]) -> Result { + let len = *key.first().ok_or(KeyDecodeError)? as usize; + + if key.len() < len + 1 { + return Err(KeyDecodeError); + } + + let map_key = K::from_bytes(&key[1..len + 1]).map_err(|_| KeyDecodeError)?; + let rest = V::decode_key(&key[len + 1..])?; + + Ok((map_key, rest)) + } + + fn decode_value(value: &[u8]) -> Result { + V::decode_value(value) + } } pub struct MapAccess { @@ -51,12 +72,71 @@ pub struct MapAccess { impl MapAccess where - K: Key + ?Sized, + K: Key, V: Storable, S: Storage, { - pub fn get<'s>(&'s self, key: &K) -> V::AccessorT> { - let key = key.bytes(); - V::access_impl(StorageBranch::new(&self.storage, key.to_vec())) + pub fn entry<'s, Q>(&'s self, key: &Q) -> V::AccessorT> + where + K: Borrow, + Q: Key + ?Sized, + { + let len = key.bytes().len(); + let bytes = key.bytes(); + let mut key = Vec::with_capacity(len + 1); + + key.push(len as u8); + key.extend_from_slice(bytes); + + V::access_impl(StorageBranch::new(&self.storage, key)) + } +} + +impl MapAccess +where + K: OwnedKey, + V: Storable, + S: IterableStorage, +{ + pub fn iter<'s>( + &'s self, + start: Option<&[u8]>, + end: Option<&[u8]>, + ) -> StorableIter<'s, Map, S> { + StorableIter { + inner: self.storage.pairs(start, end), + phantom: PhantomData, + } + } +} + +pub trait Key { + fn bytes(&self) -> &[u8]; +} + +pub trait OwnedKey: Key { + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized; +} + +impl Key for String { + fn bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl OwnedKey for String { + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized, + { + std::str::from_utf8(bytes).map(String::from).map_err(|_| ()) + } +} + +impl Key for str { + fn bytes(&self) -> &[u8] { + self.as_bytes() } } diff --git a/src/containers/mod.rs b/src/containers/mod.rs index 955795f..172f4fb 100644 --- a/src/containers/mod.rs +++ b/src/containers/mod.rs @@ -1,21 +1,57 @@ mod item; mod map; +use std::marker::PhantomData; + pub use item::{Item, ItemAccess}; pub use map::{Map, MapAccess}; +use crate::IterableStorage; + pub trait Storable { type AccessorT; + type Key; + type Value; + type ValueDecodeError; fn access_impl(storage: S) -> Self::AccessorT; + + fn decode_key(key: &[u8]) -> Result; + + fn decode_value(value: &[u8]) -> Result; } -pub trait Key { - fn bytes(&self) -> &[u8]; +pub struct KeyDecodeError; + +pub struct StorableIter<'i, S, B> +where + S: Storable, + B: IterableStorage + 'i, +{ + inner: B::PairsIterator<'i>, + phantom: PhantomData, } -impl Key for str { - fn bytes(&self) -> &[u8] { - self.as_bytes() +impl<'i, S, B> Iterator for StorableIter<'i, S, B> +where + S: Storable, + B: IterableStorage + 'i, +{ + type Item = Result<(S::Key, S::Value), KVDecodeError>; + + fn next(&mut self) -> Option { + self.inner.next().map(|(k, v)| -> Self::Item { + match (S::decode_key(&k), S::decode_value(&v)) { + (Err(_), _) => Err(KVDecodeError::Key), + (_, Err(e)) => Err(KVDecodeError::Value(e)), + (Ok(k), Ok(v)) => Ok((k, v)), + } + }) } } + +#[derive(Debug, PartialEq)] +pub enum KVDecodeError { + Key, + Value(V), +} diff --git a/src/storage_branch.rs b/src/storage_branch.rs index 44c5b9b..9640e50 100644 --- a/src/storage_branch.rs +++ b/src/storage_branch.rs @@ -1,4 +1,4 @@ -use crate::{Storage, StorageMut}; +use crate::{IterableStorage, RevIterableStorage, Storage, StorageMut}; pub struct StorageBranch<'s, S> { backend: &'s S, @@ -29,3 +29,220 @@ impl StorageMut for StorageBranch<'_, S> { self.backend.remove(&[&self.prefix[..], key].concat()) } } + +impl IterableStorage for StorageBranch<'_, S> { + type KeysIterator<'a> = BranchKeysIter> where Self: 'a; + type ValuesIterator<'a> = S::ValuesIterator<'a> where Self: 'a; + type PairsIterator<'a> = BranchKVIter> where Self: 'a; + + fn keys<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::KeysIterator<'a> { + let (start, end) = sub_bounds(&self.prefix, start, end); + + BranchKeysIter { + inner: self.backend.keys( + start.as_ref().map(AsRef::as_ref), + end.as_ref().map(AsRef::as_ref), + ), + prefix_len: self.prefix.len(), + } + } + + fn values<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::ValuesIterator<'a> { + let (start, end) = sub_bounds(&self.prefix, start, end); + + self.backend.values( + start.as_ref().map(AsRef::as_ref), + end.as_ref().map(AsRef::as_ref), + ) + } + + fn pairs<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::PairsIterator<'a> { + let (start, end) = sub_bounds(&self.prefix, start, end); + + BranchKVIter { + inner: self.backend.pairs( + start.as_ref().map(AsRef::as_ref), + end.as_ref().map(AsRef::as_ref), + ), + prefix_len: self.prefix.len(), + } + } +} + +impl RevIterableStorage for StorageBranch<'_, S> { + type RevKeysIterator<'a> = BranchKeysIter> where Self: 'a; + type RevValuesIterator<'a> = S::RevValuesIterator<'a> where Self: 'a; + type RevPairsIterator<'a> = BranchKVIter> where Self: 'a; + + fn rev_keys<'a>( + &'a self, + start: Option<&[u8]>, + end: Option<&[u8]>, + ) -> Self::RevKeysIterator<'a> { + let (start, end) = sub_bounds(&self.prefix, start, end); + + BranchKeysIter { + inner: self.backend.rev_keys( + start.as_ref().map(AsRef::as_ref), + end.as_ref().map(AsRef::as_ref), + ), + prefix_len: self.prefix.len(), + } + } + + fn rev_values<'a>( + &'a self, + start: Option<&[u8]>, + end: Option<&[u8]>, + ) -> Self::RevValuesIterator<'a> { + let (start, end) = sub_bounds(&self.prefix, start, end); + + self.backend.rev_values( + start.as_ref().map(AsRef::as_ref), + end.as_ref().map(AsRef::as_ref), + ) + } + + fn rev_pairs<'a>( + &'a self, + start: Option<&[u8]>, + end: Option<&[u8]>, + ) -> Self::RevPairsIterator<'a> { + let (start, end) = sub_bounds(&self.prefix, start, end); + + BranchKVIter { + inner: self.backend.rev_pairs( + start.as_ref().map(AsRef::as_ref), + end.as_ref().map(AsRef::as_ref), + ), + prefix_len: self.prefix.len(), + } + } +} + +fn sub_bounds( + prefix: &[u8], + start: Option<&[u8]>, + end: Option<&[u8]>, +) -> (Option>, Option>) { + if prefix.is_empty() { + (start.map(|s| s.to_vec()), end.map(|s| s.to_vec())) + } else { + ( + Some( + start + .map(|s| [prefix, s].concat()) + .unwrap_or(prefix.to_vec()), + ), + Some(end.map(|e| [prefix, e].concat()).unwrap_or_else(|| { + let mut pref = prefix.to_vec(); + if let Some(x) = pref.last_mut() { + *x += 1; + } + pref + })), + ) + } +} + +pub struct BranchKeysIter { + inner: I, + prefix_len: usize, +} + +impl Iterator for BranchKeysIter +where + I: Iterator>, +{ + type Item = Vec; + + fn next(&mut self) -> Option { + self.inner.next().map(|key| key[self.prefix_len..].to_vec()) + } +} + +pub struct BranchKVIter { + inner: I, + prefix_len: usize, +} + +impl Iterator for BranchKVIter +where + I: Iterator, Vec)>, +{ + type Item = (Vec, Vec); + + fn next(&mut self) -> Option { + self.inner.next().map(|(key, value)| { + let key = key[self.prefix_len..].to_vec(); + (key, value) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // TODO: move TestStorage and use it for these unit tests? + + // use crate::backend::TestStorage; + + // #[test] + // fn storage_branch() { + // let storage = TestStorage::new(); + // let branch = StorageBranch::new(&storage, b"foo".to_vec()); + + // branch.set(b"bar", b"baz"); + // branch.set(b"qux", b"quux"); + + // assert_eq!(storage.get(b"bar"), None); + // assert_eq!(storage.get(b"qux"), None); + + // assert_eq!(storage.get(b"foobar"), Some(b"baz".to_vec())); + // assert_eq!(storage.get(b"fooqux"), Some(b"quux".to_vec())); + // } + + #[test] + fn sub_bounds_no_prefix() { + assert_eq!( + sub_bounds(&[], Some(b"foo"), Some(b"bar")), + (Some(b"foo".to_vec()), Some(b"bar".to_vec())) + ); + + assert_eq!( + sub_bounds(&[], Some(b"foo"), None), + (Some(b"foo".to_vec()), None) + ); + + assert_eq!( + sub_bounds(&[], None, Some(b"bar")), + (None, Some(b"bar".to_vec())) + ); + + assert_eq!(sub_bounds(&[], None, None), (None, None)); + } + + #[test] + fn sub_bounds_with_prefix() { + assert_eq!( + sub_bounds(b"foo", Some(b"bar"), Some(b"baz")), + (Some(b"foobar".to_vec()), Some(b"foobaz".to_vec())) + ); + + assert_eq!( + sub_bounds(b"foo", Some(b"bar"), None), + (Some(b"foobar".to_vec()), Some(b"fop".to_vec())) + ); + + assert_eq!( + sub_bounds(b"foo", None, Some(b"baz")), + (Some(b"foo".to_vec()), Some(b"foobaz".to_vec())) + ); + + assert_eq!( + sub_bounds(b"foo", None, None), + (Some(b"foo".to_vec()), Some(b"fop".to_vec())) + ); + } +} diff --git a/tests/common/backend.rs b/tests/common/backend.rs index 1e2dadf..c8f9cf0 100644 --- a/tests/common/backend.rs +++ b/tests/common/backend.rs @@ -28,7 +28,7 @@ impl TestStorage { impl stork::Storage for TestStorage { fn get(&self, key: &[u8]) -> Option> { // Safety: see above - unsafe { (&*self.0.get()).get(key).map(|v| v.clone()) } + unsafe { (*self.0.get()).get(key).cloned() } } } @@ -36,14 +36,14 @@ impl stork::StorageMut for TestStorage { fn set(&self, key: &[u8], value: &[u8]) { // Safety: see above unsafe { - (&mut *self.0.get()).insert(key.to_vec(), value.to_vec()); + (*self.0.get()).insert(key.to_vec(), value.to_vec()); } } fn remove(&self, key: &[u8]) { // Safety: see above unsafe { - (&mut *self.0.get()).remove(key); + (*self.0.get()).remove(key); } } } @@ -53,44 +53,41 @@ impl stork::IterableStorage for TestStorage { type ValuesIterator<'a> = Box> + 'a>; type PairsIterator<'a> = Box, Vec)> + 'a>; - fn keys<'a>( - &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, - ) -> Self::KeysIterator<'a> { + fn keys<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::KeysIterator<'a> { + let start = start.map(|x| x.to_vec()); + let end = end.map(|x| x.to_vec()); + Box::new( // Safety: see above - unsafe { (&*self.0.get()).clone() } + unsafe { (*self.0.get()).clone() } .into_iter() - .filter(move |(k, _)| check_bounds(k, start, end)) + .filter(move |(k, _)| check_bounds(k, start.as_ref(), end.as_ref())) .map(|(k, _)| k), ) } - fn values<'a>( - &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, - ) -> Self::ValuesIterator<'a> { + fn values<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::ValuesIterator<'a> { + let start = start.map(|x| x.to_vec()); + let end = end.map(|x| x.to_vec()); + Box::new( // Safety: see above - unsafe { (&*self.0.get()).clone() } + unsafe { (*self.0.get()).clone() } .into_iter() - .filter(move |(k, _)| check_bounds(k, start, end)) + .filter(move |(k, _)| check_bounds(k, start.as_ref(), end.as_ref())) .map(|(_, v)| v), ) } - fn pairs<'a>( - &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, - ) -> Self::PairsIterator<'a> { + fn pairs<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::PairsIterator<'a> { + let start = start.map(|x| x.to_vec()); + let end = end.map(|x| x.to_vec()); + Box::new( // Safety: see above - unsafe { (&*self.0.get()).clone() } + unsafe { (*self.0.get()).clone() } .into_iter() - .filter(move |(k, _)| check_bounds(k, start, end)), + .filter(move |(k, _)| check_bounds(k, start.as_ref(), end.as_ref())), ) } } @@ -102,30 +99,30 @@ impl stork::RevIterableStorage for TestStorage { fn rev_keys<'a>( &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, + start: Option<&[u8]>, + end: Option<&[u8]>, ) -> Self::RevKeysIterator<'a> { Box::new(self.keys(start, end).rev()) } fn rev_values<'a>( &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, + start: Option<&[u8]>, + end: Option<&[u8]>, ) -> Self::RevValuesIterator<'a> { Box::new(self.values(start, end).rev()) } fn rev_pairs<'a>( &'a self, - start: Option<&'a [u8]>, - end: Option<&'a [u8]>, + start: Option<&[u8]>, + end: Option<&[u8]>, ) -> Self::RevPairsIterator<'a> { Box::new(self.pairs(start, end).rev()) } } -fn check_bounds(v: &[u8], start: Option<&[u8]>, end: Option<&[u8]>) -> bool { +fn check_bounds(v: &[u8], start: Option<&Vec>, end: Option<&Vec>) -> bool { if let Some(start) = start { if v < start { return false; diff --git a/tests/containers.rs b/tests/containers.rs index 14810df..72018ee 100644 --- a/tests/containers.rs +++ b/tests/containers.rs @@ -27,41 +27,113 @@ fn item() { fn map() { let storage = TestStorage::new(); - let map = Map::>::new(&[0]); + let map = Map::>::new(&[0]); let access = map.access(&storage); - access.get("foo").set(&1337).unwrap(); + access.entry("foo").set(&1337).unwrap(); - assert_eq!(access.get("foo").get().unwrap(), Some(1337)); + assert_eq!(access.entry("foo").get().unwrap(), Some(1337)); assert_eq!( - storage.get(&[0, 102, 111, 111]), + storage.get(&[0, 3, 102, 111, 111]), Some(1337u64.to_le_bytes().to_vec()) ); - assert_eq!(access.get("bar").get().unwrap(), None); + assert_eq!(access.entry("bar").get().unwrap(), None); } #[test] fn map_of_map() { let storage = TestStorage::new(); - let map = Map::>>::new(&[0]); + let map = Map::>>::new(&[0]); map.access(&storage) - .get("foo") - .get("bar") + .entry("foo") + .entry("bar") .set(&1337) .unwrap(); assert_eq!( - map.access(&storage).get("foo").get("bar").get().unwrap(), + map.access(&storage) + .entry("foo") + .entry("bar") + .get() + .unwrap(), Some(1337) ); assert_eq!( - storage.get(&[0, 102, 111, 111, 98, 97, 114]), + storage.get(&[0, 3, 102, 111, 111, 3, 98, 97, 114]), Some(1337u64.to_le_bytes().to_vec()) ); assert_eq!( - map.access(&storage).get("foo").get("baz").get().unwrap(), + map.access(&storage) + .entry("foo") + .entry("baz") + .get() + .unwrap(), None ); } + +#[test] +fn simple_iteration() { + let storage = TestStorage::new(); + + let map = Map::>::new(&[0]); + let access = map.access(&storage); + + access.entry("foo").set(&1337).unwrap(); + access.entry("bar").set(&42).unwrap(); + + let items = access + .iter(None, None) + .collect::, _>>() + .unwrap(); + assert_eq!( + items, + vec![ + (("bar".to_string(), ()), 42), + (("foo".to_string(), ()), 1337) + ] + ); +} + +#[test] +fn composable_iteration() { + let storage = TestStorage::new(); + + let map = Map::>>::new(&[0]); + let access = map.access(&storage); + + // populate with data + access.entry("foo").entry("bar").set(&1337).unwrap(); + access.entry("foo").entry("baz").set(&42).unwrap(); + access.entry("qux").entry("quux").set(&9001).unwrap(); + + // iterate over all items + let items = access + .iter(None, None) + .collect::, _>>() + .unwrap(); + assert_eq!( + items, + vec![ + (("foo".to_string(), ("bar".to_string(), ())), 1337), + (("foo".to_string(), ("baz".to_string(), ())), 42), + (("qux".to_string(), ("quux".to_string(), ())), 9001) + ] + ); + + // iterate over items under "foo" + let items = access + .entry("foo") + .iter(None, None) + .collect::, _>>() + .unwrap(); + assert_eq!( + items, + vec![ + (("bar".to_string(), ()), 1337), + (("baz".to_string(), ()), 42) + ] + ); +}