Skip to content

Commit

Permalink
fix(group): make MLS group thread safe #1349 (#1404)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Andrew Plaza <[email protected]>
Co-authored-by: Dakota Brink <github@codabrink>
Co-authored-by: Ry Racherbaumer <[email protected]>
  • Loading branch information
4 people authored Dec 17, 2024
1 parent 5e0761f commit b5f3237
Show file tree
Hide file tree
Showing 16 changed files with 813 additions and 616 deletions.
59 changes: 50 additions & 9 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1317,13 +1317,13 @@ impl FfiConversation {
Ok(())
}

pub fn find_messages(
pub async fn find_messages(
&self,
opts: FfiListMessagesOptions,
) -> Result<Vec<FfiMessage>, GenericError> {
let delivery_status = opts.delivery_status.map(|status| status.into());
let direction = opts.direction.map(|dir| dir.into());
let kind = match self.conversation_type()? {
let kind = match self.conversation_type().await? {
FfiConversationType::Group => None,
FfiConversationType::Dm => Some(GroupMessageKind::Application),
FfiConversationType::Sync => None,
Expand Down Expand Up @@ -1445,7 +1445,7 @@ impl FfiConversation {

pub fn group_image_url_square(&self) -> Result<String, GenericError> {
let provider = self.inner.mls_provider()?;
Ok(self.inner.group_image_url_square(provider)?)
Ok(self.inner.group_image_url_square(&provider)?)
}

pub async fn update_group_description(
Expand All @@ -1461,7 +1461,7 @@ impl FfiConversation {

pub fn group_description(&self) -> Result<String, GenericError> {
let provider = self.inner.mls_provider()?;
Ok(self.inner.group_description(provider)?)
Ok(self.inner.group_description(&provider)?)
}

pub async fn update_group_pinned_frame_url(
Expand Down Expand Up @@ -1593,9 +1593,9 @@ impl FfiConversation {
self.inner.added_by_inbox_id().map_err(Into::into)
}

pub fn group_metadata(&self) -> Result<Arc<FfiConversationMetadata>, GenericError> {
pub async fn group_metadata(&self) -> Result<Arc<FfiConversationMetadata>, GenericError> {
let provider = self.inner.mls_provider()?;
let metadata = self.inner.metadata(provider)?;
let metadata = self.inner.metadata(&provider).await?;
Ok(Arc::new(FfiConversationMetadata {
inner: Arc::new(metadata),
}))
Expand All @@ -1605,9 +1605,9 @@ impl FfiConversation {
self.inner.dm_inbox_id().map_err(Into::into)
}

pub fn conversation_type(&self) -> Result<FfiConversationType, GenericError> {
pub async fn conversation_type(&self) -> Result<FfiConversationType, GenericError> {
let provider = self.inner.mls_provider()?;
let conversation_type = self.inner.conversation_type(&provider)?;
let conversation_type = self.inner.conversation_type(&provider).await?;
Ok(conversation_type.into())
}
}
Expand Down Expand Up @@ -2104,6 +2104,9 @@ mod tests {
.await
.unwrap();

let conn = client.inner_client.context().store().conn().unwrap();
conn.register_triggers();

register_client(&ffi_inbox_owner, &client).await;
client
}
Expand Down Expand Up @@ -2595,6 +2598,8 @@ mod tests {
async fn test_can_stream_group_messages_for_updates() {
let alix = new_test_client().await;
let bo = new_test_client().await;
let alix_provider = alix.inner_client.mls_provider().unwrap();
let bo_provider = bo.inner_client.mls_provider().unwrap();

// Stream all group messages
let message_callbacks = Arc::new(RustStreamCallback::default());
Expand Down Expand Up @@ -2627,21 +2632,29 @@ mod tests {
.unwrap();
let bo_group = &bo_groups[0];
bo_group.sync().await.unwrap();

// alix published + processed group creation and name update
assert_eq!(alix_provider.conn_ref().intents_published(), 2);
assert_eq!(alix_provider.conn_ref().intents_deleted(), 2);

bo_group
.update_group_name("Old Name2".to_string())
.await
.unwrap();
message_callbacks.wait_for_delivery(None).await.unwrap();
assert_eq!(bo_provider.conn_ref().intents_published(), 1);

alix_group.send(b"Hello there".to_vec()).await.unwrap();
message_callbacks.wait_for_delivery(None).await.unwrap();
assert_eq!(alix_provider.conn_ref().intents_published(), 3);

let dm = bo
.conversations()
.create_dm(alix.account_address.clone())
.await
.unwrap();
dm.send(b"Hello again".to_vec()).await.unwrap();
assert_eq!(bo_provider.conn_ref().intents_published(), 3);
message_callbacks.wait_for_delivery(None).await.unwrap();

// Uncomment the following lines to add more group name updates
Expand All @@ -2650,6 +2663,8 @@ mod tests {
.await
.unwrap();
message_callbacks.wait_for_delivery(None).await.unwrap();
message_callbacks.wait_for_delivery(None).await.unwrap();
assert_eq!(bo_provider.conn_ref().intents_published(), 4);

assert_eq!(message_callbacks.message_count(), 6);

Expand Down Expand Up @@ -2693,9 +2708,11 @@ mod tests {

let bo_messages1 = bo_group1
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_messages5 = bo_group5
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(bo_messages1.len(), 0);
assert_eq!(bo_messages5.len(), 0);
Expand All @@ -2707,9 +2724,11 @@ mod tests {

let bo_messages1 = bo_group1
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_messages5 = bo_group5
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(bo_messages1.len(), 1);
assert_eq!(bo_messages5.len(), 1);
Expand Down Expand Up @@ -2828,11 +2847,13 @@ mod tests {
alix_group.sync().await.unwrap();
let alix_messages = alix_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();

bo_group.sync().await.unwrap();
let bo_messages = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(bo_messages.len(), 9);
assert_eq!(alix_messages.len(), 10);
Expand Down Expand Up @@ -3016,15 +3037,19 @@ mod tests {
// Get the message count for all the clients
let caro_messages = caro_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let alix_messages = alix_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_messages = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo2_messages = bo2_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();

assert_eq!(caro_messages.len(), 5);
Expand Down Expand Up @@ -3080,9 +3105,11 @@ mod tests {

let alix_messages = alix_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_messages = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();

let alix_can_see_bo_message = alix_messages
Expand Down Expand Up @@ -3189,6 +3216,7 @@ mod tests {

let bo_messages = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(bo_messages.len(), 0);

Expand All @@ -3204,8 +3232,12 @@ mod tests {

let bo_messages = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert!(bo_messages.first().unwrap().kind == FfiConversationMessageKind::MembershipChange);
assert_eq!(
bo_messages.first().unwrap().kind,
FfiConversationMessageKind::MembershipChange
);
assert_eq!(bo_messages.len(), 1);

let bo_members = bo_group.list_members().await.unwrap();
Expand Down Expand Up @@ -3263,6 +3295,7 @@ mod tests {

let bo_messages1 = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(bo_messages1.len(), first_msg_check);

Expand All @@ -3275,6 +3308,7 @@ mod tests {

let alix_messages = alix_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(alix_messages.len(), second_msg_check);

Expand All @@ -3284,6 +3318,7 @@ mod tests {

let bo_messages2 = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(bo_messages2.len(), second_msg_check);
assert_eq!(message_callbacks.message_count(), second_msg_check as u32);
Expand Down Expand Up @@ -4529,15 +4564,19 @@ mod tests {
// Get messages for both participants in both conversations
let alix_dm_messages = alix_dm
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_dm_messages = bo_dm
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let alix_group_messages = alix_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_group_messages = bo_group
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();

// Verify DM messages
Expand Down Expand Up @@ -4658,13 +4697,15 @@ mod tests {
.await
.unwrap()[0]
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
let bo_dm_messages = client_b
.conversations()
.list(FfiListConversationsOptions::default())
.await
.unwrap()[0]
.find_messages(FfiListMessagesOptions::default())
.await
.unwrap();
assert_eq!(alix_dm_messages[0].content, "Hello in DM".as_bytes());
assert_eq!(bo_dm_messages[0].content, "Hello in DM".as_bytes());
Expand Down
22 changes: 12 additions & 10 deletions bindings_node/src/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl Conversation {
}

#[napi]
pub fn find_messages(&self, opts: Option<ListMessagesOptions>) -> Result<Vec<Message>> {
pub async fn find_messages(&self, opts: Option<ListMessagesOptions>) -> Result<Vec<Message>> {
let opts = opts.unwrap_or_default();
let group = MlsGroup::new(
self.inner_client.clone(),
Expand All @@ -171,6 +171,7 @@ impl Conversation {
let provider = group.mls_provider().map_err(ErrorWrapper::from)?;
let conversation_type = group
.conversation_type(&provider)
.await
.map_err(ErrorWrapper::from)?;
let kind = match conversation_type {
ConversationType::Group => None,
Expand Down Expand Up @@ -250,7 +251,7 @@ impl Conversation {
);

let admin_list = group
.admin_list(group.mls_provider().map_err(ErrorWrapper::from)?)
.admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?;

Ok(admin_list)
Expand All @@ -265,7 +266,7 @@ impl Conversation {
);

let super_admin_list = group
.super_admin_list(group.mls_provider().map_err(ErrorWrapper::from)?)
.super_admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?;

Ok(super_admin_list)
Expand Down Expand Up @@ -451,7 +452,7 @@ impl Conversation {
);

let group_name = group
.group_name(group.mls_provider().map_err(ErrorWrapper::from)?)
.group_name(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?;

Ok(group_name)
Expand Down Expand Up @@ -482,7 +483,7 @@ impl Conversation {
);

let group_image_url_square = group
.group_image_url_square(group.mls_provider().map_err(ErrorWrapper::from)?)
.group_image_url_square(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?;

Ok(group_image_url_square)
Expand Down Expand Up @@ -513,7 +514,7 @@ impl Conversation {
);

let group_description = group
.group_description(group.mls_provider().map_err(ErrorWrapper::from)?)
.group_description(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?;

Ok(group_description)
Expand Down Expand Up @@ -544,7 +545,7 @@ impl Conversation {
);

let group_pinned_frame_url = group
.group_pinned_frame_url(group.mls_provider().map_err(ErrorWrapper::from)?)
.group_pinned_frame_url(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?;

Ok(group_pinned_frame_url)
Expand Down Expand Up @@ -587,7 +588,7 @@ impl Conversation {

Ok(
group
.is_active(group.mls_provider().map_err(ErrorWrapper::from)?)
.is_active(&group.mls_provider().map_err(ErrorWrapper::from)?)
.map_err(ErrorWrapper::from)?,
)
}
Expand All @@ -604,15 +605,16 @@ impl Conversation {
}

#[napi]
pub fn group_metadata(&self) -> Result<GroupMetadata> {
pub async fn group_metadata(&self) -> Result<GroupMetadata> {
let group = MlsGroup::new(
self.inner_client.clone(),
self.group_id.clone(),
self.created_at_ns,
);

let metadata = group
.metadata(group.mls_provider().map_err(ErrorWrapper::from)?)
.metadata(&group.mls_provider().map_err(ErrorWrapper::from)?)
.await
.map_err(ErrorWrapper::from)?;

Ok(GroupMetadata { inner: metadata })
Expand Down
Loading

0 comments on commit b5f3237

Please sign in to comment.