Skip to content

Commit

Permalink
Merge pull request #2 from insipx/insipx/update-xmtp-openmls
Browse files Browse the repository at this point in the history
Bring fork up to date with openmls main
  • Loading branch information
insipx authored Nov 13, 2023
2 parents 89f2cff + 5558e7e commit d723800
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_test_workspace.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ jobs:
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Build workspace
run: cargo build --workspace --all-targets
run: cargo build --workspace --all-targets --exclude openmls-fuzz
- name: Test workspace
run: cargo test --workspace --all-targets --exclude=openmls --exclude openmls-fuzz
7 changes: 7 additions & 0 deletions openmls/src/extensions/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::extensions::{
UnknownExtension,
};

use super::last_resort::LastResortExtension;

fn vlbytes_len_len(length: usize) -> usize {
if length < 0x40 {
1
Expand All @@ -34,6 +36,7 @@ impl Size for Extension {
Extension::RequiredCapabilities(e) => e.tls_serialized_len(),
Extension::ExternalPub(e) => e.tls_serialized_len(),
Extension::ExternalSenders(e) => e.tls_serialized_len(),
Extension::LastResort(e) => e.tls_serialized_len(),
Extension::Unknown(_, e) => e.0.len(),
};

Expand Down Expand Up @@ -65,6 +68,7 @@ impl Serialize for Extension {
Extension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data),
Extension::ExternalPub(e) => e.tls_serialize(&mut extension_data),
Extension::ExternalSenders(e) => e.tls_serialize(&mut extension_data),
Extension::LastResort(e) => e.tls_serialize(&mut extension_data),
Extension::Unknown(_, e) => extension_data
.write_all(e.0.as_slice())
.map(|_| e.0.len())
Expand Down Expand Up @@ -111,6 +115,9 @@ impl Deserialize for Extension {
ExtensionType::ExternalSenders => Extension::ExternalSenders(
ExternalSendersExtension::tls_deserialize(&mut extension_data)?,
),
ExtensionType::LastResort => {
Extension::LastResort(LastResortExtension::tls_deserialize(&mut extension_data)?)
}
ExtensionType::Unknown(unknown) => {
Extension::Unknown(unknown, UnknownExtension(extension_data.to_vec()))
}
Expand Down
28 changes: 28 additions & 0 deletions openmls/src/extensions/last_resort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use tls_codec::{TlsDeserialize, TlsSerialize, TlsSize};

use super::{Deserialize, Serialize};

/// ```c
/// // draft-ietf-mls-extensions-03
/// struct {} LastResort;
/// ```
#[derive(
PartialEq,
Eq,
Clone,
Debug,
Serialize,
Deserialize,
TlsSerialize,
TlsDeserialize,
TlsSize,
Default,
)]
pub struct LastResortExtension {}

impl LastResortExtension {
/// Create a new `last_resort` extension.
pub fn new() -> Self {
Self::default()
}
}
60 changes: 39 additions & 21 deletions openmls/src/extensions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! # Extensions
//!
//! In MLS, extensions appear in the following places:
//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities and aspects of their
//! participation in the group.
//! - In [`KeyPackages`](`crate::key_packages`), to describe client capabilities
//! and aspects of their participation in the group.
//! - In the `GroupInfo`, to tell new members of a group what parameters are
//! being used by the group, and to provide any additional details required
//! to join the group.
//! - In the `GroupContext` object, to ensure that all members of the group
//! have the same view of the parameters in use.
//! being used by the group, and to provide any additional details required to
//! join the group.
//! - In the `GroupContext` object, to ensure that all members of the group have
//! the same view of the parameters in use.
//!
//! Note that `GroupInfo` and `GroupContext` are not exposed in OpenMLS' public
//! API.
Expand All @@ -31,6 +31,7 @@ mod application_id_extension;
mod codec;
mod external_pub_extension;
mod external_sender_extension;
mod last_resort;
mod ratchet_tree_extension;
mod required_capabilities;
use errors::*;
Expand All @@ -44,6 +45,7 @@ pub use external_pub_extension::ExternalPubExtension;
pub use external_sender_extension::{
ExternalSender, ExternalSendersExtension, SenderExtensionIndex,
};
pub use last_resort::LastResortExtension;
pub use ratchet_tree_extension::RatchetTreeExtension;
pub use required_capabilities::RequiredCapabilitiesExtension;

Expand Down Expand Up @@ -71,8 +73,8 @@ pub enum ExtensionType {
/// application-defined identifier to a KeyPackage.
ApplicationId,

/// The ratchet tree extensions provides the whole public state of the ratchet
/// tree.
/// The ratchet tree extensions provides the whole public state of the
/// ratchet tree.
RatchetTree,

/// The required capabilities extension defines the configuration of a group
Expand All @@ -87,6 +89,10 @@ pub enum ExtensionType {
/// of senders that are permitted to send external proposals to the group.
ExternalSenders,

/// KeyPackage extension that marks a KeyPackage for use in a last resort
/// scenario.
LastResort,

/// A currently unknown extension type.
Unknown(u16),
}
Expand Down Expand Up @@ -125,6 +131,7 @@ impl From<u16> for ExtensionType {
3 => ExtensionType::RequiredCapabilities,
4 => ExtensionType::ExternalPub,
5 => ExtensionType::ExternalSenders,
10 => ExtensionType::LastResort,
unknown => ExtensionType::Unknown(unknown),
}
}
Expand All @@ -138,6 +145,7 @@ impl From<ExtensionType> for u16 {
ExtensionType::RequiredCapabilities => 3,
ExtensionType::ExternalPub => 4,
ExtensionType::ExternalSenders => 5,
ExtensionType::LastResort => 10,
ExtensionType::Unknown(unknown) => unknown,
}
}
Expand All @@ -153,6 +161,7 @@ impl ExtensionType {
| ExtensionType::RequiredCapabilities
| ExtensionType::ExternalPub
| ExtensionType::ExternalSenders
| ExtensionType::LastResort
)
}
}
Expand Down Expand Up @@ -182,12 +191,15 @@ pub enum Extension {
/// A [`RequiredCapabilitiesExtension`]
RequiredCapabilities(RequiredCapabilitiesExtension),

/// A [`ExternalPubExtension`]
/// An [`ExternalPubExtension`]
ExternalPub(ExternalPubExtension),

/// A [`ExternalPubExtension`]
/// An [`ExternalSendersExtension`]
ExternalSenders(ExternalSendersExtension),

/// A [`LastResortExtension`]
LastResort(LastResortExtension),

/// A currently unknown extension.
Unknown(u16, UnknownExtension),
}
Expand Down Expand Up @@ -234,7 +246,8 @@ impl Extensions {

/// Create an extension list with multiple extensions.
///
/// This function will fail when the list of extensions contains duplicate extension types.
/// This function will fail when the list of extensions contains duplicate
/// extension types.
pub fn from_vec(extensions: Vec<Extension>) -> Result<Self, InvalidExtensionError> {
extensions.try_into()
}
Expand All @@ -246,7 +259,8 @@ impl Extensions {

/// Add an extension to the extension list.
///
/// Returns an error when there already is an extension with the same extension type.
/// Returns an error when there already is an extension with the same
/// extension type.
pub fn add(&mut self, extension: Extension) -> Result<(), InvalidExtensionError> {
if self.contains(extension.extension_type()) {
return Err(InvalidExtensionError::Duplicate);
Expand All @@ -268,7 +282,8 @@ impl Extensions {

/// Remove an extension from the extension list.
///
/// Returns the removed extension or `None` when there is no extension with the given extension type.
/// Returns the removed extension or `None` when there is no extension with
/// the given extension type.
pub fn remove(&mut self, extension_type: ExtensionType) -> Option<Extension> {
if let Some(pos) = self
.unique
Expand All @@ -281,7 +296,8 @@ impl Extensions {
}
}

/// Returns `true` iff the extension list contains an extension with the given extension type.
/// Returns `true` iff the extension list contains an extension with the
/// given extension type.
pub fn contains(&self, extension_type: ExtensionType) -> bool {
self.unique
.iter()
Expand Down Expand Up @@ -335,7 +351,8 @@ impl Extensions {
})
}

/// Get a reference to the [`RequiredCapabilitiesExtension`] if there is any.
/// Get a reference to the [`RequiredCapabilitiesExtension`] if there is
/// any.
pub fn required_capabilities(&self) -> Option<&RequiredCapabilitiesExtension> {
self.find_by_type(ExtensionType::RequiredCapabilities)
.and_then(|e| match e {
Expand Down Expand Up @@ -389,8 +406,8 @@ impl Extension {
}

/// Get a reference to this extension as [`RequiredCapabilitiesExtension`].
/// Returns an [`ExtensionError::InvalidExtensionType`] error if called on an
/// [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
/// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
/// an [`Extension`] that's not a [`RequiredCapabilitiesExtension`].
pub fn as_required_capabilities_extension(
&self,
) -> Result<&RequiredCapabilitiesExtension, ExtensionError> {
Expand All @@ -403,8 +420,8 @@ impl Extension {
}

/// Get a reference to this extension as [`ExternalPubExtension`].
/// Returns an [`ExtensionError::InvalidExtensionType`] error if called on an
/// [`Extension`] that's not a [`ExternalPubExtension`].
/// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
/// an [`Extension`] that's not a [`ExternalPubExtension`].
pub fn as_external_pub_extension(&self) -> Result<&ExternalPubExtension, ExtensionError> {
match self {
Self::ExternalPub(e) => Ok(e),
Expand All @@ -415,8 +432,8 @@ impl Extension {
}

/// Get a reference to this extension as [`ExternalSendersExtension`].
/// Returns an [`ExtensionError::InvalidExtensionType`] error if called on an
/// [`Extension`] that's not a [`ExternalSendersExtension`].
/// Returns an [`ExtensionError::InvalidExtensionType`] error if called on
/// an [`Extension`] that's not a [`ExternalSendersExtension`].
pub fn as_external_senders_extension(
&self,
) -> Result<&ExternalSendersExtension, ExtensionError> {
Expand All @@ -437,6 +454,7 @@ impl Extension {
Extension::RequiredCapabilities(_) => ExtensionType::RequiredCapabilities,
Extension::ExternalPub(_) => ExtensionType::ExternalPub,
Extension::ExternalSenders(_) => ExtensionType::ExternalSenders,
Extension::LastResort(_) => ExtensionType::LastResort,
Extension::Unknown(kind, _) => ExtensionType::Unknown(*kind),
}
}
Expand Down
101 changes: 99 additions & 2 deletions openmls/src/extensions/test_extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
//! Proper testing is done through the public APIs.
use openmls_rust_crypto::OpenMlsRustCrypto;
use openmls_traits::key_store::OpenMlsKeyStore;
use tls_codec::{Deserialize, Serialize};

use super::*;
use crate::{
credentials::*,
framing::*,
group::{errors::*, *},
group::{config::CryptoConfig, errors::*, *},
key_packages::*,
messages::proposals::ProposalType,
prelude::Capabilities,
schedule::psk::store::ResumptionPskStore,
test_utils::*,
versions::ProtocolVersion,
};

#[test]
Expand Down Expand Up @@ -206,7 +209,8 @@ fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi

#[test]
fn required_capabilities() {
// A raw required capabilities extension with the default values for openmls (none).
// A raw required capabilities extension with the default values for openmls
// (none).
let extension_bytes = vec![0, 3, 3, 0, 0, 0];

let ext = Extension::RequiredCapabilities(RequiredCapabilitiesExtension::default());
Expand Down Expand Up @@ -243,3 +247,96 @@ fn required_capabilities() {
assert_eq!(ext, ext_decoded);
assert_eq!(extension_bytes, encoded);
}

#[apply(ciphersuites_and_providers)]
fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) {
let last_resort = Extension::LastResort(LastResortExtension::default());

// Build a KeyPackage with a last resort extension
let credential = Credential::new(b"Bob".to_vec(), CredentialType::Basic).unwrap();
let signer =
openmls_basic_credential::SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap();

let extensions = Extensions::single(last_resort);
let crypto_config = CryptoConfig::with_default_version(ciphersuite);
let capabilities = Capabilities::new(
None,
None,
// Add last resort extension as supported extension
Some(&[ExtensionType::LastResort]),
None,
None,
);
let kp = KeyPackage::builder()
.key_package_extensions(extensions)
.leaf_node_capabilities(capabilities)
.build(
crypto_config,
provider,
&signer,
CredentialWithKey {
credential: credential.clone(),
signature_key: signer.to_public_vec().into(),
},
)
.expect("error building key package with last resort extension");
assert!(kp.last_resort());
let encoded_kp = kp
.tls_serialize_detached()
.expect("error encoding key package with last resort extension");
let decoded_kp = KeyPackageIn::tls_deserialize(&mut encoded_kp.as_slice())
.expect("error decoding key package with last resort extension")
.validate(provider.crypto(), ProtocolVersion::default())
.expect("error validating key package with last resort extension");
assert!(decoded_kp.last_resort());

// If we join a group using a last resort KP, it shouldn't be deleted from the
// provider.

let alice_credential_with_key_and_signer = tests::utils::generate_credential_with_key(
"Alice".into(),
ciphersuite.signature_algorithm(),
provider,
);

let mls_group_config = MlsGroupConfigBuilder::new()
.crypto_config(CryptoConfig::with_default_version(ciphersuite))
.build();

// === Alice creates a group ===
let mut alice_group = MlsGroup::new(
provider,
&alice_credential_with_key_and_signer.signer,
&mls_group_config,
alice_credential_with_key_and_signer.credential_with_key,
)
.expect("An unexpected error occurred.");

// === Alice adds Bob ===

let (_message, welcome, _group_info) = alice_group
.add_members(
provider,
&alice_credential_with_key_and_signer.signer,
&[kp.clone()],
)
.expect("An unexpected error occurred.");

alice_group.merge_pending_commit(provider).unwrap();

let _bob_group = MlsGroup::new_from_welcome(
provider,
&mls_group_config,
welcome.into_welcome().expect("Unexpected MLS message"),
Some(alice_group.export_ratchet_tree().into()),
)
.expect("An unexpected error occurred.");

// This should not have deleted the KP from the store
let kp: Option<KeyPackage> = provider.key_store().read(
kp.hash_ref(provider.crypto())
.expect("error hashing kp")
.as_slice(),
);
assert!(kp.is_some());
}
Loading

0 comments on commit d723800

Please sign in to comment.