diff --git a/Cargo.lock b/Cargo.lock index d0e470fd0..6d8a9c3f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,6 +206,17 @@ dependencies = [ "event-listener", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -6246,6 +6257,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-barrier", + "async-recursion", "async-stream", "bincode", "criterion", diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index b99ee0775..987768b58 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -218,6 +218,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -5689,6 +5700,7 @@ name = "xmtp_mls" version = "0.0.1" dependencies = [ "aes-gcm", + "async-recursion", "async-stream", "bincode", "diesel", diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index 3e95fcbd6..e9353d8ea 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -116,6 +116,17 @@ dependencies = [ "term", ] +[[package]] +name = "async-recursion" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.64", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -5261,6 +5272,7 @@ name = "xmtp_mls" version = "0.0.1" dependencies = [ "aes-gcm", + "async-recursion", "async-stream", "bincode", "diesel", diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index dff1d5bd5..4e83bf993 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -72,6 +72,7 @@ tracing-flame = { version = "0.2", optional = true } tracing-subscriber = { workspace = true, optional = true } xmtp_api_grpc = { path = "../xmtp_api_grpc", optional = true } xmtp_api_http = { path = "../xmtp_api_http", optional = true } +async-recursion = "1.1.1" [dev-dependencies] anyhow.workspace = true diff --git a/xmtp_mls/migrations/2024-06-20-061128_key_update/down.sql b/xmtp_mls/migrations/2024-06-20-061128_key_update/down.sql new file mode 100644 index 000000000..d9a93fe9a --- /dev/null +++ b/xmtp_mls/migrations/2024-06-20-061128_key_update/down.sql @@ -0,0 +1 @@ +-- This file should undo anything in `up.sql` diff --git a/xmtp_mls/migrations/2024-06-20-061128_key_update/up.sql b/xmtp_mls/migrations/2024-06-20-061128_key_update/up.sql new file mode 100644 index 000000000..5df66c13c --- /dev/null +++ b/xmtp_mls/migrations/2024-06-20-061128_key_update/up.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups +ADD COLUMN rotated_at_ns BIGINT NOT NULL DEFAULT 0 \ No newline at end of file diff --git a/xmtp_mls/migrations/README.md b/xmtp_mls/migrations/README.md index 6612df812..8983f3297 100644 --- a/xmtp_mls/migrations/README.md +++ b/xmtp_mls/migrations/README.md @@ -6,6 +6,8 @@ cargo install diesel_cli --no-default-features --features sqlite ``` +### Change directory to libxmtp/xmtp_mls/ + ### Create your migration SQL In this example the migration is called `create_key_store`: diff --git a/xmtp_mls/migrations/update-lJTWiV2PuD3QOanK.db3-shm b/xmtp_mls/migrations/update-lJTWiV2PuD3QOanK.db3-shm new file mode 100644 index 000000000..fe9ac2845 Binary files /dev/null and b/xmtp_mls/migrations/update-lJTWiV2PuD3QOanK.db3-shm differ diff --git a/xmtp_mls/migrations/update-lJTWiV2PuD3QOanK.db3-wal b/xmtp_mls/migrations/update-lJTWiV2PuD3QOanK.db3-wal new file mode 100644 index 000000000..e69de29bb diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6f5bc6246..9e72c79af 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -40,6 +40,7 @@ use xmtp_proto::xmtp::mls::api::v1::{ use crate::{ api::ApiClientWrapper, + configuration::SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, groups::{ group_permissions::PolicySet, validated_commit::CommitValidationError, GroupError, GroupMetadataOptions, IntentError, MlsGroup, @@ -756,7 +757,11 @@ where ); if mls_group.is_active() { group - .maybe_update_installations(provider_ref, None, self) + .maybe_update_installations( + provider_ref, + SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, + self, + ) .await?; group.sync_with_conn(provider_ref, self).await?; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 1bf4775d3..78862853f 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -11,6 +11,7 @@ mod subscriptions; mod sync; pub mod validated_commit; +use async_recursion::async_recursion; use intents::SendMessageIntentData; use openmls::{ credentials::{BasicCredential, CredentialType}, @@ -86,7 +87,7 @@ use crate::{ consent_record::{ConsentState, ConsentType, StoredConsentRecord}, db_connection::DbConnection, group::{GroupMembershipState, Purpose, StoredGroup}, - group_intent::{IntentKind, NewGroupIntent}, + group_intent::{IntentKind, NewGroupIntent, PreIntentComplete, StoredGroupIntent}, group_message::{DeliveryStatus, GroupMessageKind, StoredGroupMessage}, sql_key_store, }, @@ -316,7 +317,7 @@ impl MlsGroup { )?; let group_id = mls_group.group_id().to_vec(); - let stored_group = StoredGroup::new( + let stored_group = StoredGroup::new_as_creator( group_id.clone(), now_ns(), membership_state, @@ -444,8 +445,11 @@ impl MlsGroup { )?; let group_id = mls_group.group_id().to_vec(); - let stored_group = - StoredGroup::new_sync_group(group_id.clone(), now_ns(), GroupMembershipState::Allowed); + let stored_group = StoredGroup::new_sync_group_as_creator( + group_id.clone(), + now_ns(), + GroupMembershipState::Allowed, + ); stored_group.store(provider.conn_ref())?; @@ -465,12 +469,7 @@ impl MlsGroup { where ApiClient: XmtpApi, { - let update_interval_ns = Some(SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS); - let conn = self.context.store.conn()?; - let provider = XmtpOpenMlsProvider::from(conn); - self.maybe_update_installations(&provider, update_interval_ns, client) - .await?; - + let provider = client.mls_provider()?; let message_id = self.prepare_message(message, provider.conn_ref(), |now| { Self::into_envelope(message, now) }); @@ -494,9 +493,6 @@ impl MlsGroup { { let conn = self.context.store.conn()?; let provider = XmtpOpenMlsProvider::from(conn); - let update_interval_ns = Some(SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS); - self.maybe_update_installations(&provider, update_interval_ns, client) - .await?; self.sync_until_last_intent_resolved(&provider, client) .await?; @@ -516,7 +512,7 @@ impl MlsGroup { { let conn = self.context.store.conn()?; let provider = XmtpOpenMlsProvider::from(conn); - self.maybe_update_installations(&provider, Some(0), client) + self.maybe_update_installations(&provider, 0 /*interval_ns*/, client) .await?; Ok(()) } @@ -555,7 +551,16 @@ impl MlsGroup { let intent_data: Vec = SendMessageIntentData::new(encoded_envelope).into(); let intent = NewGroupIntent::new(IntentKind::SendMessage, self.group_id.clone(), intent_data); - intent.store(conn)?; + // TODO(rich) Plumb client, call pre-intent hook without syncing + // Issue: maybe_update_installations is async because it needs a network call to figure out the + // members to include on the intent_data. We can't block prepare_message on a network call. + // Ideas: + // - Figure out the intent_data at time of publish + // - Update installations at time of publish, AFTER intent is inserted (i.e. goes through after message is sent) + // - Make prepare_message/send_message do everything in another thread + // Worth checking: If both key rotation and update intent happen at the same time, will update intent commit + // use wrong encryption at time of publish? + conn.insert_group_intent(intent, PreIntentComplete {})?; // store this unpublished message locally before sending let message_id = calculate_message_id(&self.group_id, message, &now.to_string()); @@ -645,6 +650,7 @@ impl MlsGroup { .await } + // Before calling this function, please verify pre_intent_hook has been called. #[tracing::instrument(level = "trace", skip_all)] pub async fn add_members_by_inbox_id( &self, @@ -664,13 +670,14 @@ impl MlsGroup { return Ok(()); } - let intent = provider - .conn_ref() - .insert_group_intent(NewGroupIntent::new( + let intent = provider.conn_ref().insert_group_intent( + NewGroupIntent::new( IntentKind::UpdateGroupMembership, self.group_id.clone(), intent_data.into(), - ))?; + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&provider, intent.id, client) .await @@ -693,19 +700,19 @@ impl MlsGroup { client: &Client, inbox_ids: Vec, ) -> Result<(), GroupError> { - let provider = client.store().conn()?.into(); - + let provider = client.mls_provider()?; let intent_data = self .get_membership_update_intent(client, &provider, vec![], inbox_ids) .await?; - let intent = provider - .conn_ref() - .insert_group_intent(NewGroupIntent::new( + let intent = provider.conn_ref().insert_group_intent( + NewGroupIntent::new( IntentKind::UpdateGroupMembership, self.group_id.clone(), intent_data.into(), - ))?; + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&provider, intent.id, client) .await @@ -722,11 +729,15 @@ impl MlsGroup { let conn = self.context.store.conn()?; let intent_data: Vec = UpdateMetadataIntentData::new_update_group_name(group_name).into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -754,11 +765,14 @@ impl MlsGroup { ) .into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdatePermission, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdatePermission, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -788,11 +802,14 @@ impl MlsGroup { let conn = self.context.store.conn()?; let intent_data: Vec = UpdateMetadataIntentData::new_update_group_description(group_description).into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -820,14 +837,18 @@ impl MlsGroup { ApiClient: XmtpApi, { let conn = self.context.store.conn()?; + let intent_data: Vec = UpdateMetadataIntentData::new_update_group_image_url_square(group_image_url_square) .into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -860,11 +881,14 @@ impl MlsGroup { let conn = self.context.store.conn()?; let intent_data: Vec = UpdateMetadataIntentData::new_update_group_pinned_frame_url(pinned_frame_url).into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::MetadataUpdate, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::MetadataUpdate, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -935,11 +959,14 @@ impl MlsGroup { }; let intent_data: Vec = UpdateAdminListIntentData::new(intent_action_type, inbox_id).into(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateAdminList, - self.group_id.clone(), - intent_data, - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateAdminList, + self.group_id.clone(), + intent_data, + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await @@ -980,22 +1007,60 @@ impl MlsGroup { Ok(()) } + fn prepare_key_update(&self, conn: &DbConnection) -> Result { + let intent = conn.insert_group_intent( + NewGroupIntent::new(IntentKind::KeyUpdate, self.group_id.clone(), vec![]), + // This is a private method called from pre_intent_hook, hence why we pass an empty PreIntentComplete + PreIntentComplete {}, + )?; + Ok(intent) + } + // Update this installation's leaf key in the group by creating a key update commit - pub async fn key_update(&self, client: &Client) -> Result<(), GroupError> + async fn key_update(&self, client: &Client) -> Result<(), GroupError> where ApiClient: XmtpApi, { let conn = self.context.store.conn()?; - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::KeyUpdate, - self.group_id.clone(), - vec![], - ))?; - + let intent = self.prepare_key_update(&conn)?; self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } + /// Pre-intent hook for group intents. + /// + /// This is used to ensure that the group is in a consistent state before + /// performing an action. This includes rotating the encryption keys if needed, + /// as well as syncing the group's installations. + /// + /// This should be called before any intent is inserted into the database. Infinite + /// loops are avoided by checking when each action was last performed. + #[async_recursion] + pub async fn pre_intent_hook( + &self, + client: &Client, + ) -> Result + where + ApiClient: XmtpApi, + { + let conn = self.context.store.conn()?; + let last_rotated_time = conn.get_rotated_time_checked(self.group_id.clone())?; + if last_rotated_time == 0 { + conn.update_rotated_time_checked(self.group_id.clone())?; + self.key_update(client).await?; + } + + let provider: XmtpOpenMlsProvider = conn.into(); + self.maybe_update_installations( + &provider, + SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS, + client, + ) + .await?; + + Ok(PreIntentComplete {}) + } + pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { let mls_group = self.load_mls_group(provider)?; Ok(mls_group.is_active()) @@ -1314,8 +1379,14 @@ fn build_group_join_config() -> MlsGroupJoinConfig { #[cfg(test)] mod tests { + use crate::groups::GroupMessageVersion; + use crate::storage::group_intent::PreIntentComplete; use diesel::connection::SimpleConnection; use futures::future::join_all; + use openmls::prelude::tls_codec::Deserialize; + use openmls::prelude::MlsMessageBodyIn; + use openmls::prelude::MlsMessageIn; + use openmls::prelude::ProcessedMessageContent; use openmls::prelude::{tls_codec::Serialize, Member, MlsGroup as OpenMlsGroup}; use prost::Message; use std::sync::Arc; @@ -1805,6 +1876,164 @@ mod tests { assert_eq!(bola_messages.len(), 1); } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_pre_intent_hook() { + let client_a = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let client_b = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + // client A makes a group with client B. + let group = client_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + let mut messages = client_a + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + assert_eq!(messages.len(), 0); + + group + .add_members_by_inbox_id(&client_a, vec![client_b.inbox_id()]) + .await + .unwrap(); + + // client B creates it from welcome. + let client_b_group = receive_group_invite(&client_b).await; + client_b_group.sync(&client_b).await.unwrap(); + + // verify no new payloads on client A. + messages = client_a + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + assert_eq!(messages.len(), 3); + + // call pre_intent_hook on client B. + client_b_group.pre_intent_hook(&client_b).await.unwrap(); + + // Verify client A receives a key rotation payload + messages = client_b + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + assert_eq!(messages.len(), 4); + + // steps to get the leaf node of the updated path. + let first_message = &messages[messages.len() - 1]; + + let msgv1 = match &first_message.version { + Some(GroupMessageVersion::V1(value)) => value, + _ => panic!("error msgv1"), + }; + + let mls_message_in = MlsMessageIn::tls_deserialize_exact(&msgv1.data).unwrap(); + let mls_message = match mls_message_in.extract() { + MlsMessageBodyIn::PrivateMessage(mls_message) => mls_message, + _ => panic!("error mls_message"), + }; + + let provider = client_a.mls_provider().unwrap(); + let mut openmls_group = group.load_mls_group(&provider).unwrap(); + let decrypted_message = openmls_group + .process_message(&provider, mls_message) + .unwrap(); + + let staged_commit = match decrypted_message.into_content() { + ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, + _ => panic!("error staged_commit"), + }; + + // check there is indeed some updated leaf node, which means the key update works. + let path_update_leaf_node = staged_commit.update_path_leaf_node(); + assert!(path_update_leaf_node.is_some()); + + // call pre_intent_hook on client B again, client A receives nothing new. + client_b_group.pre_intent_hook(&client_b).await.unwrap(); + messages = client_b + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + assert_eq!(messages.len(), 4); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_send_message_with_pre_intent_hook() { + let client_a = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let client_b = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + // client A makes a group with client B. + let group = client_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + group + .add_members_by_inbox_id(&client_a, vec![client_b.inbox_id()]) + .await + .unwrap(); + + // client B creates it from welcome + let client_b_group = receive_group_invite(&client_b).await; + client_b_group.sync(&client_b).await.unwrap(); + + // Verify no new payloads on client A + let mut messages = client_a + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + assert_eq!(messages.len(), 3); + + // Client B sends a message to Client A + let b_message = b"hello from client b"; + client_b_group + .send_message(b_message, &client_b) + .await + .expect("send message"); + + // Verify client A receives a key rotation. + messages = client_b + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + assert_eq!(messages.len(), 5); + + // Steps to get the leaf node of the updated path. + let queried_message = &messages[messages.len() - 2]; + + let msgv1 = match &queried_message.version { + Some(GroupMessageVersion::V1(value)) => value, + _ => panic!("error msgv1"), + }; + + let mls_message_in = MlsMessageIn::tls_deserialize_exact(&msgv1.data).unwrap(); + let mls_message = match mls_message_in.extract() { + MlsMessageBodyIn::PrivateMessage(mls_message) => mls_message, + _ => panic!("error mls_message"), + }; + + let provider = client_a.mls_provider().unwrap(); + let mut openmls_group = group.load_mls_group(&provider).unwrap(); + let decrypted_message = openmls_group + .process_message(&provider, mls_message) + .unwrap(); + + let staged_commit = match decrypted_message.into_content() { + ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, + _ => panic!("error staged_commit"), + }; + + // Check there is indeed some updated leaf node, which means the key update works. + let path_update_leaf_node = staged_commit.update_path_leaf_node(); + assert!(path_update_leaf_node.is_some()); + + // Verify client A receives the message. + let message = get_latest_message(&group, &client_a).await; + assert_eq!(message.decrypted_message_bytes, b_message); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_post_commit() { let client = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -3114,11 +3343,14 @@ mod tests { } let conn = provider.conn_ref(); - conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group.group_id.clone(), - intent_data.into(), - )) + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group.group_id.clone(), + intent_data.into(), + ), + PreIntentComplete {}, // No pre-intent needed in tests + ) .unwrap(); } diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index c5c10f824..fea575141 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -99,8 +99,12 @@ impl MlsGroup { client.inbox_id(), self.load_mls_group(&mls_provider)?.epoch() ); - self.maybe_update_installations(&mls_provider, None, client) - .await?; + self.maybe_update_installations( + &mls_provider, + SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, + client, + ) + .await?; self.sync_with_conn(&mls_provider, client).await } @@ -200,6 +204,7 @@ impl MlsGroup { } Ok(Some(StoredGroupIntent { id, + kind, state: IntentState::Error, .. })) => { @@ -207,9 +212,9 @@ impl MlsGroup { "not retrying intent ID {id}. since it is in state Error. {:?}", last_err ); - return Err(last_err.unwrap_or(GroupError::Generic( - "Group intent could not be committed".to_string(), - ))); + return Err(last_err.unwrap_or(GroupError::Generic(format!( + "Group intent {id} of kind {kind} could not be committed" + )))); } Ok(Some(StoredGroupIntent { id, state, .. })) => { tracing::warn!("retrying intent ID {id}. intent currently in state {state:?}"); @@ -978,28 +983,22 @@ impl MlsGroup { pub async fn maybe_update_installations( &self, provider: &XmtpOpenMlsProvider, - update_interval_ns: Option, + interval_ns: i64, client: &Client, ) -> Result<(), GroupError> where ApiClient: XmtpApi, { - // determine how long of an interval in time to use before updating list - let interval_ns = match update_interval_ns { - Some(val) => val, - None => SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, - }; - let now_ns = crate::utils::time::now_ns(); let last_ns = provider .conn_ref() .get_installations_time_checked(self.group_id.clone())?; let elapsed_ns = now_ns - last_ns; if elapsed_ns > interval_ns { - self.add_missing_installations(provider, client).await?; provider .conn_ref() .update_installations_time_checked(self.group_id.clone())?; + self.add_missing_installations(provider, client).await?; } Ok(()) @@ -1033,11 +1032,14 @@ impl MlsGroup { debug!("Adding missing installations {:?}", intent_data); let conn = provider.conn_ref(); - let intent = conn.insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - self.group_id.clone(), - intent_data.into(), - ))?; + let intent = conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + self.group_id.clone(), + intent_data.into(), + ), + self.pre_intent_hook(client).await?, + )?; self.sync_until_intent_resolved(provider, intent.id, client) .await @@ -1102,6 +1104,10 @@ impl MlsGroup { Ok(updates) })?; + provider + .conn_ref() + .update_installations_time_checked(self.group_id.clone())?; + Ok(UpdateGroupMembershipIntentData::new( changed_inbox_ids, inbox_ids_to_remove, diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 1142e7dff..4b7d03569 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -39,6 +39,7 @@ pub struct StoredGroup { pub added_by_inbox_id: String, /// The sequence id of the welcome message pub welcome_id: Option, + pub rotated_at_ns: i64, } impl_fetch!(StoredGroup, groups, Vec); @@ -62,11 +63,12 @@ impl StoredGroup { purpose, added_by_inbox_id, welcome_id: Some(welcome_id), + rotated_at_ns: 0, } } /// Create a new [`Purpose::Conversation`] group. This is the default type of group. - pub fn new( + pub fn new_as_creator( id: ID, created_at_ns: i64, membership_state: GroupMembershipState, @@ -80,12 +82,13 @@ impl StoredGroup { purpose: Purpose::Conversation, added_by_inbox_id, welcome_id: None, + rotated_at_ns: crate::utils::time::now_ns(), } } /// Create a new [`Purpose::Sync`] group. This is less common and is used to sync message history. /// TODO: Set added_by_inbox to your own inbox_id - pub fn new_sync_group( + pub fn new_sync_group_as_creator( id: ID, created_at_ns: i64, membership_state: GroupMembershipState, @@ -98,6 +101,7 @@ impl StoredGroup { purpose: Purpose::Sync, added_by_inbox_id: "".into(), welcome_id: None, + rotated_at_ns: crate::utils::time::now_ns(), } } } @@ -199,6 +203,34 @@ impl DbConnection { ))) } + pub fn get_rotated_time_checked(&self, group_id: Vec) -> Result { + let last_ts = self.raw_query(|conn| { + let ts = dsl::groups + .find(&group_id) + .select(dsl::rotated_at_ns) + .first(conn) + .optional()?; + Ok(ts) + })?; + + last_ts.ok_or(StorageError::NotFound(format!( + "installation time for group {}", + hex::encode(group_id) + ))) + } + + /// Update the 'rotated time' once we checked in pre_intent_hook + pub fn update_rotated_time_checked(&self, group_id: Vec) -> Result<(), StorageError> { + self.raw_query(|conn| { + let now = crate::utils::time::now_ns(); + diesel::update(dsl::groups.find(&group_id)) + .set(dsl::rotated_at_ns.eq(now)) + .execute(conn) + })?; + + Ok(()) + } + /// Updates the 'last time checked' we checked for new installations. pub fn update_installations_time_checked(&self, group_id: Vec) -> Result<(), StorageError> { self.raw_query(|conn| { @@ -331,7 +363,7 @@ pub(crate) mod tests { let id = rand_vec(); let created_at_ns = now_ns(); let membership_state = state.unwrap_or(GroupMembershipState::Allowed); - StoredGroup::new( + StoredGroup::new_as_creator( id, created_at_ns, membership_state, @@ -473,7 +505,8 @@ pub(crate) mod tests { let created_at_ns = now_ns(); let membership_state = GroupMembershipState::Allowed; - let sync_group = StoredGroup::new_sync_group(id, created_at_ns, membership_state); + let sync_group = + StoredGroup::new_sync_group_as_creator(id, created_at_ns, membership_state); let purpose = sync_group.purpose; assert_eq!(purpose, Purpose::Sync); diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 4ff7615e7..079507d8e 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -16,7 +16,7 @@ use super::{ }; use crate::{ groups::{intents::SendMessageIntentData, IntentError}, - impl_fetch, impl_store, + impl_fetch, storage::StorageError, utils::id::calculate_message_id, Delete, @@ -124,6 +124,11 @@ impl Delete for DbConnection { } } +/// A marker struct to indicate that the pre-intent hook has been called. +/// Can be obtained by calling `pre_intent_hook()` on a group. +#[derive(Debug)] +pub struct PreIntentComplete {} + #[derive(Insertable, Debug, PartialEq, Clone)] #[diesel(table_name = group_intents)] pub struct NewGroupIntent { @@ -133,8 +138,6 @@ pub struct NewGroupIntent { pub state: IntentState, } -impl_store!(NewGroupIntent, group_intents); - impl NewGroupIntent { pub fn new(kind: IntentKind, group_id: Vec, data: Vec) -> Self { Self { @@ -151,6 +154,7 @@ impl DbConnection { pub fn insert_group_intent( &self, to_save: NewGroupIntent, + pre_intent_complete: PreIntentComplete, ) -> Result { Ok(self.raw_query(|conn| { diesel::insert_into(dsl::group_intents) @@ -399,7 +403,7 @@ mod tests { }; fn insert_group(conn: &DbConnection, group_id: Vec) { - let group = StoredGroup::new( + let group = StoredGroup::new_as_creator( group_id, 100, GroupMembershipState::Allowed, @@ -448,7 +452,8 @@ mod tests { // Group needs to exist or FK constraint will fail insert_group(conn, group_id.clone()); - to_insert.store(conn).unwrap(); + conn.insert_group_intent(to_insert, PreIntentComplete {}) + .unwrap(); let results = conn .find_group_intents(group_id.clone(), Some(vec![IntentState::ToPublish]), None) @@ -497,7 +502,8 @@ mod tests { insert_group(conn, group_id.clone()); for case in test_intents { - case.store(conn).unwrap(); + conn.insert_group_intent(case, PreIntentComplete {}) + .unwrap(); } // Can query for multiple states @@ -553,12 +559,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); // Find the intent with the ID populated @@ -594,12 +602,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let mut intent = find_first_intent(conn, group_id.clone()); @@ -638,12 +648,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let mut intent = find_first_intent(conn, group_id.clone()); @@ -681,12 +693,14 @@ mod tests { insert_group(conn, group_id.clone()); // Store the intent - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let intent = find_first_intent(conn, group_id.clone()); @@ -712,12 +726,14 @@ mod tests { let group_id = rand_vec(); with_connection(|conn| { insert_group(conn, group_id.clone()); - NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - group_id.clone(), - rand_vec(), + conn.insert_group_intent( + NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + group_id.clone(), + rand_vec(), + ), + PreIntentComplete {}, ) - .store(conn) .unwrap(); let mut intent = find_first_intent(conn, group_id.clone()); diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 8cd5baf9e..169ee427f 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -676,7 +676,7 @@ mod tests { barrier.wait(); let result = store_pointer.transaction(|provider| -> Result<(), anyhow::Error> { let connection = provider.conn_ref(); - let group = StoredGroup::new( + let group = StoredGroup::new_as_creator( b"should not exist".to_vec(), 0, GroupMembershipState::Allowed, @@ -727,7 +727,7 @@ mod tests { .store(conn1) .unwrap(); - let group = StoredGroup::new( + let group = StoredGroup::new_as_creator( b"should not exist".to_vec(), 0, GroupMembershipState::Allowed, diff --git a/xmtp_mls/src/storage/encrypted_store/schema.rs b/xmtp_mls/src/storage/encrypted_store/schema.rs index 7d835a5b9..24f9bed88 100644 --- a/xmtp_mls/src/storage/encrypted_store/schema.rs +++ b/xmtp_mls/src/storage/encrypted_store/schema.rs @@ -53,6 +53,7 @@ diesel::table! { purpose -> Integer, added_by_inbox_id -> Text, welcome_id -> Nullable, + rotated_at_ns -> BigInt, } }