diff --git a/basic_credential/src/lib.rs b/basic_credential/src/lib.rs index fe8ea30c2..350ea4fc6 100644 --- a/basic_credential/src/lib.rs +++ b/basic_credential/src/lib.rs @@ -109,7 +109,7 @@ impl SignatureKeyPair { } } - fn id(&self) -> StorageId { + pub fn id(&self) -> StorageId { StorageId { value: id(&self.public, self.signature_scheme), } diff --git a/memory_storage/src/lib.rs b/memory_storage/src/lib.rs index 3eadce72f..6e3a502df 100644 --- a/memory_storage/src/lib.rs +++ b/memory_storage/src/lib.rs @@ -803,12 +803,18 @@ impl StorageProvider for MemoryStorage { &self, group_id: &GroupId, ) -> Result<(), Self::Error> { + // Get all proposal refs for this group. + let proposal_refs: Vec = + self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &serde_json::to_vec(group_id)?)?; let mut values = self.values.write().unwrap(); + for proposal_ref in proposal_refs { + // Delete all proposals. + let key = serde_json::to_vec(&(group_id, proposal_ref))?; + values.remove(&key); + } - let key = build_key::(QUEUED_PROPOSAL_LABEL, group_id); - - // XXX #1566: also remove the proposal refs. can't be done now because they are stored in a - // non-recoverable way + // Delete the proposal refs from the store. + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id); values.remove(&key); Ok(()) diff --git a/memory_storage/tests/proposals.rs b/memory_storage/tests/proposals.rs new file mode 100644 index 000000000..f0d888be0 --- /dev/null +++ b/memory_storage/tests/proposals.rs @@ -0,0 +1,81 @@ +use openmls_memory_storage::MemoryStorage; +use openmls_traits::storage::{ + traits::{self}, + Entity, Key, StorageProvider, CURRENT_VERSION, +}; +use serde::{Deserialize, Serialize}; + +// Test types +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] +struct TestGroupId(Vec); +impl traits::GroupId for TestGroupId {} +impl Key for TestGroupId {} + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)] +struct ProposalRef(usize); +impl traits::ProposalRef for ProposalRef {} +impl Key for ProposalRef {} +impl Entity for ProposalRef {} + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] +struct Proposal(Vec); +impl traits::QueuedProposal for Proposal {} +impl Entity for Proposal {} + +/// Write and read some proposals +#[test] +fn read_write_delete() { + let group_id = TestGroupId(b"TestGroupId".to_vec()); + let proposals = (0..10) + .map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec())) + .collect::>(); + let storage = MemoryStorage::default(); + + // Store proposals + for (i, proposal) in proposals.iter().enumerate() { + storage + .queue_proposal(&group_id, &ProposalRef(i), proposal) + .unwrap(); + } + + // Read proposal refs + let proposal_refs_read: Vec = storage.queued_proposal_refs(&group_id).unwrap(); + assert_eq!( + (0..10).map(|i| ProposalRef(i)).collect::>(), + proposal_refs_read + ); + + // Read proposals + let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap(); + let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) + .map(|i| ProposalRef(i)) + .zip(proposals.clone().into_iter()) + .collect(); + assert_eq!(proposals_expected, proposals_read); + + // Remove proposal 5 + storage.remove_proposal(&group_id, &ProposalRef(5)).unwrap(); + + let proposal_refs_read: Vec = storage.queued_proposal_refs(&group_id).unwrap(); + let mut expected = (0..10).map(|i| ProposalRef(i)).collect::>(); + expected.remove(5); + assert_eq!(expected, proposal_refs_read); + + let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap(); + let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) + .map(|i| ProposalRef(i)) + .zip(proposals.clone().into_iter()) + .collect(); + proposals_expected.remove(5); + assert_eq!(proposals_expected, proposals_read); + + // Clear all proposals + storage + .clear_proposal_queue::(&group_id) + .unwrap(); + let proposal_refs_read: Vec = storage.queued_proposal_refs(&group_id).unwrap(); + assert!(proposal_refs_read.is_empty()); + + let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap(); + assert!(proposals_read.is_empty()); +} diff --git a/openmls/src/group/errors.rs b/openmls/src/group/errors.rs index 50e2e8317..8dc1332e5 100644 --- a/openmls/src/group/errors.rs +++ b/openmls/src/group/errors.rs @@ -519,6 +519,9 @@ pub enum CreateGroupContextExtProposalError { /// See [`CreateCommitError`] for more details. #[error(transparent)] CreateCommitError(#[from] CreateCommitError), + /// Error writing updated group to storage. + #[error("Error writing updated group data to storage.")] + StorageError(StorageError), } /// Error merging a commit. diff --git a/openmls/src/group/mls_group/proposal.rs b/openmls/src/group/mls_group/proposal.rs index b76a46bd3..be6df8e93 100644 --- a/openmls/src/group/mls_group/proposal.rs +++ b/openmls/src/group/mls_group/proposal.rs @@ -420,6 +420,11 @@ impl MlsGroup { create_commit_result.staged_commit, ))); + provider + .storage() + .write_group_state(self.group_id(), &self.group_state) + .map_err(CreateGroupContextExtProposalError::StorageError)?; + Ok(( mls_messages, create_commit_result diff --git a/openmls/src/storage.rs b/openmls/src/storage.rs index 854a7a666..585ee4471 100644 --- a/openmls/src/storage.rs +++ b/openmls/src/storage.rs @@ -46,7 +46,7 @@ pub trait OpenMlsProvider: } impl< - Error: std::error::Error + PartialEq, + Error: std::error::Error, SP: StorageProvider, OP: openmls_traits::OpenMlsProvider, > OpenMlsProvider for OP