Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic syncing and concurrency #335

Merged
merged 17 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 111 additions & 28 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use std::collections::HashSet;
use std::{collections::HashSet, mem::Discriminant};

use log::debug;
use openmls::{
framing::{MlsMessageIn, MlsMessageInBody},
group::GroupEpoch,
messages::Welcome,
prelude::TlsSerializeTrait,
};
use thiserror::Error;
use tls_codec::{Deserialize, Error as TlsSerializationError};
use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient};
use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient};

use crate::{
api_client_wrapper::{ApiClientWrapper, IdentityUpdate},
groups::MlsGroup,
groups::{IntentError, MlsGroup},
identity::Identity,
storage::{group::GroupMembershipState, DbConnection, EncryptedMessageStore, StorageError},
types::Address,
Expand Down Expand Up @@ -44,12 +46,46 @@ pub enum ClientError {
Serialization(#[from] TlsSerializationError),
#[error("key package verification: {0}")]
KeyPackageVerification(#[from] KeyPackageVerificationError),
#[error("message processing: {0}")]
MessageProcessing(#[from] crate::groups::MessageProcessingError),
#[error("syncing errors: {0:?}")]
SyncingError(Vec<MessageProcessingError>),
#[error("generic:{0}")]
Generic(String),
}

#[derive(Debug, Error)]
pub enum MessageProcessingError {
#[error("[{0}] already processed")]
AlreadyProcessed(u64),
#[error("diesel error: {0}")]
Diesel(#[from] diesel::result::Error),
#[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")]
InvalidSender {
message_time_ns: u64,
credential: Vec<u8>,
},
#[error("openmls process message error: {0}")]
OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError),
#[error("merge pending commit: {0}")]
MergePendingCommit(#[from] openmls::group::MergePendingCommitError<StorageError>),
#[error("merge staged commit: {0}")]
MergeStagedCommit(#[from] openmls::group::MergeCommitError<StorageError>),
#[error(
"no pending commit to merge. group epoch is {group_epoch:?} and got {message_epoch:?}"
)]
NoPendingCommit {
message_epoch: GroupEpoch,
group_epoch: GroupEpoch,
},
#[error("intent error: {0}")]
Intent(#[from] IntentError),
#[error("storage error: {0}")]
Storage(#[from] crate::storage::StorageError),
#[error("tls deserialization: {0}")]
TlsDeserialization(#[from] tls_codec::Error),
#[error("unsupported message type: {0:?}")]
UnsupportedMessageType(Discriminant<MlsMessageInBody>),
}

impl From<String> for ClientError {
fn from(value: String) -> Self {
Self::Generic(value)
Expand Down Expand Up @@ -88,8 +124,16 @@ where
}
}

pub fn account_address(&self) -> Address {
self.identity.account_address.clone()
}

pub fn installation_public_key(&self) -> Vec<u8> {
self.identity.installation_keys.to_public_vec()
}

// TODO: Remove this and figure out the correct lifetimes to allow long lived provider
pub fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
pub(crate) fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::new(conn)
}

Expand Down Expand Up @@ -134,7 +178,7 @@ where
Ok(())
}

async fn get_all_active_installation_ids(
pub async fn get_all_active_installation_ids(
&self,
wallet_addresses: Vec<String>,
) -> Result<Vec<Vec<u8>>, ClientError> {
Expand Down Expand Up @@ -165,9 +209,58 @@ where
Ok(installation_ids)
}

pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result<Vec<Envelope>, ClientError> {
let mut conn = self.store.conn()?;
let last_synced_timestamp_ns =
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?;

let envelopes = self
.api_client
.read_topic(topic, last_synced_timestamp_ns as u64 + 1)
.await?;

debug!(
"Pulled {} envelopes from topic {} starting at timestamp {}",
envelopes.len(),
topic,
last_synced_timestamp_ns
);

Ok(envelopes)
}

pub(crate) fn process_for_topic<ProcessingFn, ReturnValue>(
&self,
topic: &str,
envelope_timestamp_ns: u64,
process_envelope: ProcessingFn,
) -> Result<ReturnValue, MessageProcessingError>
where
ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result<ReturnValue, MessageProcessingError>,
{
// TODO: We can handle errors in the transaction() function to make error handling
// cleaner. Retryable errors can possibly be part of their own enum
XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| {
let is_updated = {
EncryptedMessageStore::update_last_synced_timestamp_for_topic(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping this in a block, because the borrow needs to be ended before it's borrowed again somewhere else (e.g. when OpenMLS writes to the key store), or else there will be a runtime error. Don't love this solution but it works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion, will implement a WrappedConnection struct in the next PR that can solve this problem

&mut provider.conn().borrow_mut(),
topic,
envelope_timestamp_ns as i64,
)?
};
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(
envelope_timestamp_ns,
));
}
process_envelope(provider)
})
}

// Get a flat list of one key package per installation for all the wallet addresses provided.
// Revoked installations will be omitted from the list
pub async fn get_key_packages_for_wallet_addresses(
#[allow(dead_code)]
pub(crate) async fn get_key_packages_for_wallet_addresses(
&self,
wallet_addresses: Vec<String>,
) -> Result<Vec<VerifiedKeyPackage>, ClientError> {
Expand All @@ -179,7 +272,7 @@ where
.await
}

pub async fn get_key_packages_for_installation_ids(
pub(crate) async fn get_key_packages_for_installation_ids(
&self,
installation_ids: Vec<Vec<u8>>,
) -> Result<Vec<VerifiedKeyPackage>, ClientError> {
Expand All @@ -200,28 +293,23 @@ where

// Download all unread welcome messages and convert to groups.
// Returns any new groups created in the operation
pub async fn sync_welcomes(&self) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
#[allow(dead_code)]
pub(crate) async fn sync_welcomes(&self) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
let welcome_topic = get_welcome_topic(&self.installation_public_key());
let mut conn = self.store.conn()?;
// TODO: Use the last_message_timestamp_ns field on the TopicRefreshState to only fetch new messages
// Waiting for more atomic update methods
let envelopes = self.api_client.read_topic(&welcome_topic, 0).await?;
let envelopes = self.pull_from_topic(&welcome_topic).await?;

let groups: Vec<MlsGroup<ApiClient>> = envelopes
.into_iter()
.filter_map(|envelope| {
// TODO: We can handle errors in the transaction() function to make error handling
// cleaner. Retryable errors can possibly be part of their own enum
XmtpOpenMlsProvider::transaction(&mut conn, |provider| {
.filter_map(|envelope: Envelope| {
self.process_for_topic(&welcome_topic, envelope.timestamp_ns, |provider| {
let welcome = match extract_welcome(&envelope.message) {
Ok(welcome) => welcome,
Err(err) => {
log::error!("failed to extract welcome: {}", err);
return Ok::<_, ClientError>(None);
return Ok(None);
}
};

// TODO: Update last_message_timestamp_ns on success or non-retryable error
// TODO: Abort if error is retryable
match MlsGroup::create_from_welcome(self, &provider, welcome) {
Ok(mls_group) => Ok(Some(mls_group)),
Expand All @@ -238,14 +326,6 @@ where

Ok(groups)
}

pub fn account_address(&self) -> Address {
self.identity.account_address.clone()
}

pub fn installation_public_key(&self) -> Vec<u8> {
self.identity.installation_keys.to_public_vec()
}
}

fn extract_welcome(welcome_bytes: &Vec<u8>) -> Result<Welcome, ClientError> {
Expand Down Expand Up @@ -320,5 +400,8 @@ mod tests {
bob_received_groups.first().unwrap().group_id,
alice_bob_group.group_id
);

let duplicate_received_groups = bob.sync_welcomes().await.unwrap();
assert_eq!(duplicate_received_groups.len(), 0);
}
}
Loading
Loading