Skip to content

Commit

Permalink
Merge pull request #13 from CosmWasm/mut_storage_through_mut_refs
Browse files Browse the repository at this point in the history
Mutate storage through `&mut` only (maybe?)
  • Loading branch information
uint authored Feb 27, 2024
2 parents 3a2c48a + e24e1bf commit 5bccb84
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 53 deletions.
40 changes: 38 additions & 2 deletions src/bin_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ pub trait Storage {
}

pub trait StorageMut {
fn set(&self, key: &[u8], value: &[u8]);
fn remove(&self, key: &[u8]);
fn set(&mut self, key: &[u8], value: &[u8]);
fn remove(&mut self, key: &[u8]);
}

pub trait IterableStorage {
Expand All @@ -27,6 +27,42 @@ pub trait IterableStorage {
fn pairs<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::PairsIterator<'a>;
}

impl<T: IterableStorage> IterableStorage for &T {
type KeysIterator<'a> = T::KeysIterator<'a> where Self: 'a;
type ValuesIterator<'a> = T::ValuesIterator<'a> where Self: 'a;
type PairsIterator<'a> = T::PairsIterator<'a> where Self: 'a;

fn keys<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::KeysIterator<'a> {
(**self).keys(start, end)
}

fn values<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::ValuesIterator<'a> {
(**self).values(start, end)
}

fn pairs<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::PairsIterator<'a> {
(**self).pairs(start, end)
}
}

impl<T: IterableStorage> IterableStorage for &mut T {
type KeysIterator<'a> = T::KeysIterator<'a> where Self: 'a;
type ValuesIterator<'a> = T::ValuesIterator<'a> where Self: 'a;
type PairsIterator<'a> = T::PairsIterator<'a> where Self: 'a;

fn keys<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::KeysIterator<'a> {
(**self).keys(start, end)
}

fn values<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::ValuesIterator<'a> {
(**self).values(start, end)
}

fn pairs<'a>(&'a self, start: Option<&[u8]>, end: Option<&[u8]>) -> Self::PairsIterator<'a> {
(**self).pairs(start, end)
}
}

pub trait RevIterableStorage {
type RevKeysIterator<'a>: Iterator<Item = Vec<u8>>
where
Expand Down
7 changes: 2 additions & 5 deletions src/containers/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ where
}
}

pub fn access<'s, S: Storage + 's>(
&self,
storage: &'s S,
) -> ItemAccess<E, T, StorageBranch<'s, S>> {
pub fn access<S>(&self, storage: S) -> ItemAccess<E, T, StorageBranch<S>> {
Self::access_impl(StorageBranch::new(storage, self.prefix.to_vec()))
}
}
Expand Down Expand Up @@ -88,7 +85,7 @@ where
T: EncodableWith<E> + DecodableWith<E>,
S: StorageMut,
{
pub fn set(&self, value: &T) -> Result<(), E::EncodeError> {
pub fn set(&mut self, value: &T) -> Result<(), E::EncodeError> {
let bytes = value.encode()?;
self.storage.set(&[], &bytes);
Ok(())
Expand Down
25 changes: 18 additions & 7 deletions src/containers/map.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{borrow::Borrow, marker::PhantomData};

use crate::storage_branch::StorageBranch;
use crate::{IterableStorage, Storage};
use crate::IterableStorage;

use super::{KeyDecodeError, Storable, StorableIter};

Expand All @@ -22,10 +22,7 @@ where
}
}

pub fn access<'s, S: Storage + 's>(
&self,
storage: &'s S,
) -> MapAccess<K, V, StorageBranch<'s, S>> {
pub fn access<S>(&self, storage: S) -> MapAccess<K, V, StorageBranch<S>> {
Self::access_impl(StorageBranch::new(storage, self.prefix.to_vec()))
}
}
Expand Down Expand Up @@ -74,9 +71,8 @@ impl<K, V, S> MapAccess<K, V, S>
where
K: Key,
V: Storable,
S: Storage,
{
pub fn entry<'s, Q>(&'s self, key: &Q) -> V::AccessorT<StorageBranch<'s, S>>
pub fn entry<Q>(&self, key: &Q) -> V::AccessorT<StorageBranch<&S>>
where
K: Borrow<Q>,
Q: Key + ?Sized,
Expand All @@ -90,6 +86,21 @@ where

