Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream merge #27

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion basic_credential/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl SignatureKeyPair {
}
}

fn id(&self) -> StorageId {
pub fn id(&self) -> StorageId {
StorageId {
value: id(&self.public, self.signature_scheme),
}
Expand Down
17 changes: 11 additions & 6 deletions memory_storage/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use openmls_traits::storage::*;

use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::RwLock};

Expand Down Expand Up @@ -803,12 +802,18 @@ impl StorageProvider<CURRENT_VERSION> for MemoryStorage {
&self,
group_id: &GroupId,
) -> Result<(), Self::Error> {
// Get all proposal refs for this group.
let proposal_refs: Vec<ProposalRef> =
self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &serde_json::to_vec(group_id)?)?;
let mut values = self.values.write().unwrap();
for proposal_ref in proposal_refs {
// Delete all proposals.
let key = serde_json::to_vec(&(group_id, proposal_ref))?;
values.remove(&key);
}

let key = build_key::<CURRENT_VERSION, &GroupId>(QUEUED_PROPOSAL_LABEL, group_id);

// XXX #1566: also remove the proposal refs. can't be done now because they are stored in a
// non-recoverable way
// Delete the proposal refs from the store.
let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id);
values.remove(&key);

Ok(())
Expand Down Expand Up @@ -888,7 +893,7 @@ impl StorageProvider<CURRENT_VERSION> for MemoryStorage {
aad: &[u8],
) -> Result<(), Self::Error> {
let key = serde_json::to_vec(group_id)?;
self.write::<CURRENT_VERSION>(AAD_LABEL, &key, serde_json::to_vec(aad).unwrap())
self.write::<CURRENT_VERSION>(AAD_LABEL, &key, aad.to_vec())
}

fn delete_aad<GroupId: traits::GroupId<CURRENT_VERSION>>(
Expand Down
81 changes: 81 additions & 0 deletions memory_storage/tests/proposals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use openmls_memory_storage::MemoryStorage;
use openmls_traits::storage::{
traits::{self},
Entity, Key, StorageProvider, CURRENT_VERSION,
};
use serde::{Deserialize, Serialize};

// Test types
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
struct TestGroupId(Vec<u8>);
impl traits::GroupId<CURRENT_VERSION> for TestGroupId {}
impl Key<CURRENT_VERSION> for TestGroupId {}

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
struct ProposalRef(usize);
impl traits::ProposalRef<CURRENT_VERSION> for ProposalRef {}
impl Key<CURRENT_VERSION> for ProposalRef {}
impl Entity<CURRENT_VERSION> for ProposalRef {}

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
struct Proposal(Vec<u8>);
impl traits::QueuedProposal<CURRENT_VERSION> for Proposal {}
impl Entity<CURRENT_VERSION> for Proposal {}

/// Write and read some proposals
#[test]
fn read_write_delete() {
let group_id = TestGroupId(b"TestGroupId".to_vec());
let proposals = (0..10)
.map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec()))
.collect::<Vec<_>>();
let storage = MemoryStorage::default();

// Store proposals
for (i, proposal) in proposals.iter().enumerate() {
storage
.queue_proposal(&group_id, &ProposalRef(i), proposal)
.unwrap();
}

// Read proposal refs
let proposal_refs_read: Vec<ProposalRef> = storage.queued_proposal_refs(&group_id).unwrap();
assert_eq!(
(0..10).map(|i| ProposalRef(i)).collect::<Vec<_>>(),
proposal_refs_read
);

// Read proposals
let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap();
let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
.map(|i| ProposalRef(i))
.zip(proposals.clone().into_iter())
.collect();
assert_eq!(proposals_expected, proposals_read);

// Remove proposal 5
storage.remove_proposal(&group_id, &ProposalRef(5)).unwrap();

let proposal_refs_read: Vec<ProposalRef> = storage.queued_proposal_refs(&group_id).unwrap();
let mut expected = (0..10).map(|i| ProposalRef(i)).collect::<Vec<_>>();
expected.remove(5);
assert_eq!(expected, proposal_refs_read);

let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap();
let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
.map(|i| ProposalRef(i))
.zip(proposals.clone().into_iter())
.collect();
proposals_expected.remove(5);
assert_eq!(proposals_expected, proposals_read);

// Clear all proposals
storage
.clear_proposal_queue::<TestGroupId, ProposalRef>(&group_id)
.unwrap();
let proposal_refs_read: Vec<ProposalRef> = storage.queued_proposal_refs(&group_id).unwrap();
assert!(proposal_refs_read.is_empty());

let proposals_read: Vec<(ProposalRef, Proposal)> = storage.queued_proposals(&group_id).unwrap();
assert!(proposals_read.is_empty());
}
5 changes: 2 additions & 3 deletions openmls/src/group/core_group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1135,13 +1135,12 @@ impl CoreGroup {
}

/// Create a new group context extension proposal
pub(crate) fn create_group_context_ext_proposal<Provider: OpenMlsProvider>(
pub(crate) fn create_group_context_ext_proposal(
&self,
framing_parameters: FramingParameters,
extensions: Extensions,
signer: &impl Signer,
) -> Result<AuthenticatedContent, CreateGroupContextExtProposalError<Provider::StorageError>>
{
) -> Result<AuthenticatedContent, CreateGroupContextExtProposalError> {
// Ensure that the group supports all the extensions that are wanted.
let required_extension = extensions
.iter()
Expand Down
7 changes: 5 additions & 2 deletions openmls/src/group/core_group/test_proposals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,10 @@ fn test_group_context_extension_proposal_fails(
}

#[openmls_test::openmls_test]
fn test_group_context_extension_proposal(ciphersuite: Ciphersuite, provider: &Provider) {
fn test_group_context_extension_proposal(
ciphersuite: Ciphersuite,
provider: &impl crate::storage::OpenMlsProvider,
) {
// Basic group setup.
let group_aad = b"Alice's test group";
let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage);
Expand Down Expand Up @@ -583,7 +586,7 @@ fn test_group_context_extension_proposal(ciphersuite: Ciphersuite, provider: &Pr
&[CredentialType::Basic],
));
let gce_proposal = alice_group
.create_group_context_ext_proposal::<Provider>(
.create_group_context_ext_proposal(
framing_parameters,
Extensions::single(required_application_id),
&alice_signer,
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/group/mls_group/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ pub enum ProposalError<StorageError> {
ValidationError(#[from] ValidationError),
/// See [`CreateGroupContextExtProposalError`] for more details.
#[error(transparent)]
CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError<StorageError>),
CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError),
/// Error writing proposal to storage.
#[error("error writing proposal to storage")]
StorageError(StorageError),
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/group/mls_group/proposal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl MlsGroup {
) -> Result<(MlsMessageOut, ProposalRef), ProposalError<Provider::StorageError>> {
self.is_operational()?;

let proposal = self.group.create_group_context_ext_proposal::<Provider>(
let proposal = self.group.create_group_context_ext_proposal(
self.framing_parameters(),
extensions,
signer,
Expand Down
2 changes: 1 addition & 1 deletion openmls/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub trait OpenMlsProvider:
}

impl<
Error: std::error::Error + PartialEq,
Error: std::error::Error,
SP: StorageProvider<Error = Error>,
OP: openmls_traits::OpenMlsProvider<StorageProvider = SP>,
> OpenMlsProvider for OP
Expand Down
Loading