Skip to content

Commit

Permalink
Use StagedWelcome + Other Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zombieobject committed Mar 26, 2024
1 parent ca6b63d commit 95b65ef
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
37 changes: 19 additions & 18 deletions mls_validation_service/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use openmls::prelude::{MlsMessageIn, ProtocolMessage, TlsDeserializeTrait};
use openmls::{credentials::BasicCredential, prelude::{tls_codec::Deserialize, MlsMessageIn, ProtocolMessage}};
use openmls_rust_crypto::RustCrypto;
use tonic::{Request, Response, Status};

Expand Down Expand Up @@ -86,9 +86,11 @@ struct ValidateGroupMessageResult {

fn validate_group_message(message: Vec<u8>) -> Result<ValidateGroupMessageResult, String> {
let msg_result = MlsMessageIn::tls_deserialize(&mut message.as_slice())
.map_err(|_| "failed to decode".to_string())?;
.map_err(|e| e.to_string())?;

let protocol_message: ProtocolMessage = msg_result.into();
// EM: Fix Error Handling
let protocol_message: ProtocolMessage = msg_result.try_into_protocol_message()
.map_err(|e| e.to_string())?;

Ok(ValidateGroupMessageResult {
group_id: serialize_group_id(protocol_message.group_id().as_slice()),
Expand All @@ -108,15 +110,18 @@ fn validate_key_package(key_package_bytes: Vec<u8>) -> Result<ValidateKeyPackage
VerifiedKeyPackage::from_bytes(&rust_crypto, key_package_bytes.as_slice())
.map_err(|e| e.to_string())?;

let credential = verified_key_package
.inner
.leaf_node()
.credential();

let basic_credential = BasicCredential::try_from(credential)
.map_err(|e| e.to_string())?;

Ok(ValidateKeyPackageResult {
installation_id: verified_key_package.installation_id(),
account_address: verified_key_package.account_address,
credential_identity_bytes: verified_key_package
.inner
.leaf_node()
.credential()
.identity()
.to_vec(),
credential_identity_bytes: basic_credential.identity().to_vec(),
expiration: verified_key_package.inner.life_time().not_after(),
})
}
Expand All @@ -125,13 +130,9 @@ fn validate_key_package(key_package_bytes: Vec<u8>) -> Result<ValidateKeyPackage
mod tests {
use ethers::signers::LocalWallet;
use openmls::{
extensions::{ApplicationIdExtension, Extension, Extensions},
prelude::{
Ciphersuite, Credential as OpenMlsCredential, CredentialType, CredentialWithKey,
CryptoConfig, TlsSerializeTrait,
},
prelude_test::KeyPackage,
versions::ProtocolVersion,
extensions::{ApplicationIdExtension, Extension, Extensions}, prelude::{
tls_codec::Serialize, Ciphersuite, Credential as OpenMlsCredential, CredentialType, CredentialWithKey, CryptoConfig
}, prelude_test::KeyPackage, versions::ProtocolVersion
};
use openmls_basic_credential::SignatureKeyPair;
use openmls_rust_crypto::OpenMlsRustCrypto;
Expand Down Expand Up @@ -192,7 +193,7 @@ mod tests {
async fn test_validate_key_packages_happy_path() {
let (identity, keypair, account_address) = generate_identity();

let credential = OpenMlsCredential::new(identity, CredentialType::Basic).unwrap();
let credential = OpenMlsCredential::new(CredentialType::Basic, identity);
let credential_with_key = CredentialWithKey {
credential,
signature_key: keypair.to_public_vec().into(),
Expand Down Expand Up @@ -222,7 +223,7 @@ mod tests {
let (identity, keypair, account_address) = generate_identity();
let (_, other_keypair, _) = generate_identity();

let credential = OpenMlsCredential::new(identity, CredentialType::Basic).unwrap();
let credential = OpenMlsCredential::new(CredentialType::Basic, identity);
let credential_with_key = CredentialWithKey {
credential,
// Use the wrong signature key to make the validation fail
Expand Down
15 changes: 10 additions & 5 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
use intents::SendMessageIntentData;
use openmls::{
extensions::{Extension, Extensions, Metadata},
group::{MlsGroupCreateConfig, MlsGroupJoinConfig},
group::{MlsGroupCreateConfig, MlsGroupJoinConfig, StagedWelcome},
prelude::{
CredentialWithKey, CryptoConfig, GroupId, MlsGroup as OpenMlsGroup, Welcome as MlsWelcome,
WireFormatPolicy,
Expand Down Expand Up @@ -204,10 +204,15 @@ where
provider: &XmtpOpenMlsProvider,
welcome: MlsWelcome,
) -> Result<Self, GroupError> {
let mut mls_group =
OpenMlsGroup::new_from_welcome(provider, &build_group_join_config(), welcome, None)?;
mls_group.save(provider.key_store())?;
let group_id = mls_group.group_id().to_vec();
let mut mls_welcome =
// EM: Fix error handling here
StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)
.expect("Error creating staged join from Welcome")
.into_group(provider)
.expect("Error creating group from staged join");

mls_welcome.save(provider.key_store())?;
let group_id = mls_welcome.group_id().to_vec();

let to_store = StoredGroup::new(group_id, now_ns(), GroupMembershipState::Pending);
let stored_group = provider.conn().insert_or_ignore_group(to_store)?;
Expand Down

0 comments on commit 95b65ef

Please sign in to comment.