diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index c9b587fc3..021e218c5 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -5,8 +5,10 @@ use tls_codec::Serialize; use xmtp_proto::xmtp::mls::database::{ add_members_publish_data::{Version as AddMembersVersion, V1 as AddMembersV1}, post_commit_action::{Kind as PostCommitActionKind, SendWelcomes as SendWelcomesProto}, + remove_members_publish_data::{Version as RemoveMembersVersion, V1 as RemoveMembersV1}, send_message_publish_data::{Version as SendMessageVersion, V1 as SendMessageV1}, - AddMembersPublishData, PostCommitAction as PostCommitActionProto, SendMessagePublishData, + AddMembersPublishData, PostCommitAction as PostCommitActionProto, RemoveMembersPublishData, + SendMessagePublishData, }; use crate::{ @@ -90,7 +92,7 @@ impl AddMembersIntentData { })), } .encode(&mut buf) - .unwrap(); + .expect("encode error"); Ok(buf) } @@ -123,6 +125,47 @@ impl TryFrom for Vec { } } +#[derive(Debug, Clone)] +pub struct RemoveMembersIntentData { + pub installation_ids: Vec>, +} + +impl RemoveMembersIntentData { + pub fn new(installation_ids: Vec>) -> Self { + Self { installation_ids } + } + + pub(crate) fn to_bytes(&self) -> Vec { + let mut buf = Vec::new(); + + RemoveMembersPublishData { + version: Some(RemoveMembersVersion::V1(RemoveMembersV1 { + installation_ids: self.installation_ids.clone(), + })), + } + .encode(&mut buf) + .expect("encode error"); + + buf + } + + pub(crate) fn from_bytes(data: &[u8]) -> Result { + let msg = RemoveMembersPublishData::decode(data)?; + let installation_ids = match msg.version { + Some(RemoveMembersVersion::V1(v1)) => v1.installation_ids, + None => return Err(IntentError::Generic("missing payload".to_string())), + }; + + Ok(Self::new(installation_ids)) + } +} + +impl From for Vec { + fn from(intent: RemoveMembersIntentData) -> Self { + intent.to_bytes() + } +} + #[derive(Debug, Clone)] pub enum PostCommitAction { SendWelcomes(SendWelcomesAction), diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index e5d6e7017..a45039b90 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -3,8 +3,8 @@ mod intents; use intents::SendMessageIntentData; use openmls::{ prelude::{ - CredentialWithKey, CryptoConfig, GroupId, MlsGroup as OpenMlsGroup, MlsGroupConfig, - WireFormatPolicy, + CredentialWithKey, CryptoConfig, GroupId, LeafNodeIndex, MlsGroup as OpenMlsGroup, + MlsGroupConfig, WireFormatPolicy, }, prelude_test::KeyPackage, }; @@ -13,7 +13,7 @@ use thiserror::Error; use tls_codec::Serialize; use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient}; -use self::intents::{AddMembersIntentData, IntentError, PostCommitAction}; +use self::intents::{AddMembersIntentData, IntentError, PostCommitAction, RemoveMembersIntentData}; use crate::{ client::ClientError, configuration::CIPHERSUITE, @@ -43,6 +43,8 @@ pub enum GroupError { TlsSerialization(#[from] tls_codec::Error), #[error("add members: {0}")] AddMembers(#[from] openmls::prelude::AddMembersError), + #[error("remove members: {0}")] + RemoveMembers(#[from] openmls::prelude::RemoveMembersError), #[error("group create: {0}")] GroupCreate(#[from] openmls::prelude::NewGroupError), #[error("client: {0}")] @@ -131,6 +133,24 @@ where Ok(()) } + pub async fn remove_members_by_installation_id( + &self, + installation_ids: Vec>, + ) -> Result<(), GroupError> { + let mut conn = self.client.store.conn()?; + let intent_data: Vec = RemoveMembersIntentData::new(installation_ids).into(); + let intent = NewGroupIntent::new( + IntentKind::RemoveMembers, + self.group_id.clone(), + intent_data, + ); + intent.store(&mut conn)?; + + self.publish_intents(&mut conn).await?; + + Ok(()) + } + pub(crate) async fn publish_intents(&self, conn: &mut DbConnection) -> Result<(), GroupError> { let provider = self.client.mls_provider(); let mut openmls_group = self.load_mls_group(&provider)?; @@ -153,15 +173,17 @@ where } let (payload, post_commit_data) = result.expect("result already checked"); + let payload_slice = payload.as_slice(); + self.client .api_client - .publish_to_group(vec![payload.as_slice()]) + .publish_to_group(vec![payload_slice]) .await?; self.client.store.set_group_intent_published( conn, intent.id, - sha256(payload.as_slice()), + sha256(payload_slice), post_commit_data, )?; } @@ -223,6 +245,36 @@ where Ok((commit_bytes, post_commit_data)) } + IntentKind::RemoveMembers => { + let intent_data = RemoveMembersIntentData::from_bytes(intent.data.as_slice())?; + let leaf_nodes: Vec = openmls_group + .members() + .filter(|member| intent_data.installation_ids.contains(&member.signature_key)) + .map(|member| member.index) + .collect(); + + let num_leaf_nodes = leaf_nodes.len(); + + if num_leaf_nodes != intent_data.installation_ids.len() { + return Err(GroupError::Generic(format!( + "expected {} leaf nodes, found {}", + intent_data.installation_ids.len(), + num_leaf_nodes + ))); + } + + // The second return value is a Welcome, which is only possible if there + // are pending proposals. Ignoring for now + let (commit, _, _) = openmls_group.remove_members( + provider, + &self.client.identity.installation_keys, + leaf_nodes.as_slice(), + )?; + + let commit_bytes = commit.tls_serialize_detached()?; + + Ok((commit_bytes, None)) + } _ => Err(GroupError::Generic("invalid intent kind".to_string())), } } @@ -243,6 +295,7 @@ fn build_group_config() -> MlsGroupConfig { #[cfg(test)] mod tests { + use openmls_traits::OpenMlsProvider; use xmtp_cryptography::utils::generate_local_wallet; use crate::builder::ClientBuilder; @@ -302,4 +355,59 @@ mod tests { assert!(result.is_err()); } + + #[tokio::test] + async fn test_remove_member() { + let client_1 = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + // Add another client onto the network + let client_2 = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + client_2.register_identity().await.unwrap(); + + let provider = client_1.mls_provider(); + let group = client_1.create_group().expect("create group"); + group + .add_members_by_installation_id(vec![client_2 + .identity + .installation_keys + .to_public_vec()]) + .await + .expect("group create failure"); + + // Try and add another member without merging the pending commit + group + .remove_members_by_installation_id(vec![client_2 + .identity + .installation_keys + .to_public_vec()]) + .await + .expect("group create failure"); + + // We are expecting 1 message on the group topic, not 2, because the second one should have + // failed + let topic = group.topic(); + let messages = client_1 + .api_client + .read_topic(topic.as_str(), 0) + .await + .expect("read topic"); + + assert_eq!(messages.len(), 1); + // Now merge the commit and try again + let mut mls_group = group.load_mls_group(&provider).unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); + mls_group.save(provider.key_store()).unwrap(); + + group + .publish_intents(&mut client_1.store.conn().unwrap()) + .await + .unwrap(); + + let messages_after_second_try = client_1 + .api_client + .read_topic(topic.as_str(), 0) + .await + .expect("read topic"); + + assert_eq!(messages_after_second_try.len(), 2) + } }