Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Dec 19, 2024
1 parent be69a8d commit e7f933b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 17 deletions.
1 change: 1 addition & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ where
let query_args = GroupQueryArgs {
consent_state,
include_sync_groups: true,
include_duplicate_dms: true,
..GroupQueryArgs::default()
};
let groups = provider
Expand Down
3 changes: 1 addition & 2 deletions xmtp_mls/src/groups/device_sync/message_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ where
&self,
conn: &DbConnection,
) -> Result<Vec<Syncable>, DeviceSyncError> {
let groups =
conn.find_groups(GroupQueryArgs::default().conversation_type(ConversationType::Group))?;
let groups = conn.find_groups(GroupQueryArgs::default())?;

let mut all_messages = vec![];
for StoredGroup { id, .. } in groups.into_iter() {
Expand Down
71 changes: 69 additions & 2 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,16 +1670,23 @@ pub(crate) mod tests {
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);

use diesel::connection::SimpleConnection;
use diesel::query_dsl::methods::FilterDsl;
use diesel::{ExpressionMethods, RunQueryDsl};
use futures::future::join_all;
use prost::Message;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use wasm_bindgen_test::wasm_bindgen_test;
use xmtp_common::assert_err;
use xmtp_content_types::{group_updated::GroupUpdatedCodec, ContentCodec};
use xmtp_cryptography::utils::generate_local_wallet;
use xmtp_proto::xmtp::mls::api::v1::group_message::Version;
use xmtp_proto::xmtp::mls::message_contents::plaintext_envelope::Content;
use xmtp_proto::xmtp::mls::message_contents::EncodedContent;
use xmtp_proto::xmtp::mls::{
api::v1::group_message::Version, message_contents::PlaintextEnvelope,
};

use crate::storage::group::StoredGroup;
use crate::storage::schema::{group_messages, groups};
use crate::{
builder::ClientBuilder,
groups::{
Expand Down Expand Up @@ -2043,6 +2050,66 @@ pub(crate) mod tests {
});
}

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))]
async fn test_dm_stitching() {
let alix_wallet = generate_local_wallet();
let alix = ClientBuilder::new_test_client(&alix_wallet).await;
let alix_provider = alix.mls_provider().unwrap();
let alix_conn = alix_provider.conn_ref();

let bo_wallet = generate_local_wallet();
let bo = ClientBuilder::new_test_client(&bo_wallet).await;
let bo_provider = bo.mls_provider().unwrap();

let bo_dm = bo
.create_dm_by_inbox_id(&bo_provider, alix.inbox_id().to_string())
.await
.unwrap();
let alix_dm = alix
.create_dm_by_inbox_id(&alix_provider, bo.inbox_id().to_string())
.await
.unwrap();

bo_dm.send_message(b"Hello there").await.unwrap();
std::thread::sleep(Duration::from_millis(20));
alix_dm
.send_message(b"No, let's use this dm")
.await
.unwrap();

alix.sync_all_welcomes_and_groups(&alix_provider, None)
.await
.unwrap();

// The dm shows up
let alix_groups = alix_conn
.raw_query(|conn| groups::table.load::<StoredGroup>(conn))
.unwrap();
assert_eq!(alix_groups.len(), 2);
// They should have the same ID
assert_eq!(alix_groups[0].dm_id, alix_groups[1].dm_id);

// The dm is filtered out up
let alix_filtered_groups = alix_conn.find_groups(&GroupQueryArgs::default()).unwrap();
assert_eq!(alix_filtered_groups.len(), 1);

let alix_msgs = alix_conn
.raw_query(|conn| {
group_messages::table
.filter(group_messages::kind.eq(GroupMessageKind::Application))
.load::<StoredGroupMessage>(conn)
})
.unwrap();

assert_eq!(alix_msgs.len(), 2);

let msg = String::from_utf8_lossy(&alix_msgs[1].decrypted_message_bytes);
assert_eq!(msg, "Hello there");

let msg = String::from_utf8_lossy(&alix_msgs[0].decrypted_message_bytes);
assert_eq!(msg, "No, let's use this dm");
}

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))]
async fn test_add_inbox() {
let client = ClientBuilder::new_test_client(&generate_local_wallet()).await;
Expand Down
67 changes: 54 additions & 13 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ pub struct GroupQueryArgs {
pub conversation_type: Option<ConversationType>,
pub consent_state: Option<ConsentState>,
pub include_sync_groups: bool,
pub include_duplicate_dms: bool,
}

impl AsRef<GroupQueryArgs> for GroupQueryArgs {
Expand Down Expand Up @@ -220,21 +221,25 @@ impl DbConnection {
conversation_type,
consent_state,
include_sync_groups,
include_duplicate_dms,
} = args.as_ref();

let mut query = groups_dsl::groups
.filter(groups_dsl::conversation_type.ne(ConversationType::Sync))
// Group by dm_id and grab the latest group (conversation stitching)
.filter(sql::<diesel::sql_types::Bool>(
.order(groups_dsl::created_at_ns.asc())
.into_boxed();

if !include_duplicate_dms {
query = query.filter(sql::<diesel::sql_types::Bool>(
"id IN (
SELECT id
FROM groups
GROUP BY CASE WHEN dm_id IS NULL THEN id ELSE dm_id END
ORDER BY last_message_ns DESC
)",
))
.order(groups_dsl::created_at_ns.asc())
.into_boxed();
));
}

if let Some(limit) = limit {
query = query.limit(*limit);
Expand Down Expand Up @@ -581,7 +586,10 @@ pub(crate) mod tests {
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);

use std::sync::atomic::{AtomicU16, Ordering};
use std::{
sync::atomic::{AtomicU16, Ordering},
time::Duration,
};

use super::*;
use crate::{
Expand Down Expand Up @@ -620,24 +628,21 @@ pub(crate) mod tests {
}
}

static INBOX_ID: AtomicU16 = AtomicU16::new(2);
static TARGET_INBOX_ID: AtomicU16 = AtomicU16::new(2);

/// Generate a test dm group
pub fn generate_dm(state: Option<GroupMembershipState>) -> StoredGroup {
let id = rand_vec::<24>();
let created_at_ns = now_ns();
let membership_state = state.unwrap_or(GroupMembershipState::Allowed);
let members = DmMembers {
member_one_inbox_id: "placeholder_inbox_id_1".to_string(),
member_two_inbox_id: format!(
"placeholder_inbox_id_{}",
INBOX_ID.fetch_add(1, Ordering::SeqCst)
TARGET_INBOX_ID.fetch_add(1, Ordering::SeqCst)
),
};
StoredGroup::new(
id,
created_at_ns,
membership_state,
rand_vec::<24>(),
now_ns(),
state.unwrap_or(GroupMembershipState::Allowed),
"placeholder_address".to_string(),
Some(members),
)
Expand Down Expand Up @@ -699,6 +704,42 @@ pub(crate) mod tests {
})
.await
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_dm_stitching() {
with_connection(|conn| {
let dm1 = StoredGroup::new(
rand_vec::<24>(),
now_ns(),
GroupMembershipState::Allowed,
"placeholder_address".to_string(),
Some(DmMembers {
member_one_inbox_id: "thats_me".to_string(),
member_two_inbox_id: "some_wise_guy".to_string(),
}),
);
dm1.store(conn).unwrap();

let dm2 = StoredGroup::new(
rand_vec::<24>(),
now_ns(),
GroupMembershipState::Allowed,
"placeholder_address".to_string(),
Some(DmMembers {
member_one_inbox_id: "some_wise_guy".to_string(),
member_two_inbox_id: "thats_me".to_string(),
}),
);
dm2.store(conn).unwrap();

let all_groups = conn.find_groups(GroupQueryArgs::default()).unwrap();

assert_eq!(all_groups.len(), 1);
})
.await;
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_find_groups() {
Expand Down

0 comments on commit e7f933b

Please sign in to comment.