diff --git a/xmtp/src/test_utils.rs b/xmtp/src/test_utils.rs index f00bf9c2c..5021d373b 100644 --- a/xmtp/src/test_utils.rs +++ b/xmtp/src/test_utils.rs @@ -1,3 +1,4 @@ +#[allow(clippy::module_inception)] #[cfg(test)] pub mod test_utils { use xmtp_proto::api_client::XmtpApiClient; diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index 70de1a37c..54131e778 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -140,8 +140,8 @@ where .store .take() .ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?; - let mut conn = store.conn()?; - let provider = XmtpOpenMlsProvider::new(&mut conn); + let conn = store.conn()?; + let provider = XmtpOpenMlsProvider::new(&conn); let identity = self .identity_strategy .initialize_identity(&store, &provider)?; diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 93b30fa94..2d3dced29 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -17,8 +17,9 @@ use crate::{ identity::Identity, retry::Retry, storage::{ + db_connection::DbConnection, group::{GroupMembershipState, StoredGroup}, - DbConnection, EncryptedMessageStore, StorageError, + EncryptedMessageStore, StorageError, }, types::Address, utils::topic::get_welcome_topic, @@ -147,8 +148,8 @@ where } // TODO: Remove this and figure out the correct lifetimes to allow long lived provider - pub(crate) fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> { - XmtpOpenMlsProvider::new(conn) + pub(crate) fn mls_provider(&self, conn: &'a DbConnection<'a>) -> XmtpOpenMlsProvider<'a> { + XmtpOpenMlsProvider::<'a>::new(conn) } pub fn create_group(&self) -> Result, ClientError> { @@ -174,24 +175,21 @@ where created_before_ns: Option, limit: Option, ) -> Result>, ClientError> { - Ok(EncryptedMessageStore::find_groups( - &mut self.store.conn()?, - allowed_states, - created_after_ns, - created_before_ns, - limit, - )? - .into_iter() - .map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns)) - .collect()) + Ok(self + .store + .conn()? + .find_groups(allowed_states, created_after_ns, created_before_ns, limit)? + .into_iter() + .map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns)) + .collect()) } pub async fn register_identity(&self) -> Result<(), ClientError> { // TODO: Mark key package as last_resort in creation - let mut connection = self.store.conn()?; + let connection = self.store.conn()?; let last_resort_kp = self .identity - .new_key_package(&self.mls_provider(&mut connection))?; + .new_key_package(&self.mls_provider(&connection))?; let last_resort_kp_bytes = last_resort_kp.tls_serialize_detached()?; self.api_client @@ -233,9 +231,8 @@ where } pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result, ClientError> { - let mut conn = self.store.conn()?; - let last_synced_timestamp_ns = - EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?; + let conn = self.store.conn()?; + let last_synced_timestamp_ns = conn.get_last_synced_timestamp_for_topic(topic)?; let envelopes = self .api_client @@ -261,14 +258,10 @@ where where ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result, { - XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| { - let is_updated = { - EncryptedMessageStore::update_last_synced_timestamp_for_topic( - &mut provider.conn().borrow_mut(), - topic, - envelope_timestamp_ns as i64, - )? - }; + self.store.transaction(|provider| { + let is_updated = provider + .conn() + .update_last_synced_timestamp_for_topic(topic, envelope_timestamp_ns as i64)?; if !is_updated { return Err(MessageProcessingError::AlreadyProcessed( envelope_timestamp_ns, @@ -302,12 +295,12 @@ where .consume_key_packages(installation_ids) .await?; - let mut conn = self.store.conn()?; + let conn = self.store.conn()?; Ok(key_package_results .values() .map(|bytes| { - VerifiedKeyPackage::from_bytes(&self.mls_provider(&mut conn), bytes.as_slice()) + VerifiedKeyPackage::from_bytes(&self.mls_provider(&conn), bytes.as_slice()) }) .collect::>()?) } diff --git a/xmtp_mls/src/codecs/membership_change.rs b/xmtp_mls/src/codecs/membership_change.rs index 43694dc99..b02e91192 100644 --- a/xmtp_mls/src/codecs/membership_change.rs +++ b/xmtp_mls/src/codecs/membership_change.rs @@ -72,7 +72,7 @@ mod tests { encoded.clone().r#type.unwrap().type_id, "group_membership_change" ); - assert!(encoded.content.len() > 0); + assert!(!encoded.content.is_empty()); let decoded = GroupMembershipChangeCodec::decode(encoded).unwrap(); assert_eq!(decoded.members_added[0], new_member); diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 98e32b427..0cc8b04ad 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -259,17 +259,17 @@ mod tests { let wallet = generate_local_wallet(); let wallet_address = wallet.get_address(); let client = ClientBuilder::new_test_client(wallet.into()).await; - let mut conn = client.store.conn().unwrap(); + let conn = client.store.conn().unwrap(); let key_package = client .identity - .new_key_package(&client.mls_provider(&mut conn)) + .new_key_package(&client.mls_provider(&conn)) .unwrap(); let verified_key_package = VerifiedKeyPackage::new(key_package, wallet_address.clone()); let intent = AddMembersIntentData::new(vec![verified_key_package.clone()]); let as_bytes: Vec = intent.clone().try_into().unwrap(); let restored_intent = - AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&mut conn)) + AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&conn)) .unwrap(); assert!(intent.key_packages[0] diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 905d95702..10810529d 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -19,7 +19,7 @@ where // Load the member list for the group from the DB, merging together multiple installations into a single entry pub fn members(&self) -> Result, GroupError> { let openmls_group = - self.load_mls_group(&self.client.mls_provider(&mut self.client.store.conn()?))?; + self.load_mls_group(&self.client.mls_provider(&self.client.store.conn()?))?; let member_map: HashMap = openmls_group .members() diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 6bbe9f055..a2d580833 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -32,10 +32,11 @@ use crate::{ retry::{Retry, RetryableError}, retryable, storage::{ + db_connection::DbConnection, group::{GroupMembershipState, StoredGroup}, group_intent::{IntentKind, IntentState, NewGroupIntent, StoredGroupIntent}, group_message::{GroupMessageKind, StoredGroupMessage}, - DbConnection, EncryptedMessageStore, StorageError, + StorageError, }, utils::{hash::sha256, id::get_message_id, time::now_ns, topic::get_group_topic}, xmtp_openmls_provider::XmtpOpenMlsProvider, @@ -126,8 +127,8 @@ where client: &'c Client, membership_state: GroupMembershipState, ) -> Result { - let mut conn = client.store.conn()?; - let provider = XmtpOpenMlsProvider::new(&mut conn); + let conn = client.store.conn()?; + let provider = XmtpOpenMlsProvider::new(&conn); let mut mls_group = OpenMlsGroup::new( &provider, &client.identity.installation_keys, @@ -141,10 +142,7 @@ where mls_group.save(provider.key_store())?; let group_id = mls_group.group_id().to_vec(); let stored_group = StoredGroup::new(group_id.clone(), now_ns(), membership_state); - { - stored_group.store(*provider.conn().borrow_mut())?; - } - + stored_group.store(provider.conn())?; Ok(Self::new(client, group_id, stored_group.created_at_ns)) } @@ -160,9 +158,7 @@ where let group_id = mls_group.group_id().to_vec(); let stored_group = StoredGroup::new(group_id.clone(), now_ns(), GroupMembershipState::Pending); - { - stored_group.store(*provider.conn().borrow_mut())?; - } + stored_group.store(provider.conn())?; Ok(Self::new(client, group_id, stored_group.created_at_ns)) } @@ -174,15 +170,9 @@ where sent_after_ns: Option, limit: Option, ) -> Result, GroupError> { - let mut conn = self.client.store.conn()?; - let messages = EncryptedMessageStore::get_group_messages( - &mut conn, - &self.group_id, - sent_after_ns, - sent_before_ns, - kind, - limit, - )?; + let conn = self.client.store.conn()?; + let messages = + conn.get_group_messages(&self.group_id, sent_after_ns, sent_before_ns, kind, limit)?; Ok(messages) } @@ -238,6 +228,7 @@ where intent.kind ); + let conn = provider.conn(); match intent.kind { IntentKind::AddMembers | IntentKind::RemoveMembers | IntentKind::KeyUpdate => { // We don't get errors with merge_pending_commit when there are no commits to merge @@ -248,12 +239,7 @@ where "no pending commit to merge. Group epoch: {}. Message epoch: {}", group_epoch, message_epoch ); - { - EncryptedMessageStore::set_group_intent_to_publish( - &mut provider.conn().borrow_mut(), - intent.id, - )?; - } + conn.set_group_intent_to_publish(intent.id)?; return Err(MessageProcessingError::NoPendingCommit { message_epoch, @@ -266,12 +252,7 @@ where { debug!("error merging commit: {}", err); openmls_group.clear_pending_commit(); - { - EncryptedMessageStore::set_group_intent_to_publish( - &mut provider.conn().borrow_mut(), - intent.id, - )?; - } + conn.set_group_intent_to_publish(intent.id)?; } // TOOD: Handle writing transcript messages for adding/removing members } @@ -280,27 +261,20 @@ where let group_id = openmls_group.group_id().as_slice(); let decrypted_message_data = intent_data.message.as_slice(); - { - StoredGroupMessage { - id: get_message_id(decrypted_message_data, group_id, envelope_timestamp_ns), - group_id: group_id.to_vec(), - decrypted_message_bytes: intent_data.message, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id: self.client.installation_public_key(), - sender_wallet_address: self.client.account_address(), - } - .store(&mut provider.conn().borrow_mut())?; + StoredGroupMessage { + id: get_message_id(decrypted_message_data, group_id, envelope_timestamp_ns), + group_id: group_id.to_vec(), + decrypted_message_bytes: intent_data.message, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id: self.client.installation_public_key(), + sender_wallet_address: self.client.account_address(), } + .store(conn)?; } }; - { - EncryptedMessageStore::set_group_intent_committed( - &mut provider.conn().borrow_mut(), - intent.id, - )?; - } + conn.set_group_intent_committed(intent.id)?; Ok(()) } @@ -334,9 +308,7 @@ where sender_installation_id, sender_wallet_address: sender_account_address, }; - { - message.store(&mut provider.conn().borrow_mut())?; - } + message.store(provider.conn())?; } ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { // intentionally left blank. @@ -371,12 +343,9 @@ where )), }?; - let intent = { - EncryptedMessageStore::find_group_intent_by_payload_hash( - &mut provider.conn().borrow_mut(), - sha256(envelope.message.as_slice()), - ) - }; + let intent = provider + .conn() + .find_group_intent_by_payload_hash(sha256(envelope.message.as_slice())); match intent { // Intent with the payload hash matches Ok(Some(intent)) => self.process_own_message( @@ -415,8 +384,8 @@ where } pub fn process_messages(&self, envelopes: Vec) -> Result<(), GroupError> { - let mut conn = self.client.store.conn()?; - let provider = self.client.mls_provider(&mut conn); + let conn = &self.client.store.conn()?; + let provider = self.client.mls_provider(conn); let mut openmls_group = self.load_mls_group(&provider)?; let receive_errors: Vec = envelopes @@ -532,7 +501,7 @@ where self.sync_with_conn(conn).await } - async fn sync_with_conn(&self, conn: &mut DbConnection) -> Result<(), GroupError> { + async fn sync_with_conn<'a>(&self, conn: &'a DbConnection<'a>) -> Result<(), GroupError> { self.publish_intents(conn).await?; if let Err(e) = self.receive().await { log::warn!("receive error {:?}", e); @@ -542,18 +511,18 @@ where Ok(()) } - pub(crate) async fn publish_intents(&self, conn: &mut DbConnection) -> Result<(), GroupError> { + pub(crate) async fn publish_intents<'a>( + &self, + conn: &'a DbConnection<'a>, + ) -> Result<(), GroupError> { let provider = self.client.mls_provider(conn); let mut openmls_group = self.load_mls_group(&provider)?; - let intents = { - EncryptedMessageStore::find_group_intents( - &mut provider.conn().borrow_mut(), - self.group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - )? - }; + let intents = provider.conn().find_group_intents( + self.group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + )?; for intent in intents { let result = retry!( @@ -575,16 +544,13 @@ where .publish_to_group(vec![payload_slice]) .await?; - { - EncryptedMessageStore::set_group_intent_published( - &mut provider.conn().borrow_mut(), - intent.id, - sha256(payload_slice), - post_commit_data, - )?; - } + provider.conn().set_group_intent_published( + intent.id, + sha256(payload_slice), + post_commit_data, + )?; } - openmls_group.save(self.client.mls_provider(conn).key_store())?; + openmls_group.save(provider.key_store())?; Ok(()) } @@ -680,9 +646,8 @@ where } } - pub(crate) async fn post_commit(&self, conn: &mut DbConnection) -> Result<(), GroupError> { - let intents = EncryptedMessageStore::find_group_intents( - conn, + pub(crate) async fn post_commit(&self, conn: &DbConnection<'_>) -> Result<(), GroupError> { + let intents = conn.find_group_intents( self.group_id.clone(), Some(vec![IntentState::Committed]), None, @@ -707,7 +672,7 @@ where } } } - let deleter: &mut dyn Delete = conn; + let deleter: &dyn Delete = conn; deleter.delete(intent.id)?; } @@ -734,8 +699,7 @@ mod tests { use xmtp_cryptography::utils::generate_local_wallet; use crate::{ - builder::ClientBuilder, storage::group_intent::IntentState, storage::EncryptedMessageStore, - utils::topic::get_welcome_topic, + builder::ClientBuilder, storage::group_intent::IntentState, utils::topic::get_welcome_topic, }; #[tokio::test] @@ -810,37 +774,37 @@ mod tests { bola_group.receive().await.expect_err("expected error"); // Check Amal's MLS group state. - let mut amal_db = amal.store.conn().unwrap(); + let amal_db = amal.store.conn().unwrap(); let amal_mls_group = amal_group - .load_mls_group(&amal.mls_provider(&mut amal_db)) + .load_mls_group(&amal.mls_provider(&amal_db)) .unwrap(); let amal_members: Vec = amal_mls_group.members().collect(); assert_eq!(amal_members.len(), 3); // Check Bola's MLS group state. - let mut bola_db = bola.store.conn().unwrap(); + let bola_db = bola.store.conn().unwrap(); let bola_mls_group = bola_group - .load_mls_group(&bola.mls_provider(&mut bola_db)) + .load_mls_group(&bola.mls_provider(&bola_db)) .unwrap(); let bola_members: Vec = bola_mls_group.members().collect(); assert_eq!(bola_members.len(), 3); - let amal_uncommitted_intents = EncryptedMessageStore::find_group_intents( - &mut amal.store.conn().unwrap(), - amal_group.group_id.clone(), - Some(vec![IntentState::ToPublish, IntentState::Published]), - None, - ) - .unwrap(); + let amal_uncommitted_intents = amal_db + .find_group_intents( + amal_group.group_id.clone(), + Some(vec![IntentState::ToPublish, IntentState::Published]), + None, + ) + .unwrap(); assert_eq!(amal_uncommitted_intents.len(), 0); - let bola_uncommitted_intents = EncryptedMessageStore::find_group_intents( - &mut bola.store.conn().unwrap(), - bola_group.group_id.clone(), - Some(vec![IntentState::ToPublish, IntentState::Published]), - None, - ) - .unwrap(); + let bola_uncommitted_intents = bola_db + .find_group_intents( + bola_group.group_id.clone(), + Some(vec![IntentState::ToPublish, IntentState::Published]), + None, + ) + .unwrap(); // Bola should have one uncommitted intent for the failed attempt at adding Charlie, who is already in the group assert_eq!(bola_uncommitted_intents.len(), 1); } @@ -934,8 +898,8 @@ mod tests { .unwrap(); assert_eq!(messages.len(), 1); - let mut conn = client.store.conn().unwrap(); - let provider = super::XmtpOpenMlsProvider::new(&mut conn); + let conn = &client.store.conn().unwrap(); + let provider = super::XmtpOpenMlsProvider::new(conn); let mls_group = group.load_mls_group(&provider).unwrap(); let pending_commit = mls_group.pending_commit(); assert!(pending_commit.is_none()); diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index 52e26eaba..12e1bcd70 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -62,9 +62,7 @@ impl Identity { }; identity.new_key_package(provider)?; - { - StoredIdentity::from(&identity).store(*provider.conn().borrow_mut())?; - } + StoredIdentity::from(&identity).store(provider.conn())?; // TODO: upload credential_with_key and last_resort_key_package @@ -152,16 +150,16 @@ mod tests { #[test] fn does_not_error() { let store = EncryptedMessageStore::new_test(); - let mut conn = store.conn().unwrap(); - let provider = XmtpOpenMlsProvider::new(&mut conn); + let conn = &store.conn().unwrap(); + let provider = XmtpOpenMlsProvider::new(conn); Identity::new(&provider, &generate_local_wallet()).unwrap(); } #[test] fn test_key_package_extensions() { let store = EncryptedMessageStore::new_test(); - let mut conn = store.conn().unwrap(); - let provider = XmtpOpenMlsProvider::new(&mut conn); + let conn = &store.conn().unwrap(); + let provider = XmtpOpenMlsProvider::new(conn); let identity = Identity::new(&provider, &generate_local_wallet()).unwrap(); let new_key_package = identity.new_key_package(&provider).unwrap(); diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 30ebf5beb..910c0b31d 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -27,17 +27,17 @@ pub trait InboxOwner { // Inserts a model to the underlying data store pub trait Store { - fn store(&self, into: &mut StorageConnection) -> Result<(), StorageError>; + fn store(&self, into: &StorageConnection) -> Result<(), StorageError>; } pub trait Fetch { type Key; - fn fetch(&mut self, key: &Self::Key) -> Result, StorageError>; + fn fetch(&self, key: &Self::Key) -> Result, StorageError>; } pub trait Delete { type Key; - fn delete(&mut self, key: Self::Key) -> Result; + fn delete(&self, key: Self::Key) -> Result; } #[cfg(test)] diff --git a/xmtp_mls/src/retry.rs b/xmtp_mls/src/retry.rs index c8c34a886..bb5293694 100644 --- a/xmtp_mls/src/retry.rs +++ b/xmtp_mls/src/retry.rs @@ -277,10 +277,7 @@ mod tests { impl RetryableError for SomeError { fn is_retryable(&self) -> bool { - match self { - Self::ARetryableError => true, - _ => false, - } + matches!(self, Self::ARetryableError) } } diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs new file mode 100644 index 000000000..657b7d306 --- /dev/null +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -0,0 +1,47 @@ +use crate::storage::RawDbConnection; +use std::{cell::RefCell, fmt}; + +// Re-implementation of Cow without ToOwned requirement +enum RefOrValue<'a, T> { + Ref(&'a mut T), + Value(T), +} + +/// A wrapper for RawDbConnection that houses all XMTP DB operations. +/// Uses a RefCell internally for interior mutability, so that the connection +/// and transaction state can be shared between the OpenMLS Provider and +/// native XMTP operations +pub struct DbConnection<'a> { + wrapped_conn: RefCell>, +} + +impl<'a> DbConnection<'a> { + pub(crate) fn new(conn: &'a mut RawDbConnection) -> Self { + Self { + wrapped_conn: RefCell::new(RefOrValue::Ref(conn)), + } + } + pub(crate) fn held(conn: RawDbConnection) -> Self { + Self { + wrapped_conn: RefCell::new(RefOrValue::Value(conn)), + } + } + + pub(crate) fn raw_query(&self, fun: F) -> Result + where + F: FnOnce(&mut RawDbConnection) -> Result, + { + match *self.wrapped_conn.borrow_mut() { + RefOrValue::Ref(ref mut conn_ref) => fun(conn_ref), + RefOrValue::Value(ref mut conn) => fun(conn), + } + } +} + +impl fmt::Debug for DbConnection<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DbConnection") + .field("wrapped_conn", &"DbConnection") + .finish() + } +} diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index fa6268cd8..bd5e0dec9 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -11,8 +11,8 @@ use diesel::{ }; use super::{ + db_connection::DbConnection, schema::{groups, groups::dsl}, - DbConnection, EncryptedMessageStore, }; use crate::{impl_fetch, impl_store, StorageError}; @@ -45,9 +45,9 @@ impl StoredGroup { } } -impl EncryptedMessageStore { +impl DbConnection<'_> { pub fn find_groups( - conn: &mut DbConnection, + &self, allowed_states: Option>, created_after_ns: Option, created_before_ns: Option, @@ -71,18 +71,20 @@ impl EncryptedMessageStore { query = query.limit(limit); } - Ok(query.load(conn)?) + Ok(self.raw_query(|conn| query.load(conn))?) } /// Updates group membership state pub fn update_group_membership>( - conn: &mut DbConnection, + &self, id: GroupId, state: GroupMembershipState, ) -> Result<(), StorageError> { - diesel::update(dsl::groups.find(id.as_ref())) - .set(dsl::membership_state.eq(state)) - .execute(conn)?; + self.raw_query(|conn| { + diesel::update(dsl::groups.find(id.as_ref())) + .set(dsl::membership_state.eq(state)) + .execute(conn) + })?; Ok(()) } @@ -131,7 +133,7 @@ pub(crate) mod tests { use super::*; use crate::{ assert_ok, - storage::encrypted_store::{schema::groups::dsl::groups, tests::with_store}, + storage::encrypted_store::{schema::groups::dsl::groups, tests::with_connection}, utils::{test::rand_vec, time::now_ns}, Fetch, Store, }; @@ -147,41 +149,44 @@ pub(crate) mod tests { #[test] fn it_stores_group() { - with_store(|mut conn| { + with_connection(|conn| { let test_group = generate_group(None); - test_group.store(&mut conn).unwrap(); - assert_eq!(groups.first::(&mut conn).unwrap(), test_group); + test_group.store(conn).unwrap(); + assert_eq!( + conn.raw_query(|raw_conn| groups.first::(raw_conn)) + .unwrap(), + test_group + ); }) } #[test] fn it_fetches_group() { - with_store(|mut conn| { + with_connection(|conn| { let test_group = generate_group(None); - diesel::insert_into(groups) - .values(test_group.clone()) - .execute(&mut conn) - .unwrap(); + conn.raw_query(|raw_conn| { + diesel::insert_into(groups) + .values(test_group.clone()) + .execute(raw_conn) + }) + .unwrap(); - let fetched_group = Fetch::::fetch(&mut conn, &test_group.id); + let fetched_group: Result, StorageError> = + conn.fetch(&test_group.id); assert_ok!(fetched_group, Some(test_group)); }) } #[test] fn it_updates_group_membership_state() { - with_store(|mut conn| { + with_connection(|conn| { let test_group = generate_group(Some(GroupMembershipState::Pending)); - test_group.store(&mut conn).unwrap(); - EncryptedMessageStore::update_group_membership( - &mut conn, - &test_group.id, - GroupMembershipState::Rejected, - ) - .unwrap(); + test_group.store(conn).unwrap(); + conn.update_group_membership(&test_group.id, GroupMembershipState::Rejected) + .unwrap(); let updated_group: StoredGroup = conn.fetch(&test_group.id).ok().flatten().unwrap(); assert_eq!( @@ -196,41 +201,29 @@ pub(crate) mod tests { #[test] fn test_find_groups() { - with_store(|mut conn| { + with_connection(|conn| { let test_group_1 = generate_group(Some(GroupMembershipState::Pending)); - test_group_1.store(&mut conn).unwrap(); + test_group_1.store(conn).unwrap(); let test_group_2 = generate_group(Some(GroupMembershipState::Allowed)); - test_group_2.store(&mut conn).unwrap(); + test_group_2.store(conn).unwrap(); - let all_results = - EncryptedMessageStore::find_groups(&mut conn, None, None, None, None).unwrap(); + let all_results = conn.find_groups(None, None, None, None).unwrap(); assert_eq!(all_results.len(), 2); - let pending_results = EncryptedMessageStore::find_groups( - &mut conn, - Some(vec![GroupMembershipState::Pending]), - None, - None, - None, - ) - .unwrap(); + let pending_results = conn + .find_groups(Some(vec![GroupMembershipState::Pending]), None, None, None) + .unwrap(); assert_eq!(pending_results[0].id, test_group_1.id); assert_eq!(pending_results.len(), 1); // Offset and limit - let results_with_limit = - EncryptedMessageStore::find_groups(&mut conn, None, None, None, Some(1)).unwrap(); + let results_with_limit = conn.find_groups(None, None, None, Some(1)).unwrap(); assert_eq!(results_with_limit.len(), 1); assert_eq!(results_with_limit[0].id, test_group_1.id); - let results_with_created_at_ns_after = EncryptedMessageStore::find_groups( - &mut conn, - None, - Some(test_group_1.created_at_ns), - None, - Some(1), - ) - .unwrap(); + let results_with_created_at_ns_after = conn + .find_groups(None, Some(test_group_1.created_at_ns), None, Some(1)) + .unwrap(); assert_eq!(results_with_created_at_ns_after.len(), 1); assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id); }) diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 2a41bb451..22995a2b4 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -9,9 +9,9 @@ use diesel::{ }; use super::{ + db_connection::DbConnection, group, schema::{group_intents, group_intents::dsl}, - DbConnection, EncryptedMessageStore, }; use crate::{impl_fetch, impl_store, storage::StorageError, Delete}; @@ -53,10 +53,11 @@ pub struct StoredGroupIntent { impl_fetch!(StoredGroupIntent, group_intents, ID); -impl Delete for DbConnection { +impl Delete for DbConnection<'_> { type Key = ID; - fn delete(&mut self, key: ID) -> Result { - Ok(diesel::delete(dsl::group_intents.find(key)).execute(self)?) + fn delete(&self, key: ID) -> Result { + Ok(self + .raw_query(|raw_conn| diesel::delete(dsl::group_intents.find(key)).execute(raw_conn))?) } } @@ -82,10 +83,10 @@ impl NewGroupIntent { } } -impl EncryptedMessageStore { +impl DbConnection<'_> { // Query for group_intents by group_id, optionally filtering by state and kind pub fn find_group_intents( - conn: &mut DbConnection, + &self, group_id: Vec, allowed_states: Option>, allowed_kinds: Option>, @@ -104,28 +105,30 @@ impl EncryptedMessageStore { query = query.order(dsl::id.asc()); - Ok(query.load::(conn)?) + Ok(self.raw_query(|conn| query.load::(conn))?) } // Set the intent with the given ID to `Published` and set the payload hash. Optionally add // `post_commit_data` pub fn set_group_intent_published( - conn: &mut DbConnection, + &self, intent_id: ID, payload_hash: Vec, post_commit_data: Option>, ) -> Result<(), StorageError> { - let res = diesel::update(dsl::group_intents) - .filter(dsl::id.eq(intent_id)) - // State machine requires that the only valid state transition to Published is from - // ToPublish - .filter(dsl::state.eq(IntentState::ToPublish)) - .set(( - dsl::state.eq(IntentState::Published), - dsl::payload_hash.eq(payload_hash), - dsl::post_commit_data.eq(post_commit_data), - )) - .execute(conn)?; + let res = self.raw_query(|conn| { + diesel::update(dsl::group_intents) + .filter(dsl::id.eq(intent_id)) + // State machine requires that the only valid state transition to Published is from + // ToPublish + .filter(dsl::state.eq(IntentState::ToPublish)) + .set(( + dsl::state.eq(IntentState::Published), + dsl::payload_hash.eq(payload_hash), + dsl::post_commit_data.eq(post_commit_data), + )) + .execute(conn) + })?; match res { // If nothing matched the query, return an error. Either ID or state was wrong @@ -135,17 +138,16 @@ impl EncryptedMessageStore { } // Set the intent with the given ID to `Committed` - pub fn set_group_intent_committed( - conn: &mut DbConnection, - intent_id: ID, - ) -> Result<(), StorageError> { - let res = diesel::update(dsl::group_intents) - .filter(dsl::id.eq(intent_id)) - // State machine requires that the only valid state transition to Committed is from - // Published - .filter(dsl::state.eq(IntentState::Published)) - .set(dsl::state.eq(IntentState::Committed)) - .execute(conn)?; + pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { + let res = self.raw_query(|conn| { + diesel::update(dsl::group_intents) + .filter(dsl::id.eq(intent_id)) + // State machine requires that the only valid state transition to Committed is from + // Published + .filter(dsl::state.eq(IntentState::Published)) + .set(dsl::state.eq(IntentState::Committed)) + .execute(conn) + })?; match res { // If nothing matched the query, return an error. Either ID or state was wrong @@ -156,22 +158,21 @@ impl EncryptedMessageStore { // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` - pub fn set_group_intent_to_publish( - conn: &mut DbConnection, - intent_id: ID, - ) -> Result<(), StorageError> { - let res = diesel::update(dsl::group_intents) - .filter(dsl::id.eq(intent_id)) - // State machine requires that the only valid state transition to ToPublish is from - // Published - .filter(dsl::state.eq(IntentState::Published)) - .set(( - dsl::state.eq(IntentState::ToPublish), - // When moving to ToPublish, clear the payload hash and post commit data - dsl::payload_hash.eq(None::>), - dsl::post_commit_data.eq(None::>), - )) - .execute(conn)?; + pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { + let res = self.raw_query(|conn| { + diesel::update(dsl::group_intents) + .filter(dsl::id.eq(intent_id)) + // State machine requires that the only valid state transition to ToPublish is from + // Published + .filter(dsl::state.eq(IntentState::Published)) + .set(( + dsl::state.eq(IntentState::ToPublish), + // When moving to ToPublish, clear the payload hash and post commit data + dsl::payload_hash.eq(None::>), + dsl::post_commit_data.eq(None::>), + )) + .execute(conn) + })?; match res { // If nothing matched the query, return an error. Either ID or state was wrong @@ -183,13 +184,15 @@ impl EncryptedMessageStore { // Simple lookup of intents by payload hash, meant to be used when processing messages off the // network pub fn find_group_intent_by_payload_hash( - conn: &mut DbConnection, + &self, payload_hash: Vec, ) -> Result, StorageError> { - let result = dsl::group_intents - .filter(dsl::payload_hash.eq(payload_hash)) - .first::(conn) - .optional()?; + let result = self.raw_query(|conn| { + dsl::group_intents + .filter(dsl::payload_hash.eq(payload_hash)) + .first::(conn) + .optional() + })?; Ok(result) } @@ -250,13 +253,13 @@ mod tests { use crate::{ storage::encrypted_store::{ group::{GroupMembershipState, StoredGroup}, - tests::with_store, + tests::with_connection, }, utils::test::rand_vec, Fetch, Store, }; - fn insert_group(conn: &mut DbConnection, group_id: Vec) { + fn insert_group(conn: &DbConnection, group_id: Vec) { let group = StoredGroup::new(group_id, 100, GroupMembershipState::Allowed); group.store(conn).unwrap(); } @@ -279,11 +282,13 @@ mod tests { } } - fn find_first_intent(conn: &mut DbConnection, group_id: group::ID) -> StoredGroupIntent { - dsl::group_intents - .filter(dsl::group_id.eq(group_id)) - .first(conn) - .unwrap() + fn find_first_intent(conn: &DbConnection, group_id: group::ID) -> StoredGroupIntent { + conn.raw_query(|raw_conn| { + dsl::group_intents + .filter(dsl::group_id.eq(group_id)) + .first(raw_conn) + }) + .unwrap() } #[test] @@ -295,19 +300,15 @@ mod tests { let to_insert = NewGroupIntent::new_test(kind, group_id.clone(), data.clone(), state); - with_store(|mut conn| { + with_connection(|conn| { // Group needs to exist or FK constraint will fail - insert_group(&mut conn, group_id.clone()); + insert_group(conn, group_id.clone()); - to_insert.store(&mut conn).unwrap(); + to_insert.store(conn).unwrap(); - let results = EncryptedMessageStore::find_group_intents( - &mut conn, - group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - ) - .unwrap(); + let results = conn + .find_group_intents(group_id.clone(), Some(vec![IntentState::ToPublish]), None) + .unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].kind, kind); @@ -347,60 +348,59 @@ mod tests { ), ]; - with_store(|mut conn| { + with_connection(|conn| { // Group needs to exist or FK constraint will fail - insert_group(&mut conn, group_id.clone()); + insert_group(conn, group_id.clone()); for case in test_intents { - case.store(&mut conn).unwrap(); + case.store(conn).unwrap(); } // Can query for multiple states - let mut results = EncryptedMessageStore::find_group_intents( - &mut conn, - group_id.clone(), - Some(vec![IntentState::ToPublish, IntentState::Published]), - None, - ) - .unwrap(); + let mut results = conn + .find_group_intents( + group_id.clone(), + Some(vec![IntentState::ToPublish, IntentState::Published]), + None, + ) + .unwrap(); assert_eq!(results.len(), 2); // Can query by kind - results = EncryptedMessageStore::find_group_intents( - &mut conn, - group_id.clone(), - None, - Some(vec![IntentKind::RemoveMembers]), - ) - .unwrap(); + results = conn + .find_group_intents( + group_id.clone(), + None, + Some(vec![IntentKind::RemoveMembers]), + ) + .unwrap(); assert_eq!(results.len(), 2); // Can query by kind and state - results = EncryptedMessageStore::find_group_intents( - &mut conn, - group_id.clone(), - Some(vec![IntentState::Committed]), - Some(vec![IntentKind::RemoveMembers]), - ) - .unwrap(); + results = conn + .find_group_intents( + group_id.clone(), + Some(vec![IntentState::Committed]), + Some(vec![IntentKind::RemoveMembers]), + ) + .unwrap(); assert_eq!(results.len(), 1); // Can get no results - results = EncryptedMessageStore::find_group_intents( - &mut conn, - group_id.clone(), - Some(vec![IntentState::Committed]), - Some(vec![IntentKind::SendMessage]), - ) - .unwrap(); + results = conn + .find_group_intents( + group_id.clone(), + Some(vec![IntentState::Committed]), + Some(vec![IntentKind::SendMessage]), + ) + .unwrap(); assert_eq!(results.len(), 0); // Can get all intents - results = - EncryptedMessageStore::find_group_intents(&mut conn, group_id, None, None).unwrap(); + results = conn.find_group_intents(group_id, None, None).unwrap(); assert_eq!(results.len(), 3); }) } @@ -409,32 +409,31 @@ mod tests { fn find_by_payload_hash() { let group_id = rand_vec(); - with_store(|mut conn| { - insert_group(&mut conn, group_id.clone()); + with_connection(|conn| { + insert_group(conn, group_id.clone()); // Store the intent NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) - .store(&mut conn) + .store(conn) .unwrap(); // Find the intent with the ID populated - let intent = find_first_intent(&mut conn, group_id.clone()); + let intent = find_first_intent(conn, group_id.clone()); // Set the payload hash let payload_hash = rand_vec(); let post_commit_data = rand_vec(); - EncryptedMessageStore::set_group_intent_published( - &mut conn, + conn.set_group_intent_published( intent.id, payload_hash.clone(), Some(post_commit_data.clone()), ) .unwrap(); - let find_result = - EncryptedMessageStore::find_group_intent_by_payload_hash(&mut conn, payload_hash) - .unwrap() - .unwrap(); + let find_result = conn + .find_group_intent_by_payload_hash(payload_hash) + .unwrap() + .unwrap(); assert_eq!(find_result.id, intent.id); }) @@ -444,21 +443,20 @@ mod tests { fn test_happy_path_state_transitions() { let group_id = rand_vec(); - with_store(|mut conn| { - insert_group(&mut conn, group_id.clone()); + with_connection(|conn| { + insert_group(conn, group_id.clone()); // Store the intent NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) - .store(&mut conn) + .store(conn) .unwrap(); - let mut intent = find_first_intent(&mut conn, group_id.clone()); + let mut intent = find_first_intent(conn, group_id.clone()); // Set to published let payload_hash = rand_vec(); let post_commit_data = rand_vec(); - EncryptedMessageStore::set_group_intent_published( - &mut conn, + conn.set_group_intent_published( intent.id, payload_hash.clone(), Some(post_commit_data.clone()), @@ -470,7 +468,7 @@ mod tests { assert_eq!(intent.payload_hash, Some(payload_hash.clone())); assert_eq!(intent.post_commit_data, Some(post_commit_data.clone())); - EncryptedMessageStore::set_group_intent_committed(&mut conn, intent.id).unwrap(); + conn.set_group_intent_committed(intent.id).unwrap(); // Refresh from the DB intent = conn.fetch(&intent.id).unwrap().unwrap(); assert_eq!(intent.state, IntentState::Committed); @@ -483,21 +481,20 @@ mod tests { fn test_republish_state_transition() { let group_id = rand_vec(); - with_store(|mut conn| { - insert_group(&mut conn, group_id.clone()); + with_connection(|conn| { + insert_group(conn, group_id.clone()); // Store the intent NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) - .store(&mut conn) + .store(conn) .unwrap(); - let mut intent = find_first_intent(&mut conn, group_id.clone()); + let mut intent = find_first_intent(conn, group_id.clone()); // Set to published let payload_hash = rand_vec(); let post_commit_data = rand_vec(); - EncryptedMessageStore::set_group_intent_published( - &mut conn, + conn.set_group_intent_published( intent.id, payload_hash.clone(), Some(post_commit_data.clone()), @@ -509,7 +506,7 @@ mod tests { assert_eq!(intent.payload_hash, Some(payload_hash.clone())); // Now revert back to ToPublish - EncryptedMessageStore::set_group_intent_to_publish(&mut conn, intent.id).unwrap(); + conn.set_group_intent_to_publish(intent.id).unwrap(); intent = conn.fetch(&intent.id).unwrap().unwrap(); assert_eq!(intent.state, IntentState::ToPublish); assert!(intent.payload_hash.is_none()); @@ -521,23 +518,21 @@ mod tests { fn test_invalid_state_transition() { let group_id = rand_vec(); - with_store(|mut conn| { - insert_group(&mut conn, group_id.clone()); + with_connection(|conn| { + insert_group(conn, group_id.clone()); // Store the intent NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) - .store(&mut conn) + .store(conn) .unwrap(); - let intent = find_first_intent(&mut conn, group_id.clone()); + let intent = find_first_intent(conn, group_id.clone()); - let commit_result = - EncryptedMessageStore::set_group_intent_committed(&mut conn, intent.id); + let commit_result = conn.set_group_intent_committed(intent.id); assert!(commit_result.is_err()); assert_eq!(commit_result.err().unwrap(), StorageError::NotFound); - let to_publish_result = - EncryptedMessageStore::set_group_intent_to_publish(&mut conn, intent.id); + let to_publish_result = conn.set_group_intent_to_publish(intent.id); assert!(to_publish_result.is_err()); assert_eq!(to_publish_result.err().unwrap(), StorageError::NotFound); }) diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index e35c2eef7..feee8197b 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -8,7 +8,7 @@ use diesel::{ sqlite::Sqlite, }; -use super::{schema::group_messages, DbConnection, EncryptedMessageStore}; +use super::{db_connection::DbConnection, schema::group_messages}; use crate::{impl_fetch, impl_store, StorageError}; #[derive(Insertable, Identifiable, Queryable, Debug, Clone, PartialEq, Eq)] @@ -68,10 +68,10 @@ where impl_fetch!(StoredGroupMessage, group_messages, Vec); impl_store!(StoredGroupMessage, group_messages); -impl EncryptedMessageStore { +impl DbConnection<'_> { /// Query for group messages pub fn get_group_messages>( - conn: &mut DbConnection, + &self, group_id: GroupId, sent_after_ns: Option, sent_before_ns: Option, @@ -100,19 +100,21 @@ impl EncryptedMessageStore { query = query.limit(limit); } - Ok(query.load::(conn)?) + Ok(self.raw_query(|conn| query.load::(conn))?) } /// Get a particular group message pub fn get_group_message>( + &self, id: MessageId, - conn: &mut DbConnection, ) -> Result, StorageError> { use super::schema::group_messages::dsl; - Ok(dsl::group_messages - .filter(dsl::id.eq(id.as_ref())) - .first(conn) - .optional()?) + Ok(self.raw_query(|conn| { + dsl::group_messages + .filter(dsl::id.eq(id.as_ref())) + .first(conn) + .optional() + })?) } } @@ -121,7 +123,7 @@ mod tests { use super::*; use crate::{ assert_err, assert_ok, - storage::encrypted_store::{group::tests::generate_group, tests::with_store}, + storage::encrypted_store::{group::tests::generate_group, tests::with_connection}, utils::test::{rand_time, rand_vec}, Store, }; @@ -144,26 +146,23 @@ mod tests { #[test] fn it_does_not_error_on_empty_messages() { - with_store(|mut conn| { + with_connection(|conn| { let id = vec![0x0]; - assert_ok!( - EncryptedMessageStore::get_group_message(&id, &mut conn), - None - ); + assert_ok!(conn.get_group_message(id), None); }) } #[test] fn it_gets_messages() { - with_store(|mut conn| { + with_connection(|conn| { let group = generate_group(None); let message = generate_message(None, Some(&group.id), None); - group.store(&mut conn).unwrap(); + group.store(conn).unwrap(); let id = message.id.clone(); - message.store(&mut conn).unwrap(); + message.store(conn).unwrap(); - let stored_message = EncryptedMessageStore::get_group_message(&id, &mut conn); + let stored_message = conn.get_group_message(id); assert_ok!(stored_message, Some(message)); }) } @@ -172,10 +171,10 @@ mod tests { fn it_cannot_insert_message_without_group() { use diesel::result::{DatabaseErrorKind::ForeignKeyViolation, Error::DatabaseError}; - with_store(|mut conn| { + with_connection(|conn| { let message = generate_message(None, None, None); assert_err!( - message.store(&mut conn), + message.store(conn), StorageError::DieselResult(DatabaseError(ForeignKeyViolation, _)) ); }) @@ -185,34 +184,36 @@ mod tests { fn it_gets_many_messages() { use crate::storage::encrypted_store::schema::group_messages::dsl; - with_store(|mut conn| { + with_connection(|conn| { let group = generate_group(None); - group.store(&mut conn).unwrap(); + group.store(conn).unwrap(); for _ in 0..50 { let msg = generate_message(None, Some(&group.id), None); - assert_ok!(msg.store(&mut conn)); + assert_ok!(msg.store(conn)); } - let count: i64 = dsl::group_messages - .select(diesel::dsl::count_star()) - .first(&mut conn) + let count: i64 = conn + .raw_query(|raw_conn| { + dsl::group_messages + .select(diesel::dsl::count_star()) + .first(raw_conn) + }) .unwrap(); assert_eq!(count, 50); - let messages = EncryptedMessageStore::get_group_messages( - &mut conn, &group.id, None, None, None, None, - ) - .unwrap(); + let messages = conn + .get_group_messages(&group.id, None, None, None, None) + .unwrap(); assert_eq!(messages.len(), 50); }) } #[test] fn it_gets_messages_by_time() { - with_store(|mut conn| { + with_connection(|conn| { let group = generate_group(None); - group.store(&mut conn).unwrap(); + group.store(conn).unwrap(); let messages = vec![ generate_message(None, Some(&group.id), Some(1_000)), @@ -220,48 +221,30 @@ mod tests { generate_message(None, Some(&group.id), Some(100_000)), generate_message(None, Some(&group.id), Some(1_000_000)), ]; - assert_ok!(messages.store(&mut conn)); - let message = EncryptedMessageStore::get_group_messages( - &mut conn, - &group.id, - Some(1_000), - Some(100_000), - None, - None, - ) - .unwrap(); + assert_ok!(messages.store(conn)); + let message = conn + .get_group_messages(&group.id, Some(1_000), Some(100_000), None, None) + .unwrap(); assert_eq!(message.len(), 1); assert_eq!(message.first().unwrap().sent_at_ns, 10_000); - let messages = EncryptedMessageStore::get_group_messages( - &mut conn, - &group.id, - None, - Some(100_000), - None, - None, - ) - .unwrap(); + let messages = conn + .get_group_messages(&group.id, None, Some(100_000), None, None) + .unwrap(); assert_eq!(messages.len(), 2); - let messages = EncryptedMessageStore::get_group_messages( - &mut conn, - &group.id, - Some(10_000), - None, - None, - None, - ) - .unwrap(); + let messages = conn + .get_group_messages(&group.id, Some(10_000), None, None, None) + .unwrap(); assert_eq!(messages.len(), 2); }) } #[test] fn it_gets_messages_by_kind() { - with_store(|mut conn| { + with_connection(|conn| { let group = generate_group(None); - group.store(&mut conn).unwrap(); + group.store(conn).unwrap(); // just a bunch of random messages so we have something to filter through for i in 0..30 { @@ -272,7 +255,7 @@ mod tests { Some(&group.id), None, ); - msg.store(&mut conn).unwrap(); + msg.store(conn).unwrap(); } 1 => { let msg = generate_message( @@ -280,50 +263,50 @@ mod tests { Some(&group.id), None, ); - msg.store(&mut conn).unwrap(); + msg.store(conn).unwrap(); } - 2 | _ => { + _ => { let msg = generate_message( Some(GroupMessageKind::MemberAdded), Some(&group.id), None, ); - msg.store(&mut conn).unwrap(); + msg.store(conn).unwrap(); } } } - let application_messages = EncryptedMessageStore::get_group_messages( - &mut conn, - &group.id, - None, - None, - Some(GroupMessageKind::Application), - None, - ) - .unwrap(); + let application_messages = conn + .get_group_messages( + &group.id, + None, + None, + Some(GroupMessageKind::Application), + None, + ) + .unwrap(); assert_eq!(application_messages.len(), 10); - let member_removed = EncryptedMessageStore::get_group_messages( - &mut conn, - &group.id, - None, - None, - Some(GroupMessageKind::MemberAdded), - None, - ) - .unwrap(); + let member_removed = conn + .get_group_messages( + &group.id, + None, + None, + Some(GroupMessageKind::MemberAdded), + None, + ) + .unwrap(); assert_eq!(member_removed.len(), 10); - let member_added = EncryptedMessageStore::get_group_messages( - &mut conn, - &group.id, - None, - None, - Some(GroupMessageKind::MemberRemoved), - None, - ) - .unwrap(); + let member_added = conn + .get_group_messages( + &group.id, + None, + None, + Some(GroupMessageKind::MemberRemoved), + None, + ) + .unwrap(); assert_eq!(member_added.len(), 10); }) } diff --git a/xmtp_mls/src/storage/encrypted_store/identity.rs b/xmtp_mls/src/storage/encrypted_store/identity.rs index 36a89a90d..2ef3d0dec 100644 --- a/xmtp_mls/src/storage/encrypted_store/identity.rs +++ b/xmtp_mls/src/storage/encrypted_store/identity.rs @@ -72,7 +72,7 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); - let conn = &mut store.conn().unwrap(); + let conn = &store.conn().unwrap(); StoredIdentity::new("".to_string(), rand_vec(), rand_vec()) .store(conn) diff --git a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs index ef484f885..eb0048f12 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs @@ -1,6 +1,6 @@ use diesel::prelude::*; -use super::{schema::openmls_key_store, DbConnection, EncryptedMessageStore, StorageError}; +use super::{db_connection::DbConnection, schema::openmls_key_store, StorageError}; use crate::{impl_fetch, impl_store, Delete}; #[derive(Insertable, Queryable, Debug, Clone)] @@ -14,17 +14,19 @@ pub struct StoredKeyStoreEntry { impl_fetch!(StoredKeyStoreEntry, openmls_key_store, Vec); impl_store!(StoredKeyStoreEntry, openmls_key_store); -impl Delete for DbConnection { +impl Delete for DbConnection<'_> { type Key = Vec; - fn delete(&mut self, key: Vec) -> Result where { + fn delete(&self, key: Vec) -> Result where { use super::schema::openmls_key_store::dsl::*; - Ok(diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(self)?) + Ok(self.raw_query(|conn| { + diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(conn) + })?) } } -impl EncryptedMessageStore { +impl DbConnection<'_> { pub fn insert_or_update_key_store_entry( - conn: &mut DbConnection, + &self, key: Vec, value: Vec, ) -> Result<(), StorageError> { @@ -34,9 +36,11 @@ impl EncryptedMessageStore { value_bytes: value, }; - diesel::replace_into(openmls_key_store) - .values(entry) - .execute(conn)?; + self.raw_query(|conn| { + diesel::replace_into(openmls_key_store) + .values(entry) + .execute(conn) + })?; Ok(()) } } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 6f8a9c349..fff68c864 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -10,6 +10,7 @@ //! table definitions `schema.rs` must also be updated. To generate the correct schemas you can run //! `diesel print-schema` or use `cargo run update-schema` which will update the files for you. +pub mod db_connection; pub mod group; pub mod group_intent; pub mod group_message; @@ -31,12 +32,14 @@ use log::warn; use rand::RngCore; use xmtp_cryptography::utils as crypto_utils; +use self::db_connection::DbConnection; + use super::StorageError; -use crate::Store; +use crate::{xmtp_openmls_provider::XmtpOpenMlsProvider, Store}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations/"); -pub type DbConnection = PooledConnection>; +pub type RawDbConnection = PooledConnection>; pub type EncryptionKey = [u8; 32]; @@ -115,7 +118,7 @@ impl EncryptedMessageStore { } fn init_db(&mut self) -> Result<(), StorageError> { - let conn = &mut self.conn()?; + let conn = &mut self.raw_conn()?; conn.run_pending_migrations(MIGRATIONS) .map_err(|e| StorageError::DbInit(e.to_string()))?; @@ -123,7 +126,7 @@ impl EncryptedMessageStore { Ok(()) } - pub fn conn( + fn raw_conn( &self, ) -> Result>, StorageError> { let conn = self @@ -134,6 +137,37 @@ impl EncryptedMessageStore { Ok(conn) } + pub fn conn(&self) -> Result { + let conn = self.raw_conn()?; + Ok(DbConnection::held(conn)) + } + + /// Start a new database transaction with the OpenMLS Provider from XMTP + /// # Arguments + /// `fun`: Scoped closure providing a MLSProvider to carry out the transaction + /// + /// # Examples + /// + /// ```ignore + /// store.transaction(|provider| { + /// // do some operations requiring provider + /// // access the connection with .conn() + /// provider.conn().db_operation()?; + /// }) + /// ``` + pub fn transaction(&self, fun: F) -> Result + where + F: FnOnce(XmtpOpenMlsProvider) -> Result, + E: From + From, + { + let mut connection = self.raw_conn()?; + connection.transaction(|conn| { + let db_connection = DbConnection::new(conn); + let provider = XmtpOpenMlsProvider::new(&db_connection); + fun(provider) + }) + } + fn set_sqlcipher_key( pool: Pool>, encryption_key: &[u8; 32], @@ -170,21 +204,25 @@ fn warn_length(list: &Vec, str_id: &str, max_length: usize) { #[macro_export] macro_rules! impl_fetch { ($model:ty, $table:ident) => { - impl $crate::Fetch<$model> for $crate::storage::encrypted_store::DbConnection { + impl $crate::Fetch<$model> + for $crate::storage::encrypted_store::db_connection::DbConnection<'_> + { type Key = (); - fn fetch(&mut self, _key: &Self::Key) -> Result, $crate::StorageError> { + fn fetch(&self, _key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok($table.first(self).optional()?) + Ok(self.raw_query(|conn| $table.first(conn).optional())?) } } }; ($model:ty, $table:ident, $key:ty) => { - impl $crate::Fetch<$model> for $crate::storage::encrypted_store::DbConnection { + impl $crate::Fetch<$model> + for $crate::storage::encrypted_store::db_connection::DbConnection<'_> + { type Key = $key; - fn fetch(&mut self, key: &Self::Key) -> Result, $crate::StorageError> { + fn fetch(&self, key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; - Ok($table.find(key).first(self).optional()?) + Ok(self.raw_query(|conn| $table.find(key).first(conn).optional())?) } } }; @@ -193,26 +231,29 @@ macro_rules! impl_fetch { #[macro_export] macro_rules! impl_store { ($model:ty, $table:ident) => { - impl $crate::Store<$crate::storage::encrypted_store::DbConnection> for $model { + impl $crate::Store<$crate::storage::encrypted_store::db_connection::DbConnection<'_>> + for $model + { fn store( &self, - into: &mut $crate::storage::encrypted_store::DbConnection, + into: &$crate::storage::encrypted_store::db_connection::DbConnection<'_>, ) -> Result<(), $crate::StorageError> { - diesel::insert_into($table::table) - .values(self) - .execute(into) - .map_err(|e| $crate::StorageError::from(e))?; + into.raw_query(|conn| { + diesel::insert_into($table::table) + .values(self) + .execute(conn) + })?; Ok(()) } } }; } -impl Store for Vec +impl<'a, T> Store> for Vec where - T: Store, + T: Store>, { - fn store(&self, into: &mut DbConnection) -> Result<(), StorageError> { + fn store(&self, into: &DbConnection<'a>) -> Result<(), StorageError> { for item in self { item.store(into)?; } @@ -224,16 +265,19 @@ where mod tests { use std::{boxed::Box, fs}; - use super::{identity::StoredIdentity, EncryptedMessageStore, StorageError, StorageOption}; + use super::{ + db_connection::DbConnection, identity::StoredIdentity, EncryptedMessageStore, StorageError, + StorageOption, + }; use crate::{ utils::test::{rand_vec, tmp_path}, Fetch, Store, }; /// Test harness that loads an Ephemeral store. - pub fn with_store(fun: F) -> R + pub fn with_connection(fun: F) -> R where - F: FnOnce(super::DbConnection) -> R, + F: FnOnce(&DbConnection) -> R, { crate::tests::setup(); let store = EncryptedMessageStore::new( @@ -241,7 +285,7 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); - let conn = store.conn().expect("acquiring a Connection failed"); + let conn = &store.conn().expect("acquiring a Connection failed"); fun(conn) } @@ -263,7 +307,7 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); - let conn = &mut store.conn().unwrap(); + let conn = &store.conn().unwrap(); let account_address = "address"; StoredIdentity::new(account_address.to_string(), rand_vec(), rand_vec()) @@ -283,7 +327,7 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); - let conn = &mut store.conn().unwrap(); + let conn = &store.conn().unwrap(); let account_address = "address"; StoredIdentity::new(account_address.to_string(), rand_vec(), rand_vec()) @@ -309,7 +353,7 @@ mod tests { .unwrap(); StoredIdentity::new("dummy_address".to_string(), rand_vec(), rand_vec()) - .store(&mut store.conn().unwrap()) + .store(&store.conn().unwrap()) .unwrap(); } // Drop it diff --git a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs index 51c4f46c6..612826d74 100644 --- a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs @@ -1,6 +1,6 @@ use diesel::prelude::*; -use super::{schema::topic_refresh_state, DbConnection, EncryptedMessageStore}; +use super::{db_connection::DbConnection, schema::topic_refresh_state}; use crate::{impl_fetch, impl_store, storage::StorageError, Fetch, Store}; #[derive(Insertable, Identifiable, Queryable, Debug, Clone)] @@ -14,12 +14,9 @@ pub struct TopicRefreshState { impl_fetch!(TopicRefreshState, topic_refresh_state, String); impl_store!(TopicRefreshState, topic_refresh_state); -impl EncryptedMessageStore { - pub fn get_last_synced_timestamp_for_topic( - conn: &mut DbConnection, - topic: &str, - ) -> Result { - let state: Option = conn.fetch(&topic.to_string())?; +impl DbConnection<'_> { + pub fn get_last_synced_timestamp_for_topic(&self, topic: &str) -> Result { + let state: Option = self.fetch(&topic.to_string())?; match state { Some(state) => Ok(state.last_message_timestamp_ns), None => { @@ -27,25 +24,27 @@ impl EncryptedMessageStore { topic: topic.to_string(), last_message_timestamp_ns: 0, }; - new_state.store(conn)?; + new_state.store(self)?; Ok(0) } } } pub fn update_last_synced_timestamp_for_topic( - conn: &mut DbConnection, + &self, topic: &str, timestamp_ns: i64, ) -> Result { - let state: Option = conn.fetch(&topic.to_string())?; + let state: Option = self.fetch(&topic.to_string())?; match state { Some(state) => { use super::schema::topic_refresh_state::dsl; - let num_updated = diesel::update(&state) - .filter(dsl::last_message_timestamp_ns.lt(timestamp_ns)) - .set(dsl::last_message_timestamp_ns.eq(timestamp_ns)) - .execute(conn)?; + let num_updated = self.raw_query(|conn| { + diesel::update(&state) + .filter(dsl::last_message_timestamp_ns.lt(timestamp_ns)) + .set(dsl::last_message_timestamp_ns.eq(timestamp_ns)) + .execute(conn) + })?; Ok(num_updated == 1) } None => Err(StorageError::NotFound), @@ -56,16 +55,15 @@ impl EncryptedMessageStore { #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::{storage::encrypted_store::tests::with_store, Fetch, Store}; + use crate::{storage::encrypted_store::tests::with_connection, Fetch, Store}; #[test] fn get_timestamp_with_no_existing_state() { - with_store(|mut conn| { + with_connection(|conn| { let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); assert!(entry.is_none()); assert_eq!( - EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, "topic") - .unwrap(), + conn.get_last_synced_timestamp_for_topic("topic").unwrap(), 0 ); let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); @@ -75,15 +73,14 @@ pub(crate) mod tests { #[test] fn get_timestamp_with_existing_state() { - with_store(|mut conn| { + with_connection(|conn| { let entry = TopicRefreshState { topic: "topic".to_string(), last_message_timestamp_ns: 123, }; - entry.store(&mut conn).unwrap(); + entry.store(conn).unwrap(); assert_eq!( - EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, "topic") - .unwrap(), + conn.get_last_synced_timestamp_for_topic("topic").unwrap(), 123 ); }) @@ -91,18 +88,15 @@ pub(crate) mod tests { #[test] fn update_timestamp_when_bigger() { - with_store(|mut conn| { + with_connection(|conn| { let entry = TopicRefreshState { topic: "topic".to_string(), last_message_timestamp_ns: 123, }; - entry.store(&mut conn).unwrap(); - assert!( - EncryptedMessageStore::update_last_synced_timestamp_for_topic( - &mut conn, "topic", 124 - ) - .unwrap() - ); + entry.store(conn).unwrap(); + assert!(conn + .update_last_synced_timestamp_for_topic("topic", 124) + .unwrap()); let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); assert_eq!(entry.unwrap().last_message_timestamp_ns, 124); }) @@ -110,18 +104,15 @@ pub(crate) mod tests { #[test] fn dont_update_timestamp_when_smaller() { - with_store(|mut conn| { + with_connection(|conn| { let entry = TopicRefreshState { topic: "topic".to_string(), last_message_timestamp_ns: 123, }; - entry.store(&mut conn).unwrap(); - assert!( - !EncryptedMessageStore::update_last_synced_timestamp_for_topic( - &mut conn, "topic", 122 - ) - .unwrap() - ); + entry.store(conn).unwrap(); + assert!(!conn + .update_last_synced_timestamp_for_topic("topic", 122) + .unwrap()); let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); assert_eq!(entry.unwrap().last_message_timestamp_ns, 123); }) diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 85d7fc25a..72e9cd63a 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -4,7 +4,7 @@ mod serialization; pub mod sql_key_store; pub use encrypted_store::{ - group, group_intent, group_message, identity, key_store_entry, topic_refresh_state, - DbConnection, EncryptedMessageStore, EncryptionKey, StorageOption, + db_connection, group, group_intent, group_message, identity, key_store_entry, + topic_refresh_state, EncryptedMessageStore, EncryptionKey, RawDbConnection, StorageOption, }; pub use errors::StorageError; diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 447d0e986..057f399a9 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -1,34 +1,26 @@ use log::{debug, error}; use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore}; -use std::{cell::RefCell, fmt}; use super::{ - encrypted_store::{key_store_entry::StoredKeyStoreEntry, DbConnection}, + encrypted_store::{db_connection::DbConnection, key_store_entry::StoredKeyStoreEntry}, serialization::{db_deserialize, db_serialize}, - EncryptedMessageStore, StorageError, + StorageError, }; use crate::{Delete, Fetch}; /// CRUD Operations for an [`EncryptedMessageStore`] +#[derive(Debug)] pub struct SqlKeyStore<'a> { - pub conn: RefCell<&'a mut DbConnection>, + conn: &'a DbConnection<'a>, } impl<'a> SqlKeyStore<'a> { - pub fn new(conn: &'a mut DbConnection) -> Self { - Self { conn: conn.into() } + pub fn new(conn: &'a DbConnection<'a>) -> Self { + Self { conn } } - pub fn conn(&self) -> &RefCell<&'a mut DbConnection> { - &self.conn - } -} - -impl fmt::Debug for SqlKeyStore<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SqlKeyStore") - .field("conn", &"DbConnection") - .finish() + pub fn conn(&self) -> &DbConnection<'a> { + self.conn } } @@ -41,11 +33,8 @@ impl OpenMlsKeyStore for SqlKeyStore<'_> { /// /// Returns an error if storing fails. fn store(&self, k: &[u8], v: &V) -> Result<(), Self::Error> { - EncryptedMessageStore::insert_or_update_key_store_entry( - *self.conn.borrow_mut(), - k.to_vec(), - db_serialize(v)?, - )?; + self.conn() + .insert_or_update_key_store_entry(k.to_vec(), db_serialize(v)?)?; Ok(()) } @@ -54,7 +43,7 @@ impl OpenMlsKeyStore for SqlKeyStore<'_> { /// /// Returns [`None`] if no value is stored for `k` or reading fails. fn read(&self, k: &[u8]) -> Option { - let fetch_result = (*self.conn.borrow_mut()).fetch(&k.to_vec()); + let fetch_result = self.conn().fetch(&k.to_vec()); if let Err(e) = fetch_result { error!("Failed to fetch key: {:?}", e); @@ -74,8 +63,7 @@ impl OpenMlsKeyStore for SqlKeyStore<'_> { /// Interface is unclear on expected behavior when item is already deleted - /// we choose to not surface an error if this is the case. fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { - let mut conn = self.conn.borrow_mut(); - let conn: &mut dyn Delete> = *conn; + let conn: &dyn Delete> = self.conn(); let num_deleted = conn.delete(k.to_vec())?; if num_deleted == 0 { debug!("No entry to delete for key {:?}", k); @@ -104,8 +92,8 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); - let mut conn = store.conn().unwrap(); - let key_store = SqlKeyStore::new(&mut conn); + let conn = &store.conn().unwrap(); + let key_store = SqlKeyStore::new(conn); let signature_keys = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm()).unwrap(); let index = "index".as_bytes(); assert!(key_store.read::(index).is_none()); diff --git a/xmtp_mls/src/xmtp_openmls_provider.rs b/xmtp_mls/src/xmtp_openmls_provider.rs index 1f5599c3d..a0cc26934 100644 --- a/xmtp_mls/src/xmtp_openmls_provider.rs +++ b/xmtp_mls/src/xmtp_openmls_provider.rs @@ -1,10 +1,7 @@ -use std::cell::RefCell; - -use diesel::Connection; use openmls_rust_crypto::RustCrypto; use openmls_traits::OpenMlsProvider; -use crate::storage::{sql_key_store::SqlKeyStore, DbConnection}; +use crate::storage::{db_connection::DbConnection, sql_key_store::SqlKeyStore}; #[derive(Debug)] pub struct XmtpOpenMlsProvider<'a> { @@ -13,44 +10,16 @@ pub struct XmtpOpenMlsProvider<'a> { } impl<'a> XmtpOpenMlsProvider<'a> { - pub fn new(conn: &'a mut DbConnection) -> Self { + pub fn new(conn: &'a DbConnection<'a>) -> Self { Self { crypto: RustCrypto::default(), key_store: SqlKeyStore::new(conn), } } - pub(crate) fn conn(&self) -> &RefCell<&'a mut DbConnection> { + pub(crate) fn conn(&self) -> &DbConnection<'a> { self.key_store.conn() } - - /// Start a new database transaction with the OpenMLS Provider from XMTP - /// # Arguments - /// `fun`: Scoped closure providing a MLSProvider to carry out the transaction - /// - /// # Examples - /// - /// ```ignore - /// let connection = EncryptedMessageStore::new_unencrypted(StorageOptions::default()); - /// XmtpOpenMlsProvider::transaction(conn, |provider| { - /// // do some operations requiring provider - /// // access the connection with .conn() - /// // wrap in a block so that the borrow is ended - /// { - /// provider.conn().borrow_mut() - /// } - /// }) - /// ``` - pub fn transaction(connection: &mut DbConnection, fun: F) -> Result - where - F: FnOnce(XmtpOpenMlsProvider) -> Result, - E: From, - { - connection.transaction(|conn| { - let provider = XmtpOpenMlsProvider::new(conn); - fun(provider) - }) - } } impl<'a> OpenMlsProvider for XmtpOpenMlsProvider<'a> { diff --git a/xmtp_proto/src/gen/xmtp.message_contents.serde.rs b/xmtp_proto/src/gen/xmtp.message_contents.serde.rs index 3042dc548..7eb3d12f9 100644 --- a/xmtp_proto/src/gen/xmtp.message_contents.serde.rs +++ b/xmtp_proto/src/gen/xmtp.message_contents.serde.rs @@ -1326,7 +1326,7 @@ impl serde::Serialize for EciesMessage { if let Some(v) = self.version.as_ref() { match v { ecies_message::Version::V1(v) => { - struct_ser.serialize_field("v1", pbjson::private::base64::encode(&v).as_str())?; + struct_ser.serialize_field("v1", pbjson::private::base64::encode(v).as_str())?; } } }