Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove test only distinction for group_context_ext_proposal #22

Merged
merged 9 commits into from
Apr 10, 2024
9 changes: 5 additions & 4 deletions openmls/src/group/core_group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1043,11 +1043,7 @@ impl CoreGroup {
group_info: group_info.filter(|_| self.use_ratchet_tree_extension),
})
}
}

// Test functions
#[cfg(test)]
impl CoreGroup {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I need this for some work I have coming up too

pub(crate) fn create_group_context_ext_proposal(
&self,
framing_parameters: FramingParameters,
Expand Down Expand Up @@ -1082,6 +1078,11 @@ impl CoreGroup {
)
.map_err(|e| e.into())
}
}

// Test functions
#[cfg(test)]
impl CoreGroup {

pub(crate) fn use_ratchet_tree_extension(&self) -> bool {
self.use_ratchet_tree_extension
Expand Down
2 changes: 2 additions & 0 deletions openmls/src/group/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ pub enum CreateGroupContextExtProposalError {
/// See [`LeafNodeValidationError`] for more details.
#[error(transparent)]
LeafNodeValidation(#[from] LeafNodeValidationError),
#[error(transparent)]
GroupStateError(#[from] MlsGroupStateError),
}

/// Error merging a commit.
Expand Down
46 changes: 42 additions & 4 deletions openmls/src/group/mls_group/proposal.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use core_group::create_commit_params::CreateCommitParams;
Copy link

@keks keks Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really just a stylistic nitpick, but I'd suggest merging this line back into the use super::{...} block below, otherwise it looks like there was a crate called "core_group".

If you need the core_group binding later, you an do

use super::{
  core_group::{self, create_commit_params::CreateCommitParams},
  // ...
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks fixed here 7395004

use openmls_traits::{
key_store::OpenMlsKeyStore, signatures::Signer, types::Ciphersuite, OpenMlsProvider,
};

use super::{
errors::{ProposalError, ProposeAddMemberError, ProposeRemoveMemberError},
MlsGroup,
core_group, errors::{ProposalError, ProposeAddMemberError, ProposeRemoveMemberError}, CreateGroupContextExtProposalError, GroupContextExtensionProposal, GroupContextExtensionsProposalValidationError, MlsGroup, MlsGroupState, PendingCommitState, Proposal
};
use crate::{
binary_tree::LeafNodeIndex,
Expand All @@ -14,7 +14,7 @@ use crate::{
framing::MlsMessageOut,
group::{errors::CreateAddProposalError, GroupId, QueuedProposal},
key_packages::KeyPackage,
messages::proposals::ProposalOrRefType,
messages::{group_info::GroupInfo, proposals::ProposalOrRefType},
prelude::LibraryError,
schedule::PreSharedKeyId,
treesync::LeafNode,
Expand Down Expand Up @@ -319,7 +319,6 @@ impl MlsGroup {
}
}

#[cfg(test)]
pub fn propose_group_context_extensions(
&mut self,
provider: &impl OpenMlsProvider,
Expand Down Expand Up @@ -350,4 +349,43 @@ impl MlsGroup {

Ok((mls_message, proposal_ref))
}

pub fn update_group_context_extensions(
&mut self,
provider: &impl OpenMlsProvider,
extensions: Extensions,
signer: &impl Signer,
) -> Result<(MlsMessageOut, Option<MlsMessageOut>, Option<GroupInfo>), CreateGroupContextExtProposalError>
{
self.is_operational()?;

// Create inline add proposals from key packages
let mut inline_proposals = vec![];
inline_proposals.push(Proposal::GroupContextExtensions(GroupContextExtensionProposal {
extensions,
}));

let params = CreateCommitParams::builder()
.framing_parameters(self.framing_parameters())
.proposal_store(&self.proposal_store)
.inline_proposals(inline_proposals)
.build();
let create_commit_result = self.group.create_commit(params, provider, signer).unwrap();

let mls_messages = self.content_to_mls_message(create_commit_result.commit, provider)?;
self.group_state = MlsGroupState::PendingCommit(Box::new(PendingCommitState::Member(
create_commit_result.staged_commit,
)));

// Since the state of the group might be changed, arm the state flag
self.flag_state_change();

Ok((
mls_messages,
create_commit_result
.welcome_option
.map(|w| MlsMessageOut::from_welcome(w, self.group.version())),
create_commit_result.group_info,
))
}
}
163 changes: 163 additions & 0 deletions openmls/src/group/mls_group/test_mls_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,169 @@ fn unknown_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider)
.expect("Error creating group from staged join");
}

// Test the successful update of Group Context Extension with type Extension::Unknown(0xff11)
#[apply(ciphersuites_and_providers)]
fn update_group_context_with_unknown_extension(
ciphersuite: Ciphersuite,
provider: &impl OpenMlsProvider,
) {
let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) =
setup_client("Alice", ciphersuite, provider);

// === Define the unknown group context extension and initial data ===
let unknown_extension_data = vec![1, 2];
let unknown_gc_extension = Extension::Unknown(0xff11, UnknownExtension(unknown_extension_data));
let required_extension_types = &[ExtensionType::Unknown(0xff11)];
let required_capabilities = Extension::RequiredCapabilities(
RequiredCapabilitiesExtension::new(required_extension_types, &[], &[]),
);
let capabilities = Capabilities::new(None, None, Some(required_extension_types), None, None);
let test_gc_extensions = Extensions::from_vec(vec![
unknown_gc_extension.clone(),
required_capabilities.clone(),
])
.expect("error creating test group context extensions");
let mls_group_create_config = MlsGroupCreateConfig::builder()
.with_group_context_extensions(test_gc_extensions.clone())
.expect("error adding unknown extension to config")
.capabilities(capabilities.clone())
.crypto_config(CryptoConfig::with_default_version(ciphersuite))
.build();

// === Alice creates a group ===
let mut alice_group = MlsGroup::new(
provider,
&alice_signer,
&mls_group_create_config,
alice_credential_with_key,
)
.expect("error creating group");

// === Verify the initial group context extension data is correct ===
let group_context_extensions = alice_group.group().context().extensions();
let mut extracted_data = None;
for extension in group_context_extensions.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data = Some(data.clone());
}
}
assert_eq!(
extracted_data.unwrap(),
vec![1, 2],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);

// === Alice adds Bob ===
let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) =
setup_client("Bob", ciphersuite, provider);

let bob_key_package = KeyPackage::builder()
.leaf_node_capabilities(capabilities)
.build(
CryptoConfig::with_default_version(ciphersuite),
provider,
&bob_signer,
bob_credential_with_key,
)
.expect("error building key package");

let (_, welcome, _) = alice_group
.add_members(provider, &alice_signer, &[bob_key_package.clone()])
.unwrap();
alice_group.merge_pending_commit(provider).unwrap();

let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");

let bob_group = StagedWelcome::new_from_welcome(
provider,
&MlsGroupJoinConfig::default(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from Welcome")
.into_group(provider)
.expect("Error creating group from staged join");

// === Verify Bob's initial group context extension data is correct ===
let group_context_extensions = bob_group.group().context().extensions();
let mut extracted_data_2 = None;
for extension in group_context_extensions.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data_2 = Some(data.clone());
}
}
assert_eq!(
extracted_data_2.unwrap(),
vec![1, 2],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);

// === Propose the new group context extension ===
let updated_unknown_extension_data = vec![3, 4]; // Sample data for the extension
let updated_unknown_gc_extension = Extension::Unknown(
0xff11,
UnknownExtension(updated_unknown_extension_data.clone()),
);

let mut updated_extensions = test_gc_extensions.clone();
updated_extensions.add_or_replace(updated_unknown_gc_extension);
alice_group
.propose_group_context_extensions(provider, updated_extensions, &alice_signer)
.expect("failed to propose group context extensions with unknown extension");

assert_eq!(
alice_group.pending_proposals().count(),
1,
"Expected one pending proposal"
);

// === Commit to the proposed group context extension ===
alice_group
.commit_to_pending_proposals(provider, &alice_signer)
.expect("failed to commit to pending group context extensions");

alice_group
.merge_pending_commit(provider)
.expect("error merging pending commit");

alice_group
.save(provider.key_store())
.expect("error saving group");

// === Verify the group context extension was updated ===
let group_context_extensions = alice_group.group().context().extensions();
let mut extracted_data_updated = None;
for extension in group_context_extensions.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data_updated = Some(data.clone());
}
}
assert_eq!(
extracted_data_updated.unwrap(),
vec![3, 4],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);

// === Verify Bob sees the group context extension updated ===
let bob_group_loaded = MlsGroup::load(bob_group.group().group_id(), provider.key_store())
.expect("error loading group");
let group_context_extensions_2 = bob_group_loaded.export_group_context().extensions();
let mut extracted_data_2 = None;
for extension in group_context_extensions_2.iter() {
if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension {
extracted_data_2 = Some(data.clone());
}
}
assert_eq!(
extracted_data_2.unwrap(),
vec![3, 4],
"The data of Extension::Unknown(0xff11) does not match the expected data"
);
}

#[apply(ciphersuites_and_providers)]
fn join_multiple_groups_last_resort_extension(
ciphersuite: Ciphersuite,
Expand Down
3 changes: 1 addition & 2 deletions openmls/src/messages/proposals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,11 @@ pub struct AppAckProposal {
TlsSize,
)]
pub struct GroupContextExtensionProposal {
extensions: Extensions,
pub(crate) extensions: Extensions,
}

impl GroupContextExtensionProposal {
/// Create a new [`GroupContextExtensionProposal`].
#[cfg(test)]
pub(crate) fn new(extensions: Extensions) -> Self {
Self { extensions }
}
Expand Down
Loading