Skip to content

Commit

Permalink
Update to New BasicCredential API
Browse files Browse the repository at this point in the history
  • Loading branch information
zombieobject committed Mar 21, 2024
1 parent 82bb5b1 commit bc911c4
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 18 deletions.
3 changes: 3 additions & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashSet, mem::Discriminant};

use openmls:: {
credentials::errors::BasicCredentialError,
framing::{MlsMessageBodyIn, MlsMessageIn},
group:: GroupEpoch,
messages::Welcome,
Expand Down Expand Up @@ -120,6 +121,8 @@ pub enum MessageProcessingError {
EpochIncrementNotAllowed,
#[error("Welcome processing error: {0}")]
WelcomeProcessing(String),
#[error("wrong credential type")]
WrongCredentialType(#[from] BasicCredentialError),
}

impl crate::retry::RetryableError for MessageProcessingError {
Expand Down
9 changes: 7 additions & 2 deletions xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::collections::HashMap;

use openmls::group::MlsGroup as OpenMlsGroup;
use openmls::{
credentials::BasicCredential,
group::MlsGroup as OpenMlsGroup
};

use xmtp_proto::api_client::XmtpMlsClient;

use super::{GroupError, MlsGroup};
Expand Down Expand Up @@ -37,8 +41,9 @@ pub fn aggregate_member_list(openmls_group: &OpenMlsGroup) -> Result<Vec<GroupMe
let member_map: HashMap<String, GroupMember> = openmls_group
.members()
.filter_map(|member| {
let basic_credential = BasicCredential::try_from(&member.credential).unwrap();
Identity::get_validated_account_address(
member.credential.identity(),
basic_credential.identity(),
&member.signature_key,
)
.ok()
Expand Down
16 changes: 14 additions & 2 deletions xmtp_mls/src/groups/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use std::{collections::HashMap, mem::discriminant};

use log::debug;
use openmls::{
credentials::{
BasicCredential,
errors::BasicCredentialError,
},
framing::ProtocolMessage,
group::MergePendingCommitError,
prelude::{
Expand Down Expand Up @@ -773,8 +777,12 @@ fn validate_message_sender(
if let Sender::Member(leaf_node_index) = decrypted_message.sender() {
if let Some(member) = openmls_group.member_at(*leaf_node_index) {
if member.credential.eq(decrypted_message.credential()) {
let basic_credential =
BasicCredential::try_from(&member.credential)
.map_err(|_| BasicCredentialError::WrongCredentialType)?;

sender_account_address = Identity::get_validated_account_address(
member.credential.identity(),
basic_credential.identity(),
&member.signature_key,
)
.ok();
Expand All @@ -784,9 +792,13 @@ fn validate_message_sender(
}

if sender_account_address.is_none() {
let basic_credential =
BasicCredential::try_from(decrypted_message.credential())
.map_err(|_| BasicCredentialError::WrongCredentialType)?;

return Err(MessageProcessingError::InvalidSender {
message_time_ns: message_created_ns,
credential: decrypted_message.credential().identity().to_vec(),
credential: basic_credential.identity().to_vec(),
});
}
Ok((
Expand Down
20 changes: 16 additions & 4 deletions xmtp_mls/src/groups/validated_commit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use openmls::{
credentials::CredentialType,
credentials::{BasicCredential, errors::BasicCredentialError, CredentialType},
group::{QueuedAddProposal, QueuedRemoveProposal},
prelude::{LeafNodeIndex, MlsGroup as OpenMlsGroup, Sender, StagedCommit},
};
Expand Down Expand Up @@ -47,6 +47,8 @@ pub enum CommitValidationError {
GroupMetadata(#[from] GroupMetadataError),
#[error("invalid application id")]
InvalidApplicationId,
#[error("wrong credential type")]
WrongCredentialType(#[from] BasicCredentialError),
}

// A participant in a commit. Could be the actor or the subject of a proposal
Expand Down Expand Up @@ -164,8 +166,13 @@ fn extract_actor(
) -> Result<CommitParticipant, CommitValidationError> {
if let Some(leaf_node) = group.member_at(leaf_index) {
let signature_key = leaf_node.signature_key.as_slice();

let basic_credential =
BasicCredential::try_from(&leaf_node.credential)
.map_err(|_| BasicCredentialError::WrongCredentialType)?;

let account_address =
Identity::get_validated_account_address(leaf_node.credential.identity(), signature_key)
Identity::get_validated_account_address(basic_credential.identity(), signature_key)
.map_err(|_| CommitValidationError::InvalidActorCredential)?;

let is_creator = account_address.eq(&group_metadata.creator_account_address);
Expand Down Expand Up @@ -215,8 +222,13 @@ fn extract_identity_from_remove(

if let Some(member) = group.member_at(leaf_index) {
let signature_key = member.signature_key.as_slice();

let basic_credential =
BasicCredential::try_from(&member.credential)
.map_err(|_| BasicCredentialError::WrongCredentialType)?;

let account_address =
Identity::get_validated_account_address(member.credential.identity(), signature_key)
Identity::get_validated_account_address(basic_credential.identity(), signature_key)
.map_err(|_| CommitValidationError::InvalidSubjectCredential)?;
let is_creator = account_address.eq(&group_metadata.creator_account_address);

Expand Down Expand Up @@ -503,7 +515,7 @@ mod tests {
&bola.identity.installation_keys,
CredentialWithKey {
// Broken credential
credential: Credential::new(vec![1, 2, 3], CredentialType::Basic).unwrap(),
credential: Credential::new(CredentialType::Basic, vec![1, 2, 3]),
signature_key: bola.identity.installation_keys.to_public_vec().into(),
},
)
Expand Down
31 changes: 21 additions & 10 deletions xmtp_mls/src/verified_key_package.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use openmls::prelude::{
KeyPackage, KeyPackageIn, KeyPackageVerifyError,
use openmls::{
credentials::{
BasicCredential, errors::BasicCredentialError
},
prelude::{
tls_codec::{
Deserialize, Error as TlsSerializationError
},
};
Deserialize, Error as TlsSerializationError
},
KeyPackage, KeyPackageIn, KeyPackageVerifyError
}
};

use openmls_rust_crypto::RustCrypto;
use thiserror::Error;

use crate::{
configuration::MLS_PROTOCOL_VERSION,
identity::{Identity, IdentityError},
types::Address,
configuration::MLS_PROTOCOL_VERSION, identity::{Identity, IdentityError}, types::Address
};

#[derive(Debug, Error)]
Expand All @@ -26,10 +29,14 @@ pub enum KeyPackageVerificationError {
InvalidApplicationId,
#[error("application id ({0}) does not match the credential address ({1}).")]
ApplicationIdCredentialMismatch(String, String),
#[error("invalid credential")]
InvalidCredential,
#[error("invalid lifetime")]
InvalidLifetime,
#[error("generic: {0}")]
Generic(String),
#[error("wrong credential type")]
WrongCredentialType(#[from] BasicCredentialError),
}

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -49,9 +56,13 @@ impl VerifiedKeyPackage {
// Validates starting with a KeyPackage (which is already validated by OpenMLS)
pub fn from_key_package(kp: KeyPackage) -> Result<Self, KeyPackageVerificationError> {
let leaf_node = kp.leaf_node();
let identity_bytes = leaf_node.credential().identity();

let basic_credential =
BasicCredential::try_from(leaf_node.credential())
.map_err(|_| BasicCredentialError::WrongCredentialType)?;

let pub_key_bytes = leaf_node.signature_key().as_slice();
let account_address = identity_to_account_address(identity_bytes, pub_key_bytes)?;
let account_address = identity_to_account_address(basic_credential.identity(), pub_key_bytes)?;
let application_id = extract_application_id(&kp)?;
if !account_address.eq(&application_id) {
return Err(
Expand Down

0 comments on commit bc911c4

Please sign in to comment.