diff --git a/Cargo.lock b/Cargo.lock index d0e470fd0..6d8a9c3f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,6 +206,17 @@ dependencies = [ "event-listener", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -6246,6 +6257,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-barrier", + "async-recursion", "async-stream", "bincode", "criterion", diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index b99ee0775..987768b58 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -218,6 +218,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -5689,6 +5700,7 @@ name = "xmtp_mls" version = "0.0.1" dependencies = [ "aes-gcm", + "async-recursion", "async-stream", "bincode", "diesel", diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index 3e95fcbd6..e9353d8ea 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -116,6 +116,17 @@ dependencies = [ "term", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.64", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -5261,6 +5272,7 @@ name = "xmtp_mls" version = "0.0.1" dependencies = [ "aes-gcm", + "async-recursion", "async-stream", "bincode", "diesel", diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index dff1d5bd5..4e83bf993 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -72,6 +72,7 @@ tracing-flame = { version = "0.2", optional = true } tracing-subscriber = { workspace = true, optional = true } xmtp_api_grpc = { path = "../xmtp_api_grpc", optional = true } xmtp_api_http = { path = "../xmtp_api_http", optional = true } +async-recursion = "1.1.1" [dev-dependencies] anyhow.workspace = true diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 9f5621460..76dfe2d66 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -11,6 +11,7 @@ mod subscriptions; mod sync; pub mod validated_commit; +use async_recursion::async_recursion; use intents::SendMessageIntentData; use openmls::{ credentials::{BasicCredential, CredentialType}, @@ -86,7 +87,7 @@ use crate::{ consent_record::{ConsentState, ConsentType, StoredConsentRecord}, db_connection::DbConnection, group::{GroupMembershipState, Purpose, StoredGroup}, - group_intent::{IntentKind, NewGroupIntent}, + group_intent::{IntentKind, NewGroupIntent, PreIntentComplete}, group_message::{DeliveryStatus, GroupMessageKind, StoredGroupMessage}, sql_key_store, }, @@ -468,7 +469,6 @@ impl MlsGroup { where ApiClient: XmtpApi, { - self.pre_intent_hook(client).await?; let provider = client.mls_provider()?; let message_id = self.prepare_message(message, provider.conn_ref(), |now| { Self::into_envelope(message, now) @@ -554,7 +554,8 @@ impl MlsGroup { let intent_data: Vec = SendMessageIntentData::new(encoded_envelope).into(); let intent = NewGroupIntent::new(IntentKind::SendMessage, self.group_id.clone(), intent_data); - intent.store(conn)?; + // TODO(rich) Plumb client, call pre-intent hook without syncing + conn.insert_group_intent(intent, PreIntentComplete {})?; // store this unpublished message locally before sending let message_id = calculate_message_id(&self.group_id, message, &now.to_string()); @@ -651,7 +652,6 @@ impl MlsGroup { client: &Client, inbox_ids: Vec, ) -> Result<(), GroupError> { - self.pre_intent_hook(client).await?; let provider = client.mls_provider()?; let intent_data = self .get_membership_update_intent(client, &provider, inbox_ids, vec![]) @@ -665,13 +665,14 @@ impl MlsGroup { return Ok(()); } - let intent = provider - .conn_ref() - .insert_group_intent(NewGroupIntent::new( + let intent = provider.conn_ref().insert_group_intent( + NewGroupIntent::new( IntentKind::UpdateGroupMembership, self.group_id.clone(), intent_data.into(), - ))?; + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&provider, intent.id, client) .await @@ -694,19 +695,19 @@ impl MlsGroup { client: &Client, inbox_ids: Vec, ) -> Result<(), GroupError> { - self.pre_intent_hook(client).await?; let provider = client.mls_provider()?; let intent_data = self .get_membership_update_intent(client, &provider, vec![], inbox_ids) .await?; - let intent = provider - .conn_ref() - .insert_group_intent(NewGroupIntent::new( + let intent = provider.conn_ref().insert_group_intent( + NewGroupIntent::new( IntentKind::UpdateGroupMembership, self.group_id.clone(), intent_data.into(), - ))?; + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&provider, intent.id, client) .await @@ -724,13 +725,14 @@ impl MlsGroup { let intent_data: Vec = UpdateMetadataIntentData::new_update_group_name(group_name).into(); - self.pre_intent_hook(client).await?; - - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -758,11 +760,14 @@ impl MlsGroup { ) .into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdatePermission, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdatePermission, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -792,11 +797,14 @@ impl MlsGroup { let conn = self.context.store.conn()?; let intent_data: Vec = UpdateMetadataIntentData::new_update_group_description(group_description).into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -825,15 +833,17 @@ impl MlsGroup { { let conn = self.context.store.conn()?; - self.pre_intent_hook(client).await?; let intent_data: Vec = UpdateMetadataIntentData::new_update_group_image_url_square(group_image_url_square) .into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -866,11 +876,14 @@ impl MlsGroup { let conn = self.context.store.conn()?; let intent_data: Vec = UpdateMetadataIntentData::new_update_group_pinned_frame_url(pinned_frame_url).into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -941,12 +954,14 @@ impl MlsGroup { }; let intent_data: Vec = UpdateAdminListIntentData::new(intent_action_type, inbox_id).into(); - self.pre_intent_hook(client).await?; - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateAdminList, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateAdminList, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -988,41 +1003,50 @@ impl MlsGroup { } // Update this installation's leaf key in the group by creating a key update commit - pub async fn key_update(&self, client: &Client) -> Result<(), GroupError> + async fn key_update(&self, client: &Client) -> Result<(), GroupError> where ApiClient: XmtpApi, { let conn = self.context.store.conn()?; - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::KeyUpdate, - self.group_id.clone(), - vec![], - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new(IntentKind::KeyUpdate, self.group_id.clone(), vec![]), + // This is a private method called from pre_intent_hook, hence why we pass an empty PreIntentComplete + PreIntentComplete {}, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } - /// Checking the last key rotation time before rotating the key. + /// Pre-intent hook for group intents. + /// + /// This is used to ensure that the group is in a consistent state before + /// performing an action. This includes rotating the encryption keys if needed, + /// as well as syncing the group's installations. + /// + /// This should be called before any intent is inserted into the database. Infinite + /// loops are avoided by checking when each action was last performed. + #[async_recursion] pub async fn pre_intent_hook( &self, client: &Client, - ) -> Result<(), GroupError> + ) -> Result where ApiClient: XmtpApi, { let conn = self.context.store.conn()?; let last_rotated_time = conn.get_rotated_time_checked(self.group_id.clone())?; - if last_rotated_time == 0 { conn.update_rotated_time_checked(self.group_id.clone())?; self.key_update(client).await?; } + let provider: XmtpOpenMlsProvider = conn.into(); let update_interval_ns = Some(SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS); self.maybe_update_installations(&provider, update_interval_ns, client) .await?; - Ok(()) + + Ok(PreIntentComplete {}) } pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { @@ -1344,6 +1368,7 @@ fn build_group_join_config() -> MlsGroupJoinConfig { #[cfg(test)] mod tests { use crate::groups::GroupMessageVersion; + use crate::storage::group_intent::PreIntentComplete; use diesel::connection::SimpleConnection; use futures::future::join_all; use openmls::prelude::tls_codec::Deserialize; @@ -1354,7 +1379,6 @@ mod tests { use prost::Message; use std::sync::Arc; use xmtp_cryptography::utils::generate_local_wallet; - use xmtp_proto::xmtp::mls::api::v1::GroupMessage; use xmtp_proto::xmtp::mls::message_contents::EncodedContent; use crate::{ @@ -3307,11 +3331,14 @@ mod tests { } let conn = provider.conn_ref(); - conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group.group_id.clone(), - intent_data.into(), - )) + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group.group_id.clone(), + intent_data.into(), + ), + PreIntentComplete {}, // TODO(rich) Allow pre-intent hook to run without syncing + ) .unwrap(); } diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index eda0f25e2..6794a2248 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -1034,11 +1034,14 @@ impl MlsGroup { debug!("Adding missing installations {:?}", intent_data); let conn = provider.conn_ref(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - self.group_id.clone(), - intent_data.into(), - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + self.group_id.clone(), + intent_data.into(), + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(provider, intent.id, client) .await @@ -1103,6 +1106,10 @@ impl MlsGroup { Ok(updates) })?; + provider + .conn_ref() + .update_installations_time_checked(self.group_id.clone())?; + Ok(UpdateGroupMembershipIntentData::new( changed_inbox_ids, inbox_ids_to_remove, diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 746d6b1b5..079507d8e 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -16,7 +16,7 @@ use super::{ }; use crate::{ groups::{intents::SendMessageIntentData, IntentError}, - impl_fetch, impl_store, + impl_fetch, storage::StorageError, utils::id::calculate_message_id, Delete, @@ -124,6 +124,11 @@ impl Delete for DbConnection { } } +/// A marker struct to indicate that the pre-intent hook has been called. +/// Can be obtained by calling `pre_intent_hook()` on a group. +#[derive(Debug)] +pub struct PreIntentComplete {} + #[derive(Insertable, Debug, PartialEq, Clone)] #[diesel(table_name = group_intents)] pub struct NewGroupIntent { @@ -133,8 +138,6 @@ pub struct NewGroupIntent { pub state: IntentState, } -impl_store!(NewGroupIntent, group_intents); - impl NewGroupIntent { pub fn new(kind: IntentKind, group_id: Vec, data: Vec) -> Self { Self { @@ -151,6 +154,7 @@ impl DbConnection { pub fn insert_group_intent( &self, to_save: NewGroupIntent, + pre_intent_complete: PreIntentComplete, ) -> Result { Ok(self.raw_query(|conn| { diesel::insert_into(dsl::group_intents) @@ -448,7 +452,8 @@ mod tests { // Group needs to exist or FK constraint will fail insert_group(conn, group_id.clone()); - to_insert.store(conn).unwrap(); + conn.insert_group_intent(to_insert, PreIntentComplete {}) + .unwrap(); let results = conn .find_group_intents(group_id.clone(), Some(vec![IntentState::ToPublish]), None) @@ -497,7 +502,8 @@ mod tests { insert_group(conn, group_id.clone()); for case in test_intents { - case.store(conn).unwrap(); + conn.insert_group_intent(case, PreIntentComplete {}) + .unwrap(); } // Can query for multiple states @@ -553,12 +559,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); // Find the intent with the ID populated @@ -594,12 +602,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let mut intent = find_first_intent(conn, group_id.clone()); @@ -638,12 +648,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let mut intent = find_first_intent(conn, group_id.clone()); @@ -681,12 +693,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let intent = find_first_intent(conn, group_id.clone()); @@ -712,12 +726,14 @@ mod tests { let group_id = rand_vec(); with_connection(|conn| { insert_group(conn, group_id.clone()); - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let mut intent = find_first_intent(conn, group_id.clone());