V::access_impl(StorageBranch::new(&self.storage, key))
}

pub fn entry_mut<Q>(&mut self, key: &Q) -> V::AccessorT<StorageBranch<&mut S>>
where
K: Borrow<Q>,
Q: Key + ?Sized,
{
let len = key.bytes().len();
let bytes = key.bytes();
let mut key = Vec::with_capacity(len + 1);

key.push(len as u8);
key.extend_from_slice(bytes);

V::access_impl(StorageBranch::new(&mut self.storage, key))
}
}

impl<K, V, S> MapAccess<K, V, S>
Expand Down
29 changes: 16 additions & 13 deletions src/storage_branch.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
use crate::{IterableStorage, RevIterableStorage, Storage, StorageMut};

pub struct StorageBranch<'s, S> {
backend: &'s S,
pub struct StorageBranch<S> {
backend: S,
prefix: Vec<u8>,
}

impl<'s, S> StorageBranch<'s, S>
where
S: Storage,
{
pub fn new(backend: &'s S, prefix: Vec<u8>) -> Self {
impl<S> StorageBranch<S> {
pub fn new(backend: S, prefix: Vec<u8>) -> Self {
Self { backend, prefix }
}
}

impl<S: Storage> Storage for StorageBranch<'_, S> {
impl<S: Storage> Storage for StorageBranch<&S> {
fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
self.backend.get(&[&self.prefix[..], key].concat())
}
}

impl<S: Storage> Storage for StorageBranch<&mut S> {
fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
self.backend.get(&[&self.prefix[..], key].concat())
}
}

