Skip to content

Commit

Permalink
Topic syncing and concurrency (#335)
Browse files Browse the repository at this point in the history
1. Only pull down payloads we haven't already processed (via `pull_from_topic`)
2. Do all payload processing within an atomic transaction that also updates the `last_synced_payload_time` (via `process_for_topic`)

These methods are agnostic to which topic, so can be used for both welcomes and group message processing, as well as any other case in the future. Unsure if they belong on `Client` or somewhere else.

Some missing pieces:
1. More complicated concurrency test cases
1. Sophisticated error handling (when an error happens, do we update `last_synced_payload_time` or do we not?)
1. Will add a follow-up PR to improve the [safety](#335 (comment)) of borrowing the DB connection in a transaction
  • Loading branch information
richardhuaaa authored Nov 17, 2023
1 parent 945205f commit 6c05ffa
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 130 deletions.
139 changes: 111 additions & 28 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use std::collections::HashSet;
use std::{collections::HashSet, mem::Discriminant};

use log::debug;
use openmls::{
framing::{MlsMessageIn, MlsMessageInBody},
group::GroupEpoch,
messages::Welcome,
prelude::TlsSerializeTrait,
};
use thiserror::Error;
use tls_codec::{Deserialize, Error as TlsSerializationError};
use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient};
use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient};

