Skip to content

Commit

Permalink
Safe DB connection sharing (#351)
Browse files Browse the repository at this point in the history
This refactor aims to do the following:
1. Continue allow sharing DB connections between libxmtp logic and the OpenMLS keystore, including via transactions
2. Prevent errors arising from concurrent mutable access to this connection (whether at compile-time via multiple `&mut` refs or run-time via multiple `RefCell` borrows)
3. Simplify the DB interface

This is done by:
1. Create an `DbConnection` struct that wraps the `RawDbConnection` in a `RefCell`. The `DbConnection` uses interior mutability, therefore a non-mut reference to it can be freely shared as many times as desired.
2. Move all DB operations, as well as `fetch` and `store` onto this struct. Additionally, add a method to this struct that allows external callers to execute raw queries via a closure. This ensures that the `RawDbConnection` is only ever accessed internally within `DbConnection`, and uses function scope to make sure that borrows on the `RefCell` are always returned before they are used again.
3. Use visibility to ensure that the `RawDbConnection` is completely inaccessible outside of `EncryptedMessageStore` and `DbConnection`. It is only possible for external callers to interact with the database via `DbConnection`, and all references to `RawDbConnection` in the code have been replaced by `DbConnection`.
  • Loading branch information
richardhuaaa authored Nov 29, 2023
1 parent 6626895 commit 099a525
Show file tree
Hide file tree
Showing 22 changed files with 533 additions and 566 deletions.
1 change: 1 addition & 0 deletions xmtp/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(clippy::module_inception)]
#[cfg(test)]
pub mod test_utils {
use xmtp_proto::api_client::XmtpApiClient;
Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ where
.store
.take()
.ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?;
let mut conn = store.conn()?;
let provider = XmtpOpenMlsProvider::new(&mut conn);
let conn = store.conn()?;
let provider = XmtpOpenMlsProvider::new(&conn);
let identity = self
.identity_strategy
.initialize_identity(&store, &provider)?;
Expand Down
49 changes: 21 additions & 28 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use crate::{
identity::Identity,
retry::Retry,
storage::{
db_connection::DbConnection,
group::{GroupMembershipState, StoredGroup},
DbConnection, EncryptedMessageStore, StorageError,
EncryptedMessageStore, StorageError,
},
types::Address,
utils::topic::get_welcome_topic,
Expand Down Expand Up @@ -147,8 +148,8 @@ where
}

// TODO: Remove this and figure out the correct lifetimes to allow long lived provider
pub(crate) fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::new(conn)
pub(crate) fn mls_provider(&self, conn: &'a DbConnection<'a>) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::<'a>::new(conn)
}

pub fn create_group(&self) -> Result<MlsGroup<ApiClient>, ClientError> {
Expand All @@ -174,24 +175,21 @@ where
created_before_ns: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
Ok(EncryptedMessageStore::find_groups(
&mut self.store.conn()?,
allowed_states,
created_after_ns,
created_before_ns,
limit,
)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
Ok(self
.store
.conn()?
.find_groups(allowed_states, created_after_ns, created_before_ns, limit)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
}

pub async fn register_identity(&self) -> Result<(), ClientError> {
// TODO: Mark key package as last_resort in creation
let mut connection = self.store.conn()?;
let connection = self.store.conn()?;
let last_resort_kp = self
.identity
.new_key_package(&self.mls_provider(&mut connection))?;
.new_key_package(&self.mls_provider(&connection))?;
let last_resort_kp_bytes = last_resort_kp.tls_serialize_detached()?;

self.api_client
Expand Down Expand Up @@ -233,9 +231,8 @@ where
}

pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result<Vec<Envelope>, ClientError> {
let mut conn = self.store.conn()?;
let last_synced_timestamp_ns =
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?;
let conn = self.store.conn()?;
let last_synced_timestamp_ns = conn.get_last_synced_timestamp_for_topic(topic)?;

let envelopes = self
.api_client
Expand All @@ -261,14 +258,10 @@ where
where
ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result<ReturnValue, MessageProcessingError>,
{
XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| {
let is_updated = {
EncryptedMessageStore::update_last_synced_timestamp_for_topic(
&mut provider.conn().borrow_mut(),
topic,
envelope_timestamp_ns as i64,
)?
};
self.store.transaction(|provider| {
let is_updated = provider
.conn()
.update_last_synced_timestamp_for_topic(topic, envelope_timestamp_ns as i64)?;
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(
envelope_timestamp_ns,
Expand Down Expand Up @@ -302,12 +295,12 @@ where
.consume_key_packages(installation_ids)
.await?;

let mut conn = self.store.conn()?;
let conn = self.store.conn()?;

Ok(key_package_results
.values()
.map(|bytes| {
VerifiedKeyPackage::from_bytes(&self.mls_provider(&mut conn), bytes.as_slice())
VerifiedKeyPackage::from_bytes(&self.mls_provider(&conn), bytes.as_slice())
})
.collect::<Result<_, _>>()?)
}
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/codecs/membership_change.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mod tests {
encoded.clone().r#type.unwrap().type_id,
"group_membership_change"
);
assert!(encoded.content.len() > 0);
assert!(!encoded.content.is_empty());

let decoded = GroupMembershipChangeCodec::decode(encoded).unwrap();
assert_eq!(decoded.members_added[0], new_member);
Expand Down
6 changes: 3 additions & 3 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,17 @@ mod tests {
let wallet = generate_local_wallet();
let wallet_address = wallet.get_address();
let client = ClientBuilder::new_test_client(wallet.into()).await;
let mut conn = client.store.conn().unwrap();
let conn = client.store.conn().unwrap();
let key_package = client
.identity
.new_key_package(&client.mls_provider(&mut conn))
.new_key_package(&client.mls_provider(&conn))
.unwrap();
let verified_key_package = VerifiedKeyPackage::new(key_package, wallet_address.clone());

let intent = AddMembersIntentData::new(vec![verified_key_package.clone()]);
let as_bytes: Vec<u8> = intent.clone().try_into().unwrap();
let restored_intent =
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&mut conn))
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&conn))
.unwrap();

assert!(intent.key_packages[0]
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
// Load the member list for the group from the DB, merging together multiple installations into a single entry
pub fn members(&self) -> Result<Vec<GroupMember>, GroupError> {
let openmls_group =
self.load_mls_group(&self.client.mls_provider(&mut self.client.store.conn()?))?;
self.load_mls_group(&self.client.mls_provider(&self.client.store.conn()?))?;

let member_map: HashMap<String, GroupMember> = openmls_group
.members()
Expand Down
Loading

0 comments on commit 099a525

Please sign in to comment.