impl<S: StorageMut> StorageMut for StorageBranch<'_, S> {
fn set(&self, key: &[u8], value: &[u8]) {
impl<S: StorageMut> StorageMut for StorageBranch<&mut S> {
fn set(&mut self, key: &[u8], value: &[u8]) {
self.backend.set(&[&self.prefix[..], key].concat(), value)
}

fn remove(&self, key: &[u8]) {
fn remove(&mut self, key: &[u8]) {
self.backend.remove(&[&self.prefix[..], key].concat())
}
}

impl<S: IterableStorage> IterableStorage for StorageBranch<'_, S> {
impl<S: IterableStorage> IterableStorage for StorageBranch<S> {
type KeysIterator<'a> = BranchKeysIter<S::KeysIterator<'a>> where Self: 'a;
type ValuesIterator<'a> = S::ValuesIterator<'a> where Self: 'a;
type PairsIterator<'a> = BranchKVIter<S::PairsIterator<'a>> where Self: 'a;
Expand Down Expand Up @@ -69,7 +72,7 @@ impl<S: IterableStorage> IterableStorage for StorageBranch<'_, S> {
}
}

impl<S: RevIterableStorage> RevIterableStorage for StorageBranch<'_, S> {
impl<S: RevIterableStorage> RevIterableStorage for StorageBranch<S> {
type RevKeysIterator<'a> = BranchKeysIter<S::RevKeysIterator<'a>> where Self: 'a;
type RevValuesIterator<'a> = S::RevValuesIterator<'a> where Self: 'a;
type RevPairsIterator<'a> = BranchKVIter<S::RevPairsIterator<'a>> where Self: 'a;
Expand Down
4 changes: 2 additions & 2 deletions tests/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ mod common;

#[test]
fn storage_backend() {
// TODO: split this into multiple tests
// TODO: split this into multiple tests?

use stork::{IterableStorage as _, RevIterableStorage as _, StorageMut as _};

let storage = common::backend::TestStorage::new();
let mut storage = common::backend::TestStorage::new();

storage.set(&[0], b"bar");
storage.set(&[1], b"baz");
Expand Down
4 changes: 2 additions & 2 deletions tests/common/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ impl stork::Storage for TestStorage {
}

impl stork::StorageMut for TestStorage {
fn set(&self, key: &[u8], value: &[u8]) {
fn set(&mut self, key: &[u8], value: &[u8]) {
// Safety: see above
unsafe {
(*self.0.get()).insert(key.to_vec(), value.to_vec());
}
}

fn remove(&self, key: &[u8]) {
fn remove(&mut self, key: &[u8]) {
// Safety: see above
unsafe {
(*self.0.get()).remove(key);
Expand Down
49 changes: 27 additions & 22 deletions tests/containers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,48 @@ use common::encoding::TestEncoding;

#[test]
fn item() {
let storage = TestStorage::new();
let mut storage = TestStorage::new();

let item0 = Item::<u64, TestEncoding>::new(&[0]);
let access0 = item0.access(&storage);
access0.set(&42).unwrap();
item0.access(&mut storage).set(&42).unwrap();

let item1 = Item::<u64, TestEncoding>::new(&[1]);
let access1 = item1.access(&storage);

assert_eq!(access0.get().unwrap(), Some(42));
assert_eq!(item0.access(&storage).get().unwrap(), Some(42));
assert_eq!(storage.get(&[0]), Some(42u64.to_le_bytes().to_vec()));
assert_eq!(access1.get().unwrap(), None);
assert_eq!(storage.get(&[1]), None);
}

#[test]
fn map() {
let storage = TestStorage::new();
let mut storage = TestStorage::new();

let map = Map::<String, Item<u64, TestEncoding>>::new(&[0]);
let access = map.access(&storage);

access.entry("foo").set(&1337).unwrap();
map.access(&mut storage)
.entry_mut("foo")
.set(&1337)
.unwrap();

assert_eq!(access.entry("foo").get().unwrap(), Some(1337));
assert_eq!(map.access(&storage).entry("foo").get().unwrap(), Some(1337));
assert_eq!(
storage.get(&[0, 3, 102, 111, 111]),
Some(1337u64.to_le_bytes().to_vec())
);
assert_eq!(access.entry("bar").get().unwrap(), None);
assert_eq!(map.access(&storage).entry("bar").get().unwrap(), None);
}

#[test]
fn map_of_map() {
let storage = TestStorage::new();
let mut storage = TestStorage::new();

let map = Map::<String, Map<String, Item<u64, TestEncoding>>>::new(&[0]);

map.access(&storage)
.entry("foo")
.entry("bar")
map.access(&mut storage)
.entry_mut("foo")
.entry_mut("bar")
.set(&1337)
.unwrap();

Expand Down Expand Up @@ -76,13 +77,13 @@ fn map_of_map() {

#[test]
fn simple_iteration() {
let storage = TestStorage::new();
let mut storage = TestStorage::new();

let map = Map::<String, Item<u64, TestEncoding>>::new(&[0]);
let access = map.access(&storage);
let mut access = map.access(&mut storage);

access.entry("foo").set(&1337).unwrap();
access.entry("bar").set(&42).unwrap();
access.entry_mut("foo").set(&1337).unwrap();
access.entry_mut("bar").set(&42).unwrap();

let items = access
.iter(None, None)
Expand All @@ -99,15 +100,19 @@ fn simple_iteration() {

#[test]
fn composable_iteration() {
let storage = TestStorage::new();
let mut storage = TestStorage::new();

let map = Map::<String, Map<String, Item<u64, TestEncoding>>>::new(&[0]);
let access = map.access(&storage);
let mut access = map.access(&mut storage);

// populate with data
access.entry("foo").entry("bar").set(&1337).unwrap();
access.entry("foo").entry("baz").set(&42).unwrap();
access.entry("qux").entry("quux").set(&9001).unwrap();
access.entry_mut("foo").entry_mut("bar").set(&1337).unwrap();
access.entry_mut("foo").entry_mut("baz").set(&42).unwrap();
access
.entry_mut("qux")
.entry_mut("quux")
.set(&9001)
.unwrap();

// iterate over all items
let items = access
Expand Down

0 comments on commit 5bccb84

Please sign in to comment.