use crate::{
api_client_wrapper::{ApiClientWrapper, IdentityUpdate},
groups::MlsGroup,
groups::{IntentError, MlsGroup},
identity::Identity,
storage::{
group::{GroupMembershipState, StoredGroup},
Expand Down Expand Up @@ -48,12 +50,46 @@ pub enum ClientError {
Serialization(#[from] TlsSerializationError),
#[error("key package verification: {0}")]
KeyPackageVerification(#[from] KeyPackageVerificationError),
#[error("message processing: {0}")]
MessageProcessing(#[from] crate::groups::MessageProcessingError),
#[error("syncing errors: {0:?}")]
SyncingError(Vec<MessageProcessingError>),
#[error("generic:{0}")]
Generic(String),
}

#[derive(Debug, Error)]
pub enum MessageProcessingError {
#[error("[{0}] already processed")]
AlreadyProcessed(u64),
#[error("diesel error: {0}")]
Diesel(#[from] diesel::result::Error),
#[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")]
InvalidSender {
message_time_ns: u64,
credential: Vec<u8>,
},
#[error("openmls process message error: {0}")]
OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError),
#[error("merge pending commit: {0}")]
MergePendingCommit(#[from] openmls::group::MergePendingCommitError<StorageError>),
#[error("merge staged commit: {0}")]
MergeStagedCommit(#[from] openmls::group::MergeCommitError<StorageError>),
#[error(
"no pending commit to merge. group epoch is {group_epoch:?} and got {message_epoch:?}"
)]
NoPendingCommit {
message_epoch: GroupEpoch,
group_epoch: GroupEpoch,
},
#[error("intent error: {0}")]
Intent(#[from] IntentError),
#[error("storage error: {0}")]
Storage(#[from] crate::storage::StorageError),
#[error("tls deserialization: {0}")]
TlsDeserialization(#[from] tls_codec::Error),
#[error("unsupported message type: {0:?}")]
UnsupportedMessageType(Discriminant<MlsMessageInBody>),
}

impl From<String> for ClientError {
fn from(value: String) -> Self {
Self::Generic(value)
Expand Down Expand Up @@ -92,8 +128,16 @@ where
}
}

pub fn account_address(&self) -> Address {
self.identity.account_address.clone()
}

pub fn installation_public_key(&self) -> Vec<u8> {
self.identity.installation_keys.to_public_vec()
}

// TODO: Remove this and figure out the correct lifetimes to allow long lived provider
pub fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
pub(crate) fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> {
XmtpOpenMlsProvider::new(conn)
}

Expand Down Expand Up @@ -147,7 +191,7 @@ where
Ok(())
}

async fn get_all_active_installation_ids(
pub async fn get_all_active_installation_ids(
&self,
wallet_addresses: Vec<String>,
) -> Result<Vec<Vec<u8>>, ClientError> {
Expand Down Expand Up @@ -178,9 +222,58 @@ where
Ok(installation_ids)
}

pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result<Vec<Envelope>, ClientError> {
let mut conn = self.store.conn()?;
let last_synced_timestamp_ns =
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?;

let envelopes = self
.api_client
.read_topic(topic, last_synced_timestamp_ns as u64 + 1)
.await?;

debug!(
"Pulled {} envelopes from topic {} starting at timestamp {}",
envelopes.len(),
topic,
last_synced_timestamp_ns
);

Ok(envelopes)
}

pub(crate) fn process_for_topic<ProcessingFn, ReturnValue>(
&self,
topic: &str,
envelope_timestamp_ns: u64,
process_envelope: ProcessingFn,
) -> Result<ReturnValue, MessageProcessingError>
where
ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result<ReturnValue, MessageProcessingError>,
{
// TODO: We can handle errors in the transaction() function to make error handling
// cleaner. Retryable errors can possibly be part of their own enum
XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| {
let is_updated = {
EncryptedMessageStore::update_last_synced_timestamp_for_topic(
&mut provider.conn().borrow_mut(),
topic,
envelope_timestamp_ns as i64,
)?
};
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(
envelope_timestamp_ns,
));
}
process_envelope(provider)
})
}

// Get a flat list of one key package per installation for all the wallet addresses provided.
// Revoked installations will be omitted from the list
pub async fn get_key_packages_for_wallet_addresses(
#[allow(dead_code)]
pub(crate) async fn get_key_packages_for_wallet_addresses(
&self,
wallet_addresses: Vec<String>,
) -> Result<Vec<VerifiedKeyPackage>, ClientError> {
Expand All @@ -192,7 +285,7 @@ where
.await
}

pub async fn get_key_packages_for_installation_ids(
pub(crate) async fn get_key_packages_for_installation_ids(
&self,
installation_ids: Vec<Vec<u8>>,
) -> Result<Vec<VerifiedKeyPackage>, ClientError> {
Expand All @@ -213,28 +306,23 @@ where

// Download all unread welcome messages and convert to groups.
// Returns any new groups created in the operation
pub async fn sync_welcomes(&self) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
#[allow(dead_code)]
pub(crate) async fn sync_welcomes(&self) -> Result<Vec<MlsGroup<ApiClient>>, ClientError> {
let welcome_topic = get_welcome_topic(&self.installation_public_key());
let mut conn = self.store.conn()?;
// TODO: Use the last_message_timestamp_ns field on the TopicRefreshState to only fetch new messages
// Waiting for more atomic update methods
let envelopes = self.api_client.read_topic(&welcome_topic, 0).await?;
let envelopes = self.pull_from_topic(&welcome_topic).await?;

let groups: Vec<MlsGroup<ApiClient>> = envelopes
.into_iter()
.filter_map(|envelope| {
// TODO: We can handle errors in the transaction() function to make error handling
// cleaner. Retryable errors can possibly be part of their own enum
XmtpOpenMlsProvider::transaction(&mut conn, |provider| {
.filter_map(|envelope: Envelope| {
self.process_for_topic(&welcome_topic, envelope.timestamp_ns, |provider| {
let welcome = match extract_welcome(&envelope.message) {
Ok(welcome) => welcome,
Err(err) => {
log::error!("failed to extract welcome: {}", err);
return Ok::<_, ClientError>(None);
return Ok(None);
}
};

// TODO: Update last_message_timestamp_ns on success or non-retryable error
// TODO: Abort if error is retryable
match MlsGroup::create_from_welcome(self, &provider, welcome) {
Ok(mls_group) => Ok(Some(mls_group)),
Expand All @@ -251,14 +339,6 @@ where

Ok(groups)
}

pub fn account_address(&self) -> Address {
self.identity.account_address.clone()
}

pub fn installation_public_key(&self) -> Vec<u8> {
self.identity.installation_keys.to_public_vec()
}
}

fn extract_welcome(welcome_bytes: &Vec<u8>) -> Result<Welcome, ClientError> {
Expand Down Expand Up @@ -333,5 +413,8 @@ mod tests {
bob_received_groups.first().unwrap().group_id,
alice_bob_group.group_id
);

let duplicate_received_groups = bob.sync_welcomes().await.unwrap();
assert_eq!(duplicate_received_groups.len(), 0);
}
}
Loading

0 comments on commit 6c05ffa

Please sign in to comment.