Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe DB connection sharing #351

Merged
merged 12 commits into from
Nov 29, 2023
1 change: 1 addition & 0 deletions xmtp/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(clippy::module_inception)]
#[cfg(test)]
pub mod test_utils {
use xmtp_proto::api_client::XmtpApiClient;
Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ where
.store
.take()
.ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?;
let mut conn = store.conn()?;
let provider = XmtpOpenMlsProvider::new(&mut conn);
let conn = store.conn()?;
let provider = XmtpOpenMlsProvider::new(&conn);
let identity = self
.identity_strategy
.initialize_identity(&store, &provider)?;
Expand Down
49 changes: 21 additions & 28 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use crate::{
identity::Identity,
retry::Retry,
storage::{
db_connection::DbConnection,
group::{GroupMembershipState, StoredGroup},
DbConnection, EncryptedMessageStore, StorageError,
EncryptedMessageStore, StorageError,
},
types::Address,
utils::topic::get_welcome_topic,
Expand Down Expand Up @@ -147,8 +148,8 @@ where
}

// TODO: Remove this and figure out the correct lifetimes to allow long lived provider
pub(crate) fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::new(conn)
pub(crate) fn mls_provider(&self, conn: &'a DbConnection<'a>) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::<'a>::new(conn)
}

pub fn create_group(&self) -> Result<MlsGroup<ApiClient>, ClientError> {
Expand All @@ -174,24 +175,21 @@ where
created_before_ns: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
Ok(EncryptedMessageStore::find_groups(
&mut self.store.conn()?,
allowed_states,
created_after_ns,
created_before_ns,
limit,
)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
Ok(self
.store
.conn()?
.find_groups(allowed_states, created_after_ns, created_before_ns, limit)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
}

pub async fn register_identity(&self) -> Result<(), ClientError> {
// TODO: Mark key package as last_resort in creation
let mut connection = self.store.conn()?;
let connection = self.store.conn()?;
let last_resort_kp = self
.identity
.new_key_package(&self.mls_provider(&mut connection))?;
.new_key_package(&self.mls_provider(&connection))?;
let last_resort_kp_bytes = last_resort_kp.tls_serialize_detached()?;

self.api_client
Expand Down Expand Up @@ -233,9 +231,8 @@ where
}

pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result<Vec<Envelope>, ClientError> {
let mut conn = self.store.conn()?;
let last_synced_timestamp_ns =
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?;
let conn = self.store.conn()?;
let last_synced_timestamp_ns = conn.get_last_synced_timestamp_for_topic(topic)?;

let envelopes = self
.api_client
Expand All @@ -261,14 +258,10 @@ where
where
ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result<ReturnValue, MessageProcessingError>,
{
XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| {
let is_updated = {
EncryptedMessageStore::update_last_synced_timestamp_for_topic(
&mut provider.conn().borrow_mut(),
topic,
envelope_timestamp_ns as i64,
)?
};
self.store.transaction(|provider| {
let is_updated = provider
.conn()
.update_last_synced_timestamp_for_topic(topic, envelope_timestamp_ns as i64)?;
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(
envelope_timestamp_ns,
Expand Down Expand Up @@ -302,12 +295,12 @@ where
.consume_key_packages(installation_ids)
.await?;

let mut conn = self.store.conn()?;
let conn = self.store.conn()?;

Ok(key_package_results
.values()
.map(|bytes| {
VerifiedKeyPackage::from_bytes(&self.mls_provider(&mut conn), bytes.as_slice())
VerifiedKeyPackage::from_bytes(&self.mls_provider(&conn), bytes.as_slice())
})
.collect::<Result<_, _>>()?)
}
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/codecs/membership_change.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mod tests {
encoded.clone().r#type.unwrap().type_id,
"group_membership_change"
);
assert!(encoded.content.len() > 0);
assert!(!encoded.content.is_empty());

let decoded = GroupMembershipChangeCodec::decode(encoded).unwrap();
assert_eq!(decoded.members_added[0], new_member);
Expand Down
6 changes: 3 additions & 3 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,17 @@ mod tests {
let wallet = generate_local_wallet();
let wallet_address = wallet.get_address();
let client = ClientBuilder::new_test_client(wallet.into()).await;
let mut conn = client.store.conn().unwrap();
let conn = client.store.conn().unwrap();
let key_package = client
.identity
.new_key_package(&client.mls_provider(&mut conn))
.new_key_package(&client.mls_provider(&conn))
.unwrap();
let verified_key_package = VerifiedKeyPackage::new(key_package, wallet_address.clone());

let intent = AddMembersIntentData::new(vec![verified_key_package.clone()]);
let as_bytes: Vec<u8> = intent.clone().try_into().unwrap();
let restored_intent =
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&mut conn))
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&conn))
.unwrap();

assert!(intent.key_packages[0]
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
// Load the member list for the group from the DB, merging together multiple installations into a single entry
pub fn members(&self) -> Result<Vec<GroupMember>, GroupError> {
let openmls_group =
self.load_mls_group(&self.client.mls_provider(&mut self.client.store.conn()?))?;
self.load_mls_group(&self.client.mls_provider(&self.client.store.conn()?))?;

let member_map: HashMap<String, GroupMember> = openmls_group
.members()
Expand Down
Loading
Loading