Skip to content

Commit

Permalink
pass back mlsgroup before we rollback transaction so that we can retu…
Browse files Browse the repository at this point in the history
…rn groups on duplicate welcome messages
  • Loading branch information
cameronvoell committed Jun 21, 2024
1 parent 43e21e0 commit 70414d9
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 32 deletions.
8 changes: 4 additions & 4 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1514,9 +1514,9 @@ mod tests {
.await
.unwrap();
alix_group.send("hello1".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
// tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
bo.conversations().sync().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
// tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;

let bo_groups = bo
.conversations()
Expand Down Expand Up @@ -1550,8 +1550,8 @@ mod tests {
.unwrap();
assert_eq!(bo_messages2.len(), second_msg_check);

// tokio::time::sleep(tokio::time::Duration::from_millis(10000)).await;
// assert_eq!(message_callbacks.message_count(), 5);
tokio::time::sleep(tokio::time::Duration::from_millis(2000)).await;
assert_eq!(message_callbacks.message_count(), 5);

stream_messages.end();
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
Expand Down
10 changes: 10 additions & 0 deletions bindings_node/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,14 @@ where
)
.await
{
Ok(mls_group) => Ok(Some(mls_group)),
Ok(create_from_welcome_result) => {
if create_from_welcome_result.requires_rollback {
return Err(MessageProcessingError::WelcomeProcessing(
"failed to create group from welcome".to_string(),
));
}
Ok(Some(create_from_welcome_result.group))
}
Err(err) => {
log::error!("failed to create group from welcome: {}", err);
Err(MessageProcessingError::WelcomeProcessing(err.to_string()))
Expand Down
27 changes: 18 additions & 9 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ pub enum UpdateAdminListType {
RemoveSuper,
}

pub struct CreateFromWelcomeResult {
pub group: MlsGroup,
pub requires_rollback: bool,
}

impl MlsGroup {
// Creates a new group instance. Does not validate that the group exists in the DB
pub fn new(context: Arc<XmtpMlsLocalContext>, group_id: Vec<u8>, created_at_ns: i64) -> Self {
Expand Down Expand Up @@ -295,7 +300,7 @@ impl MlsGroup {
welcome: MlsWelcome,
added_by_inbox: String,
welcome_id: i64,
) -> Result<Self, GroupError> {
) -> Result<CreateFromWelcomeResult, GroupError> {
let mls_welcome =
StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)?;

Expand Down Expand Up @@ -325,13 +330,17 @@ impl MlsGroup {

validate_initial_group_membership(client, provider.conn_ref(), &mls_group).await?;

let stored_group = provider.conn().insert_or_replace_group(to_store)?;
let insert_result = provider.conn().insert_or_replace_group(to_store)?;
let stored_group = insert_result.group;

Ok(Self::new(
client.context.clone(),
stored_group.id,
stored_group.created_at_ns,
))
Ok(CreateFromWelcomeResult {
group: Self::new(
client.context.clone(),
stored_group.id,
stored_group.created_at_ns,
),
requires_rollback: insert_result.requires_rollback,
})
}

// Decrypt a welcome message using HPKE and then create and save a group from the stored message
Expand All @@ -341,14 +350,14 @@ impl MlsGroup {
hpke_public_key: &[u8],
encrypted_welcome_bytes: Vec<u8>,
welcome_id: i64,
) -> Result<Self, GroupError> {
) -> Result<CreateFromWelcomeResult, GroupError> {
let welcome_bytes = decrypt_welcome(provider, hpke_public_key, &encrypted_welcome_bytes)?;

let welcome = deserialize_welcome(&welcome_bytes)?;

let join_config = build_group_join_config();

let processed_welcome =
let processed_welcome: ProcessedWelcome =
ProcessedWelcome::new_from_welcome(provider, &join_config, welcome.clone())?;
let psks = processed_welcome.psks();
if !psks.is_empty() {
Expand Down
40 changes: 32 additions & 8 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ pub struct StoredGroup {
pub welcome_id: Option<i64>,
}

pub struct InsertOrReplaceGroupResult {
pub group: StoredGroup,
pub requires_rollback: bool,
}

impl_fetch!(StoredGroup, groups, Vec<u8>);
impl_store!(StoredGroup, groups);

Expand Down Expand Up @@ -196,7 +201,10 @@ impl DbConnection {
Ok(())
}

pub fn insert_or_replace_group(&self, group: StoredGroup) -> Result<StoredGroup, StorageError> {
pub fn insert_or_replace_group(
&self,
group: StoredGroup,
) -> Result<InsertOrReplaceGroupResult, StorageError> {
let stored_group = self.raw_query(|conn| {
let maybe_inserted_group: Option<StoredGroup> = diesel::insert_into(dsl::groups)
.values(&group)
Expand All @@ -208,18 +216,34 @@ impl DbConnection {
let existing_group: StoredGroup = dsl::groups.find(group.id).first(conn).unwrap();
if existing_group.welcome_id == group.welcome_id {
// Error so OpenMLS db transaction are rolled back on duplicate welcomes
return Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
Box::new("welcome id already exists".to_string()),
));
return Ok(InsertOrReplaceGroupResult {
group: existing_group,
requires_rollback: true,
});
// return Err(diesel::result::Error::DatabaseError(
// diesel::result::DatabaseErrorKind::UniqueViolation,
// Box::new("welcome id already exists".to_string()),
// ));
} else {
return Ok(existing_group);
return Ok(InsertOrReplaceGroupResult {
group: existing_group,
requires_rollback: false,
});
}
}

match maybe_inserted_group {
Some(group) => Ok(group),
None => dsl::groups.find(group.id).first(conn),
Some(group) => Ok(InsertOrReplaceGroupResult {
group,
requires_rollback: false,
}),
None => {
let group = dsl::groups.find(group.id).first(conn)?;
Ok(InsertOrReplaceGroupResult {
group,
requires_rollback: false,
})
}
}
})?;

Expand Down
54 changes: 44 additions & 10 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use futures::lock::Mutex as AsyncMutex;
use std::{
collections::HashMap,
pin::Pin,
Expand Down Expand Up @@ -58,23 +59,56 @@ where
welcome: WelcomeMessage,
) -> Result<MlsGroup, ClientError> {
let welcome_v1 = extract_welcome_message(welcome)?;
let failed_due_to_rollback = Arc::new(AsyncMutex::new(false));
let existing_group = Arc::new(AsyncMutex::new(None));

let failed_due_to_rollback_clone = Arc::clone(&failed_due_to_rollback);
let existing_group_clone = Arc::clone(&existing_group);

let creation_result = self
.context
.store
.transaction_async(|provider| async move {
MlsGroup::create_from_encrypted_welcome(
self,
&provider,
welcome_v1.hpke_public_key.as_slice(),
welcome_v1.data,
welcome_v1.id as i64,
)
.await
.transaction_async(|provider| {
// let self_ref = self_ref;
let welcome_v1 = welcome_v1.clone();
println!("welcome_v1 ID is THIS!: {:?}", welcome_v1.id);
let failed_due_to_rollback_clone = Arc::clone(&failed_due_to_rollback_clone);
let existing_group_clone = Arc::clone(&existing_group_clone);

async move {
let create_from_welcome_result = MlsGroup::create_from_encrypted_welcome(
self,
&provider,
welcome_v1.hpke_public_key.as_slice(),
welcome_v1.data,
welcome_v1.id as i64,
)
.await;

if let Ok(create_from_welcome_result) = create_from_welcome_result {
if create_from_welcome_result.requires_rollback {
*failed_due_to_rollback_clone.lock().await = true;
*existing_group_clone.lock().await =
Some(create_from_welcome_result.group);
return Err(ClientError::Generic(
"failed to create group from welcome".to_string(),
));
}
Ok(create_from_welcome_result.group)
} else {
Err(ClientError::Generic(
"failed to create group from welcome".to_string(),
))
}
}
})
.await;

if let Some(err) = creation_result.as_ref().err() {
if let Err(err) = creation_result {
if *failed_due_to_rollback.lock().await {
log::info!("returning existing group");
return Ok(existing_group.lock().await.take().unwrap());
}
return Err(ClientError::Generic(err.to_string()));
}

Expand Down

0 comments on commit 70414d9

Please sign in to comment.