diff --git a/basic_credential/src/lib.rs b/basic_credential/src/lib.rs index fe8ea30c2..350ea4fc6 100644 --- a/basic_credential/src/lib.rs +++ b/basic_credential/src/lib.rs @@ -109,7 +109,7 @@ impl SignatureKeyPair { } } - fn id(&self) -> StorageId { + pub fn id(&self) -> StorageId { StorageId { value: id(&self.public, self.signature_scheme), } diff --git a/memory_storage/src/lib.rs b/memory_storage/src/lib.rs index 3eadce72f..4290d5723 100644 --- a/memory_storage/src/lib.rs +++ b/memory_storage/src/lib.rs @@ -1,5 +1,4 @@ use openmls_traits::storage::*; - use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::RwLock}; @@ -803,12 +802,18 @@ impl StorageProvider for MemoryStorage { &self, group_id: &GroupId, ) -> Result<(), Self::Error> { + // Get all proposal refs for this group. + let proposal_refs: Vec = + self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &serde_json::to_vec(group_id)?)?; let mut values = self.values.write().unwrap(); + for proposal_ref in proposal_refs { + // Delete all proposals. + let key = serde_json::to_vec(&(group_id, proposal_ref))?; + values.remove(&key); + } - let key = build_key::(QUEUED_PROPOSAL_LABEL, group_id); - - // XXX #1566: also remove the proposal refs. can't be done now because they are stored in a - // non-recoverable way + // Delete the proposal refs from the store. + let key = build_key::(PROPOSAL_QUEUE_REFS_LABEL, group_id); values.remove(&key); Ok(()) @@ -888,7 +893,7 @@ impl StorageProvider for MemoryStorage { aad: &[u8], ) -> Result<(), Self::Error> { let key = serde_json::to_vec(group_id)?; - self.write::(AAD_LABEL, &key, serde_json::to_vec(aad).unwrap()) + self.write::(AAD_LABEL, &key, aad.to_vec()) } fn delete_aad>( diff --git a/memory_storage/tests/proposals.rs b/memory_storage/tests/proposals.rs new file mode 100644 index 000000000..f0d888be0 --- /dev/null +++ b/memory_storage/tests/proposals.rs @@ -0,0 +1,81 @@ +use openmls_memory_storage::MemoryStorage; +use openmls_traits::storage::{ + traits::{self}, + Entity, Key, StorageProvider, CURRENT_VERSION, +}; +use serde::{Deserialize, Serialize}; + +// Test types +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] +struct TestGroupId(Vec); +impl traits::GroupId for TestGroupId {} +impl Key for TestGroupId {} + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)] +struct ProposalRef(usize); +impl traits::ProposalRef for ProposalRef {} +impl Key for ProposalRef {} +impl Entity for ProposalRef {} + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] +struct Proposal(Vec); +impl traits::QueuedProposal for Proposal {} +impl Entity for Proposal {} + +/// Write and read some proposals +#[test] +fn read_write_delete() { + let group_id = TestGroupId(b"TestGroupId".to_vec()); + let proposals = (0..10) + .map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec())) + .collect::>(); + let storage = MemoryStorage::default(); + + // Store proposals + for (i, proposal) in proposals.iter().enumerate() { + storage + .queue_proposal(&group_id, &ProposalRef(i), proposal) + .unwrap(); + } + + // Read proposal refs + let proposal_refs_read: Vec = storage.queued_proposal_refs(&group_id).unwrap(); + assert_eq!( + (0..10).map(|i| ProposalRef(i)).collect::>(), + proposal_refs_read + ); + + // Read proposals + let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap(); + let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) + .map(|i| ProposalRef(i)) + .zip(proposals.clone().into_iter()) + .collect(); + assert_eq!(proposals_expected, proposals_read); + + // Remove proposal 5 + storage.remove_proposal(&group_id, &ProposalRef(5)).unwrap(); + + let proposal_refs_read: Vec = storage.queued_proposal_refs(&group_id).unwrap(); + let mut expected = (0..10).map(|i| ProposalRef(i)).collect::>(); + expected.remove(5); + assert_eq!(expected, proposal_refs_read); + + let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap(); + let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) + .map(|i| ProposalRef(i)) + .zip(proposals.clone().into_iter()) + .collect(); + proposals_expected.remove(5); + assert_eq!(proposals_expected, proposals_read); + + // Clear all proposals + storage + .clear_proposal_queue::(&group_id) + .unwrap(); + let proposal_refs_read: Vec = storage.queued_proposal_refs(&group_id).unwrap(); + assert!(proposal_refs_read.is_empty()); + + let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap(); + assert!(proposals_read.is_empty()); +} diff --git a/openmls/src/group/core_group/mod.rs b/openmls/src/group/core_group/mod.rs index 11f9f3dea..a78fff0fc 100644 --- a/openmls/src/group/core_group/mod.rs +++ b/openmls/src/group/core_group/mod.rs @@ -1135,13 +1135,12 @@ impl CoreGroup { } /// Create a new group context extension proposal - pub(crate) fn create_group_context_ext_proposal( + pub(crate) fn create_group_context_ext_proposal( &self, framing_parameters: FramingParameters, extensions: Extensions, signer: &impl Signer, - ) -> Result> - { + ) -> Result { // Ensure that the group supports all the extensions that are wanted. let required_extension = extensions .iter() diff --git a/openmls/src/group/core_group/test_proposals.rs b/openmls/src/group/core_group/test_proposals.rs index 240c990bd..ca166ea72 100644 --- a/openmls/src/group/core_group/test_proposals.rs +++ b/openmls/src/group/core_group/test_proposals.rs @@ -512,7 +512,10 @@ fn test_group_context_extension_proposal_fails( } #[openmls_test::openmls_test] -fn test_group_context_extension_proposal(ciphersuite: Ciphersuite, provider: &Provider) { +fn test_group_context_extension_proposal( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -583,7 +586,7 @@ fn test_group_context_extension_proposal(ciphersuite: Ciphersuite, provider: &Pr &[CredentialType::Basic], )); let gce_proposal = alice_group - .create_group_context_ext_proposal::( + .create_group_context_ext_proposal( framing_parameters, Extensions::single(required_application_id), &alice_signer, diff --git a/openmls/src/group/mls_group/errors.rs b/openmls/src/group/mls_group/errors.rs index abf94bdbe..b260b4944 100644 --- a/openmls/src/group/mls_group/errors.rs +++ b/openmls/src/group/mls_group/errors.rs @@ -340,7 +340,7 @@ pub enum ProposalError { ValidationError(#[from] ValidationError), /// See [`CreateGroupContextExtProposalError`] for more details. #[error(transparent)] - CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError), + CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError), /// Error writing proposal to storage. #[error("error writing proposal to storage")] StorageError(StorageError), diff --git a/openmls/src/group/mls_group/proposal.rs b/openmls/src/group/mls_group/proposal.rs index b76a46bd3..a747ac141 100644 --- a/openmls/src/group/mls_group/proposal.rs +++ b/openmls/src/group/mls_group/proposal.rs @@ -363,7 +363,7 @@ impl MlsGroup { ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { self.is_operational()?; - let proposal = self.group.create_group_context_ext_proposal::( + let proposal = self.group.create_group_context_ext_proposal( self.framing_parameters(), extensions, signer, diff --git a/openmls/src/storage.rs b/openmls/src/storage.rs index 854a7a666..585ee4471 100644 --- a/openmls/src/storage.rs +++ b/openmls/src/storage.rs @@ -46,7 +46,7 @@ pub trait OpenMlsProvider: } impl< - Error: std::error::Error + PartialEq, + Error: std::error::Error, SP: StorageProvider, OP: openmls_traits::OpenMlsProvider, > OpenMlsProvider for OP