From 70414d947c35b8ae61b06cd9129157104aef76d8 Mon Sep 17 00:00:00 2001 From: cameronvoell Date: Thu, 20 Jun 2024 19:00:05 -0700 Subject: [PATCH] pass back mlsgroup before we rollback transaction so that we can return groups on duplicate welcome messages --- bindings_ffi/src/mls.rs | 8 +-- bindings_node/Cargo.lock | 10 ++++ xmtp_mls/src/client.rs | 9 +++- xmtp_mls/src/groups/mod.rs | 27 ++++++---- xmtp_mls/src/storage/encrypted_store/group.rs | 40 +++++++++++--- xmtp_mls/src/subscriptions.rs | 54 +++++++++++++++---- 6 files changed, 116 insertions(+), 32 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index c4700f4ed..94d43b21e 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1514,9 +1514,9 @@ mod tests { .await .unwrap(); alix_group.send("hello1".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + // tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; bo.conversations().sync().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + // tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; let bo_groups = bo .conversations() @@ -1550,8 +1550,8 @@ mod tests { .unwrap(); assert_eq!(bo_messages2.len(), second_msg_check); - // tokio::time::sleep(tokio::time::Duration::from_millis(10000)).await; - // assert_eq!(message_callbacks.message_count(), 5); + tokio::time::sleep(tokio::time::Duration::from_millis(2000)).await; + assert_eq!(message_callbacks.message_count(), 5); stream_messages.end(); tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index 5b989fca4..8d712d2ab 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -285,6 +285,15 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d86b93f97252c47b41663388e6d155714a9d0c398b99f1005cbc5f978b29f445" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bindings_node" version = "0.0.1" @@ -5418,6 +5427,7 @@ dependencies = [ "aes", "aes-gcm", "async-trait", + "bincode", "chrono", "diesel", "diesel_migrations", diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 9d4ecb8e6..7703c0b9e 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -487,7 +487,14 @@ where ) .await { - Ok(mls_group) => Ok(Some(mls_group)), + Ok(create_from_welcome_result) => { + if create_from_welcome_result.requires_rollback { + return Err(MessageProcessingError::WelcomeProcessing( + "failed to create group from welcome".to_string(), + )); + } + Ok(Some(create_from_welcome_result.group)) + } Err(err) => { log::error!("failed to create group from welcome: {}", err); Err(MessageProcessingError::WelcomeProcessing(err.to_string())) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 17f4e435a..c84c91690 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -218,6 +218,11 @@ pub enum UpdateAdminListType { RemoveSuper, } +pub struct CreateFromWelcomeResult { + pub group: MlsGroup, + pub requires_rollback: bool, +} + impl MlsGroup { // Creates a new group instance. Does not validate that the group exists in the DB pub fn new(context: Arc, group_id: Vec, created_at_ns: i64) -> Self { @@ -295,7 +300,7 @@ impl MlsGroup { welcome: MlsWelcome, added_by_inbox: String, welcome_id: i64, - ) -> Result { + ) -> Result { let mls_welcome = StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)?; @@ -325,13 +330,17 @@ impl MlsGroup { validate_initial_group_membership(client, provider.conn_ref(), &mls_group).await?; - let stored_group = provider.conn().insert_or_replace_group(to_store)?; + let insert_result = provider.conn().insert_or_replace_group(to_store)?; + let stored_group = insert_result.group; - Ok(Self::new( - client.context.clone(), - stored_group.id, - stored_group.created_at_ns, - )) + Ok(CreateFromWelcomeResult { + group: Self::new( + client.context.clone(), + stored_group.id, + stored_group.created_at_ns, + ), + requires_rollback: insert_result.requires_rollback, + }) } // Decrypt a welcome message using HPKE and then create and save a group from the stored message @@ -341,14 +350,14 @@ impl MlsGroup { hpke_public_key: &[u8], encrypted_welcome_bytes: Vec, welcome_id: i64, - ) -> Result { + ) -> Result { let welcome_bytes = decrypt_welcome(provider, hpke_public_key, &encrypted_welcome_bytes)?; let welcome = deserialize_welcome(&welcome_bytes)?; let join_config = build_group_join_config(); - let processed_welcome = + let processed_welcome: ProcessedWelcome = ProcessedWelcome::new_from_welcome(provider, &join_config, welcome.clone())?; let psks = processed_welcome.psks(); if !psks.is_empty() { diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index b0615c2cf..6f3d8a7d7 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -41,6 +41,11 @@ pub struct StoredGroup { pub welcome_id: Option, } +pub struct InsertOrReplaceGroupResult { + pub group: StoredGroup, + pub requires_rollback: bool, +} + impl_fetch!(StoredGroup, groups, Vec); impl_store!(StoredGroup, groups); @@ -196,7 +201,10 @@ impl DbConnection { Ok(()) } - pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result { + pub fn insert_or_replace_group( + &self, + group: StoredGroup, + ) -> Result { let stored_group = self.raw_query(|conn| { let maybe_inserted_group: Option = diesel::insert_into(dsl::groups) .values(&group) @@ -208,18 +216,34 @@ impl DbConnection { let existing_group: StoredGroup = dsl::groups.find(group.id).first(conn).unwrap(); if existing_group.welcome_id == group.welcome_id { // Error so OpenMLS db transaction are rolled back on duplicate welcomes - return Err(diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::UniqueViolation, - Box::new("welcome id already exists".to_string()), - )); + return Ok(InsertOrReplaceGroupResult { + group: existing_group, + requires_rollback: true, + }); + // return Err(diesel::result::Error::DatabaseError( + // diesel::result::DatabaseErrorKind::UniqueViolation, + // Box::new("welcome id already exists".to_string()), + // )); } else { - return Ok(existing_group); + return Ok(InsertOrReplaceGroupResult { + group: existing_group, + requires_rollback: false, + }); } } match maybe_inserted_group { - Some(group) => Ok(group), - None => dsl::groups.find(group.id).first(conn), + Some(group) => Ok(InsertOrReplaceGroupResult { + group, + requires_rollback: false, + }), + None => { + let group = dsl::groups.find(group.id).first(conn)?; + Ok(InsertOrReplaceGroupResult { + group, + requires_rollback: false, + }) + } } })?; diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index c990510ff..3957008ea 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -1,3 +1,4 @@ +use futures::lock::Mutex as AsyncMutex; use std::{ collections::HashMap, pin::Pin, @@ -58,23 +59,56 @@ where welcome: WelcomeMessage, ) -> Result { let welcome_v1 = extract_welcome_message(welcome)?; + let failed_due_to_rollback = Arc::new(AsyncMutex::new(false)); + let existing_group = Arc::new(AsyncMutex::new(None)); + + let failed_due_to_rollback_clone = Arc::clone(&failed_due_to_rollback); + let existing_group_clone = Arc::clone(&existing_group); let creation_result = self .context .store - .transaction_async(|provider| async move { - MlsGroup::create_from_encrypted_welcome( - self, - &provider, - welcome_v1.hpke_public_key.as_slice(), - welcome_v1.data, - welcome_v1.id as i64, - ) - .await + .transaction_async(|provider| { + // let self_ref = self_ref; + let welcome_v1 = welcome_v1.clone(); + println!("welcome_v1 ID is THIS!: {:?}", welcome_v1.id); + let failed_due_to_rollback_clone = Arc::clone(&failed_due_to_rollback_clone); + let existing_group_clone = Arc::clone(&existing_group_clone); + + async move { + let create_from_welcome_result = MlsGroup::create_from_encrypted_welcome( + self, + &provider, + welcome_v1.hpke_public_key.as_slice(), + welcome_v1.data, + welcome_v1.id as i64, + ) + .await; + + if let Ok(create_from_welcome_result) = create_from_welcome_result { + if create_from_welcome_result.requires_rollback { + *failed_due_to_rollback_clone.lock().await = true; + *existing_group_clone.lock().await = + Some(create_from_welcome_result.group); + return Err(ClientError::Generic( + "failed to create group from welcome".to_string(), + )); + } + Ok(create_from_welcome_result.group) + } else { + Err(ClientError::Generic( + "failed to create group from welcome".to_string(), + )) + } + } }) .await; - if let Some(err) = creation_result.as_ref().err() { + if let Err(err) = creation_result { + if *failed_due_to_rollback.lock().await { + log::info!("returning existing group"); + return Ok(existing_group.lock().await.take().unwrap()); + } return Err(ClientError::Generic(err.to_string())); }