Skip to content

Commit

Permalink
Transactional MlsProvider (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx authored Nov 15, 2023
1 parent 2b3e0c3 commit b0f00f9
Show file tree
Hide file tree
Showing 13 changed files with 1,115 additions and 1,086 deletions.
1,539 changes: 764 additions & 775 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions xmtp_mls/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ pub enum IdentityStrategy<Owner> {
ExternalIdentity(Identity),
}

impl<Owner> IdentityStrategy<Owner>
impl<'a, Owner> IdentityStrategy<Owner>
where
Owner: InboxOwner,
{
fn initialize_identity(
self,
store: &EncryptedMessageStore,
provider: &XmtpOpenMlsProvider,
provider: &'a XmtpOpenMlsProvider,
) -> Result<Identity, ClientBuilderError> {
let identity_option: Option<Identity> =
store.conn()?.fetch(&())?.map(|i: StoredIdentity| i.into());
Expand All @@ -68,7 +68,7 @@ where
}
Ok(identity)
}
None => Ok(Identity::new(store, provider, &owner)?),
None => Ok(Identity::new(provider, &owner)?),
},
#[cfg(test)]
IdentityStrategy::ExternalIdentity(identity) => Ok(identity),
Expand Down Expand Up @@ -140,7 +140,8 @@ where
.store
.take()
.ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?;
let provider = XmtpOpenMlsProvider::new(&store);
let mut conn = store.conn()?;
let provider = XmtpOpenMlsProvider::new(&mut conn);
let identity = self
.identity_strategy
.initialize_identity(&store, &provider)?;
Expand Down
81 changes: 44 additions & 37 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
api_client_wrapper::{ApiClientWrapper, IdentityUpdate},
groups::MlsGroup,
identity::Identity,
storage::{group::GroupMembershipState, EncryptedMessageStore, StorageError},
storage::{group::GroupMembershipState, DbConnection, EncryptedMessageStore, StorageError},
types::Address,
utils::topic::get_welcome_topic,
verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage},
Expand Down Expand Up @@ -70,7 +70,7 @@ pub struct Client<ApiClient> {
pub(crate) store: EncryptedMessageStore,
}

impl<ApiClient> Client<ApiClient>
impl<'a, ApiClient> Client<ApiClient>
where
ApiClient: XmtpMlsClient + XmtpApiClient,
{
Expand All @@ -89,8 +89,8 @@ where
}

// TODO: Remove this and figure out the correct lifetimes to allow long lived provider
pub fn mls_provider(&self) -> XmtpOpenMlsProvider {
XmtpOpenMlsProvider::new(&self.store)
pub fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::new(conn)
}

pub fn create_group(&self) -> Result<MlsGroup<ApiClient>, ClientError> {
Expand All @@ -107,23 +107,24 @@ where
created_before_ns: Option<i64>,
limit: Option<i64>,
) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
Ok(self
.store
.find_groups(
&mut self.store.conn()?,
allowed_states,
created_after_ns,
created_before_ns,
limit,
)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
Ok(EncryptedMessageStore::find_groups(
&mut self.store.conn()?,
allowed_states,
created_after_ns,
created_before_ns,
limit,
)?
.into_iter()
.map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns))
.collect())
}

pub async fn register_identity(&self) -> Result<(), ClientError> {
// TODO: Mark key package as last_resort in creation
let last_resort_kp = self.identity.new_key_package(&self.mls_provider())?;
let mut connection = self.store.conn()?;
let last_resort_kp = self
.identity
.new_key_package(&self.mls_provider(&mut connection))?;
let last_resort_kp_bytes = last_resort_kp.tls_serialize_detached()?;

self.api_client
Expand Down Expand Up @@ -187,11 +188,13 @@ where
.consume_key_packages(installation_ids)
.await?;

let mls_provider = self.mls_provider();
let mut conn = self.store.conn()?;

Ok(key_package_results
.values()
.map(|bytes| VerifiedKeyPackage::from_bytes(&mls_provider, bytes.as_slice()))
.map(|bytes| {
VerifiedKeyPackage::from_bytes(&self.mls_provider(&mut conn), bytes.as_slice())
})
.collect::<Result<_, _>>()?)
}

Expand All @@ -200,32 +203,36 @@ where
pub async fn sync_welcomes(&self) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
let welcome_topic = get_welcome_topic(&self.installation_public_key());
let mut conn = self.store.conn()?;
let provider = self.mls_provider();
// TODO: Use the last_message_timestamp_ns field on the TopicRefreshState to only fetch new messages
// Waiting for more atomic update methods
let envelopes = self.api_client.read_topic(&welcome_topic, 0).await?;

let groups: Vec<MlsGroup<ApiClient>> = envelopes
.into_iter()
.filter_map(|envelope| {
// TODO: Wrap in a transaction
let welcome = match extract_welcome(&envelope.message) {
Ok(welcome) => welcome,
Err(err) => {
log::error!("failed to extract welcome: {}", err);
return None;
// TODO: We can handle errors in the transaction() function to make error handling
// cleaner. Retryable errors can possibly be part of their own enum
XmtpOpenMlsProvider::transaction(&mut conn, |provider| {
let welcome = match extract_welcome(&envelope.message) {
Ok(welcome) => welcome,
Err(err) => {
log::error!("failed to extract welcome: {}", err);
return Ok::<_, ClientError>(None);
}
};

// TODO: Update last_message_timestamp_ns on success or non-retryable error
// TODO: Abort if error is retryable
match MlsGroup::create_from_welcome(self, &provider, welcome) {
Ok(mls_group) => Ok(Some(mls_group)),
Err(err) => {
log::error!("failed to create group from welcome: {}", err);
Ok(None)
}
}
};

// TODO: Update last_message_timestamp_ns on success or non-retryable error
// TODO: Abort if error is retryable
match MlsGroup::create_from_welcome(self, &mut conn, &provider, welcome) {
Ok(mls_group) => Some(mls_group),
Err(err) => {
log::error!("failed to create group from welcome: {}", err);
None
}
}
})
.ok()
.flatten()
})
.collect();

Expand Down
6 changes: 4 additions & 2 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,18 @@ mod tests {
let wallet = generate_local_wallet();
let wallet_address = wallet.get_address();
let client = ClientBuilder::new_test_client(wallet.into()).await;
let mut conn = client.store.conn().unwrap();
let key_package = client
.identity
.new_key_package(&client.mls_provider())
.new_key_package(&client.mls_provider(&mut conn))
.unwrap();
let verified_key_package = VerifiedKeyPackage::new(key_package, wallet_address.clone());

let intent = AddMembersIntentData::new(vec![verified_key_package.clone()]);
let as_bytes: Vec<u8> = intent.clone().try_into().unwrap();
let restored_intent =
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider()).unwrap();
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider(&mut conn))
.unwrap();

assert!(intent.key_packages[0]
.inner
Expand Down
Loading

0 comments on commit b0f00f9

Please sign in to comment.