diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6a379fd9e..c276ae44f 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1,17 +1,19 @@ -use std::collections::HashSet; +use std::{collections::HashSet, mem::Discriminant}; +use log::debug; use openmls::{ framing::{MlsMessageIn, MlsMessageInBody}, + group::GroupEpoch, messages::Welcome, prelude::TlsSerializeTrait, }; use thiserror::Error; use tls_codec::{Deserialize, Error as TlsSerializationError}; -use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient}; +use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient}; use crate::{ api_client_wrapper::{ApiClientWrapper, IdentityUpdate}, - groups::MlsGroup, + groups::{IntentError, MlsGroup}, identity::Identity, storage::{group::GroupMembershipState, DbConnection, EncryptedMessageStore, StorageError}, types::Address, @@ -44,12 +46,46 @@ pub enum ClientError { Serialization(#[from] TlsSerializationError), #[error("key package verification: {0}")] KeyPackageVerification(#[from] KeyPackageVerificationError), - #[error("message processing: {0}")] - MessageProcessing(#[from] crate::groups::MessageProcessingError), + #[error("syncing errors: {0:?}")] + SyncingError(Vec), #[error("generic:{0}")] Generic(String), } +#[derive(Debug, Error)] +pub enum MessageProcessingError { + #[error("[{0}] already processed")] + AlreadyProcessed(u64), + #[error("diesel error: {0}")] + Diesel(#[from] diesel::result::Error), + #[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")] + InvalidSender { + message_time_ns: u64, + credential: Vec, + }, + #[error("openmls process message error: {0}")] + OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError), + #[error("merge pending commit: {0}")] + MergePendingCommit(#[from] openmls::group::MergePendingCommitError), + #[error("merge staged commit: {0}")] + MergeStagedCommit(#[from] openmls::group::MergeCommitError), + #[error( + "no pending commit to merge. group epoch is {group_epoch:?} and got {message_epoch:?}" + )] + NoPendingCommit { + message_epoch: GroupEpoch, + group_epoch: GroupEpoch, + }, + #[error("intent error: {0}")] + Intent(#[from] IntentError), + #[error("storage error: {0}")] + Storage(#[from] crate::storage::StorageError), + #[error("tls deserialization: {0}")] + TlsDeserialization(#[from] tls_codec::Error), + #[error("unsupported message type: {0:?}")] + UnsupportedMessageType(Discriminant), +} + impl From for ClientError { fn from(value: String) -> Self { Self::Generic(value) @@ -88,8 +124,16 @@ where } } + pub fn account_address(&self) -> Address { + self.identity.account_address.clone() + } + + pub fn installation_public_key(&self) -> Vec { + self.identity.installation_keys.to_public_vec() + } + // TODO: Remove this and figure out the correct lifetimes to allow long lived provider - pub fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> { + pub(crate) fn mls_provider(&self, conn: &'a mut DbConnection) -> XmtpOpenMlsProvider<'a> { XmtpOpenMlsProvider::new(conn) } @@ -134,7 +178,7 @@ where Ok(()) } - async fn get_all_active_installation_ids( + pub async fn get_all_active_installation_ids( &self, wallet_addresses: Vec, ) -> Result>, ClientError> { @@ -165,9 +209,58 @@ where Ok(installation_ids) } + pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result, ClientError> { + let mut conn = self.store.conn()?; + let last_synced_timestamp_ns = + EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?; + + let envelopes = self + .api_client + .read_topic(topic, last_synced_timestamp_ns as u64 + 1) + .await?; + + debug!( + "Pulled {} envelopes from topic {} starting at timestamp {}", + envelopes.len(), + topic, + last_synced_timestamp_ns + ); + + Ok(envelopes) + } + + pub(crate) fn process_for_topic( + &self, + topic: &str, + envelope_timestamp_ns: u64, + process_envelope: ProcessingFn, + ) -> Result + where + ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result, + { + // TODO: We can handle errors in the transaction() function to make error handling + // cleaner. Retryable errors can possibly be part of their own enum + XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| { + let is_updated = { + EncryptedMessageStore::update_last_synced_timestamp_for_topic( + &mut provider.conn().borrow_mut(), + topic, + envelope_timestamp_ns as i64, + )? + }; + if !is_updated { + return Err(MessageProcessingError::AlreadyProcessed( + envelope_timestamp_ns, + )); + } + process_envelope(provider) + }) + } + // Get a flat list of one key package per installation for all the wallet addresses provided. // Revoked installations will be omitted from the list - pub async fn get_key_packages_for_wallet_addresses( + #[allow(dead_code)] + pub(crate) async fn get_key_packages_for_wallet_addresses( &self, wallet_addresses: Vec, ) -> Result, ClientError> { @@ -179,7 +272,7 @@ where .await } - pub async fn get_key_packages_for_installation_ids( + pub(crate) async fn get_key_packages_for_installation_ids( &self, installation_ids: Vec>, ) -> Result, ClientError> { @@ -200,28 +293,23 @@ where // Download all unread welcome messages and convert to groups. // Returns any new groups created in the operation - pub async fn sync_welcomes(&self) -> Result>, ClientError> { + #[allow(dead_code)] + pub(crate) async fn sync_welcomes(&self) -> Result>, ClientError> { let welcome_topic = get_welcome_topic(&self.installation_public_key()); - let mut conn = self.store.conn()?; - // TODO: Use the last_message_timestamp_ns field on the TopicRefreshState to only fetch new messages - // Waiting for more atomic update methods - let envelopes = self.api_client.read_topic(&welcome_topic, 0).await?; + let envelopes = self.pull_from_topic(&welcome_topic).await?; let groups: Vec> = envelopes .into_iter() - .filter_map(|envelope| { - // TODO: We can handle errors in the transaction() function to make error handling - // cleaner. Retryable errors can possibly be part of their own enum - XmtpOpenMlsProvider::transaction(&mut conn, |provider| { + .filter_map(|envelope: Envelope| { + self.process_for_topic(&welcome_topic, envelope.timestamp_ns, |provider| { let welcome = match extract_welcome(&envelope.message) { Ok(welcome) => welcome, Err(err) => { log::error!("failed to extract welcome: {}", err); - return Ok::<_, ClientError>(None); + return Ok(None); } }; - // TODO: Update last_message_timestamp_ns on success or non-retryable error // TODO: Abort if error is retryable match MlsGroup::create_from_welcome(self, &provider, welcome) { Ok(mls_group) => Ok(Some(mls_group)), @@ -238,14 +326,6 @@ where Ok(groups) } - - pub fn account_address(&self) -> Address { - self.identity.account_address.clone() - } - - pub fn installation_public_key(&self) -> Vec { - self.identity.installation_keys.to_public_vec() - } } fn extract_welcome(welcome_bytes: &Vec) -> Result { @@ -320,5 +400,8 @@ mod tests { bob_received_groups.first().unwrap().group_id, alice_bob_group.group_id ); + + let duplicate_received_groups = bob.sync_welcomes().await.unwrap(); + assert_eq!(duplicate_received_groups.len(), 0); } } diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 071ab6543..c8ad9da16 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1,15 +1,11 @@ mod intents; mod members; - -#[cfg(test)] -use std::println as debug; - use intents::SendMessageIntentData; #[cfg(not(test))] use log::debug; use openmls::{ framing::ProtocolMessage, - group::{GroupEpoch, MergePendingCommitError}, + group::MergePendingCommitError, prelude::{ CredentialWithKey, CryptoConfig, GroupId, LeafNodeIndex, MlsGroup as OpenMlsGroup, MlsGroupConfig, MlsMessageIn, MlsMessageInBody, PrivateMessageIn, ProcessedMessage, @@ -18,15 +14,18 @@ use openmls::{ prelude_test::KeyPackage, }; use openmls_traits::OpenMlsProvider; -use std::mem::{discriminant, Discriminant}; +use std::mem::discriminant; +#[cfg(test)] +use std::println as debug; use thiserror::Error; use tls_codec::{Deserialize, Serialize}; use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient}; -use self::intents::{AddMembersIntentData, IntentError, PostCommitAction, RemoveMembersIntentData}; +pub use self::intents::IntentError; +use self::intents::{AddMembersIntentData, PostCommitAction, RemoveMembersIntentData}; use crate::{ api_client_wrapper::WelcomeMessage, - client::ClientError, + client::{ClientError, MessageProcessingError}, configuration::CIPHERSUITE, identity::Identity, storage::{ @@ -74,36 +73,6 @@ pub enum GroupError { Diesel(#[from] diesel::result::Error), } -#[derive(Debug, Error)] -pub enum MessageProcessingError { - #[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")] - InvalidSender { - message_time_ns: u64, - credential: Vec, - }, - #[error("openmls process message error: {0}")] - OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError), - #[error("merge pending commit: {0}")] - MergePendingCommit(#[from] openmls::group::MergePendingCommitError), - #[error("merge staged commit: {0}")] - MergeStagedCommit(#[from] openmls::group::MergeCommitError), - #[error( - "no pending commit to merge. group epoch is {group_epoch:?} and got {message_epoch:?}" - )] - NoPendingCommit { - message_epoch: GroupEpoch, - group_epoch: GroupEpoch, - }, - #[error("intent error: {0}")] - Intent(#[from] IntentError), - #[error("storage error: {0}")] - Storage(#[from] crate::storage::StorageError), - #[error("tls deserialization: {0}")] - TlsDeserialization(#[from] tls_codec::Error), - #[error("unsupported message type: {0:?}")] - UnsupportedMessageType(Discriminant), -} - pub struct MlsGroup<'c, ApiClient> { pub group_id: Vec, pub created_at_ns: i64, @@ -153,7 +122,9 @@ where mls_group.save(provider.key_store())?; let group_id = mls_group.group_id().to_vec(); let stored_group = StoredGroup::new(group_id.clone(), now_ns(), membership_state); - stored_group.store(*provider.conn().borrow_mut())?; + { + stored_group.store(*provider.conn().borrow_mut())?; + } Ok(Self::new(client, group_id, stored_group.created_at_ns)) } @@ -170,7 +141,9 @@ where let group_id = mls_group.group_id().to_vec(); let stored_group = StoredGroup::new(group_id.clone(), now_ns(), GroupMembershipState::Pending); - stored_group.store(*provider.conn().borrow_mut())?; + { + stored_group.store(*provider.conn().borrow_mut())?; + } Ok(Self::new(client, group_id, stored_group.created_at_ns)) } @@ -230,7 +203,6 @@ where fn process_own_message( &self, - conn: &mut DbConnection, intent: StoredGroupIntent, openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, @@ -241,7 +213,7 @@ where return Ok(()); } debug!( - "[{}]processing own message for intent {} / {:?}", + "[{}] processing own message for intent {} / {:?}", self.client.account_address(), intent.id, intent.kind @@ -257,7 +229,12 @@ where "no pending commit to merge. Group epoch: {}. Message epoch: {}", group_epoch, message_epoch ); - EncryptedMessageStore::set_group_intent_to_publish(conn, intent.id)?; + { + EncryptedMessageStore::set_group_intent_to_publish( + &mut provider.conn().borrow_mut(), + intent.id, + )?; + } return Err(MessageProcessingError::NoPendingCommit { message_epoch, @@ -265,14 +242,18 @@ where }); } debug!("[{}] merging pending commit", self.client.account_address()); - match openmls_group.merge_pending_commit(provider) { - Err(MergePendingCommitError::MlsGroupStateError(err)) => { - debug!("error merging commit: {}", err); - openmls_group.clear_pending_commit(); - EncryptedMessageStore::set_group_intent_to_publish(conn, intent.id)?; + if let Err(MergePendingCommitError::MlsGroupStateError(err)) = + openmls_group.merge_pending_commit(provider) + { + debug!("error merging commit: {}", err); + openmls_group.clear_pending_commit(); + { + EncryptedMessageStore::set_group_intent_to_publish( + &mut provider.conn().borrow_mut(), + intent.id, + )?; } - _ => (), - }; + } // TOOD: Handle writing transcript messages for adding/removing members } IntentKind::SendMessage => { @@ -280,27 +261,33 @@ where let group_id = openmls_group.group_id().as_slice(); let decrypted_message_data = intent_data.message.as_slice(); - StoredGroupMessage { - id: get_message_id(decrypted_message_data, group_id, envelope_timestamp_ns), - group_id: group_id.to_vec(), - decrypted_message_bytes: intent_data.message, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id: self.client.installation_public_key(), - sender_wallet_address: self.client.account_address(), + { + StoredGroupMessage { + id: get_message_id(decrypted_message_data, group_id, envelope_timestamp_ns), + group_id: group_id.to_vec(), + decrypted_message_bytes: intent_data.message, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id: self.client.installation_public_key(), + sender_wallet_address: self.client.account_address(), + } + .store(&mut provider.conn().borrow_mut())?; } - .store(conn)?; } }; - EncryptedMessageStore::set_group_intent_committed(conn, intent.id)?; + { + EncryptedMessageStore::set_group_intent_committed( + &mut provider.conn().borrow_mut(), + intent.id, + )?; + } Ok(()) } fn process_private_message( &self, - conn: &mut DbConnection, openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, @@ -328,7 +315,9 @@ where sender_installation_id, sender_wallet_address: sender_account_address, }; - message.store(conn)?; + { + message.store(&mut provider.conn().borrow_mut())?; + } } ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { // intentionally left blank. @@ -343,13 +332,13 @@ where ); openmls_group.merge_staged_commit(provider, *staged_commit)?; } - } + }; + Ok(()) } fn process_message( &self, - conn: &mut DbConnection, openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, envelope: &Envelope, @@ -363,13 +352,15 @@ where )), }?; - match EncryptedMessageStore::find_group_intent_by_payload_hash( - conn, - sha256(envelope.message.as_slice()), - ) { + let intent = { + EncryptedMessageStore::find_group_intent_by_payload_hash( + &mut provider.conn().borrow_mut(), + sha256(envelope.message.as_slice()), + ) + }; + match intent { // Intent with the payload hash matches Ok(Some(intent)) => self.process_own_message( - conn, intent, openmls_group, provider, @@ -379,7 +370,6 @@ where Err(err) => Err(MessageProcessingError::Storage(err)), // No matching intent found Ok(None) => self.process_private_message( - conn, openmls_group, provider, message, @@ -388,38 +378,36 @@ where } } - pub fn process_messages(&self, envelopes: Vec) -> Result<(), GroupError> { + pub(crate) fn process_messages(&self, envelopes: Vec) -> Result<(), GroupError> { let mut conn = self.client.store.conn()?; let provider = self.client.mls_provider(&mut conn); let mut openmls_group = self.load_mls_group(&provider)?; - let conn = &mut self.client.store.conn()?; let receive_errors: Vec = envelopes .into_iter() .map(|envelope| -> Result<(), MessageProcessingError> { - self.process_message(conn, &mut openmls_group, &provider, &envelope) + self.client.process_for_topic( + &self.topic(), + envelope.timestamp_ns, + |provider| -> Result<(), MessageProcessingError> { + self.process_message(&mut openmls_group, &provider, &envelope)?; + openmls_group.save(provider.key_store())?; + Ok(()) + }, + ) }) - .filter(|result| result.is_err()) - .map(|result| result.unwrap_err()) + .filter_map(|result| result.err()) .collect(); - openmls_group.save(provider.key_store())?; // TODO handle concurrency if receive_errors.is_empty() { Ok(()) } else { + debug!("Message processing errors: {:?}", receive_errors); Err(GroupError::ReceiveError(receive_errors)) } } pub async fn receive(&self) -> Result<(), GroupError> { - let topic = get_group_topic(&self.group_id); - let envelopes = self - .client - .api_client - .read_topic( - &topic, 0, // TODO: query from last query point - ) - .await?; - debug!("Received {} envelopes", envelopes.len()); + let envelopes = self.client.pull_from_topic(&self.topic()).await?; self.process_messages(envelopes) } @@ -490,12 +478,14 @@ where let provider = self.client.mls_provider(conn); let mut openmls_group = self.load_mls_group(&provider)?; - let intents = EncryptedMessageStore::find_group_intents( - &mut provider.conn().borrow_mut(), - self.group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - )?; + let intents = { + EncryptedMessageStore::find_group_intents( + &mut provider.conn().borrow_mut(), + self.group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + )? + }; for intent in intents { let result = self.get_publish_intent_data(&provider, &mut openmls_group, &intent); @@ -514,12 +504,14 @@ where .publish_to_group(vec![payload_slice]) .await?; - EncryptedMessageStore::set_group_intent_published( - &mut provider.conn().borrow_mut(), - intent.id, - sha256(payload_slice), - post_commit_data, - )?; + { + EncryptedMessageStore::set_group_intent_published( + &mut provider.conn().borrow_mut(), + intent.id, + sha256(payload_slice), + post_commit_data, + )?; + } } openmls_group.save(self.client.mls_provider(conn).key_store())?; @@ -639,6 +631,7 @@ where ciphertext: action.welcome_message.clone(), }) .collect(); + debug!("Sending {} welcomes", welcomes.len()); self.client.api_client.publish_welcomes(welcomes).await?; } } @@ -650,7 +643,7 @@ where Ok(()) } - pub fn topic(&self) -> String { + fn topic(&self) -> String { get_group_topic(&self.group_id) } } diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index 198ae109d..52e26eaba 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -62,7 +62,9 @@ impl Identity { }; identity.new_key_package(provider)?; - StoredIdentity::from(&identity).store(*provider.conn().borrow_mut())?; + { + StoredIdentity::from(&identity).store(*provider.conn().borrow_mut())?; + } // TODO: upload credential_with_key and last_resort_key_package diff --git a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs index ee4f42ab5..482e96a32 100644 --- a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs @@ -1,7 +1,7 @@ use diesel::prelude::*; -use super::schema::topic_refresh_state; -use crate::{impl_fetch, impl_store}; +use super::{schema::topic_refresh_state, DbConnection, EncryptedMessageStore}; +use crate::{impl_fetch, impl_store, storage::StorageError, Fetch, Store}; #[derive(Insertable, Identifiable, Queryable, Debug, Clone)] #[diesel(table_name = topic_refresh_state)] @@ -13,3 +13,119 @@ pub struct TopicRefreshState { impl_fetch!(TopicRefreshState, topic_refresh_state, String); impl_store!(TopicRefreshState, topic_refresh_state); + +impl EncryptedMessageStore { + pub fn get_last_synced_timestamp_for_topic( + conn: &mut DbConnection, + topic: &str, + ) -> Result { + let state: Option = conn.fetch(&topic.to_string())?; + match state { + Some(state) => Ok(state.last_message_timestamp_ns), + None => { + let new_state = TopicRefreshState { + topic: topic.to_string(), + last_message_timestamp_ns: 0, + }; + new_state.store(conn)?; + Ok(0) + } + } + } + + pub fn update_last_synced_timestamp_for_topic( + conn: &mut DbConnection, + topic: &str, + timestamp_ns: i64, + ) -> Result { + let state: Option = conn.fetch(&topic.to_string())?; + match state { + Some(state) => { + use super::schema::topic_refresh_state::dsl; + let num_updated = diesel::update(&state) + .filter(dsl::last_message_timestamp_ns.lt(timestamp_ns)) + .set(dsl::last_message_timestamp_ns.eq(timestamp_ns)) + .execute(conn)?; + Ok(num_updated == 1) + } + None => Err(StorageError::NotFound), + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::{storage::encrypted_store::tests::with_store, Fetch, Store}; + + #[test] + fn get_timestamp_with_no_existing_state() { + with_store(|mut conn| { + let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); + assert!(entry.is_none()); + assert_eq!( + EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, "topic") + .unwrap(), + 0 + ); + let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); + assert!(entry.is_some()); + }) + } + + #[test] + fn get_timestamp_with_existing_state() { + with_store(|mut conn| { + let entry = TopicRefreshState { + topic: "topic".to_string(), + last_message_timestamp_ns: 123, + }; + entry.store(&mut conn).unwrap(); + assert_eq!( + EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, "topic") + .unwrap(), + 123 + ); + }) + } + + #[test] + fn update_timestamp_when_bigger() { + with_store(|mut conn| { + let entry = TopicRefreshState { + topic: "topic".to_string(), + last_message_timestamp_ns: 123, + }; + entry.store(&mut conn).unwrap(); + assert_eq!( + EncryptedMessageStore::update_last_synced_timestamp_for_topic( + &mut conn, "topic", 124 + ) + .unwrap(), + true + ); + let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); + assert_eq!(entry.unwrap().last_message_timestamp_ns, 124); + }) + } + + #[test] + fn dont_update_timestamp_when_smaller() { + with_store(|mut conn| { + let entry = TopicRefreshState { + topic: "topic".to_string(), + last_message_timestamp_ns: 123, + }; + entry.store(&mut conn).unwrap(); + assert_eq!( + EncryptedMessageStore::update_last_synced_timestamp_for_topic( + &mut conn, "topic", 122 + ) + .unwrap(), + false + ); + let entry: Option = conn.fetch(&"topic".to_string()).unwrap(); + assert_eq!(entry.unwrap().last_message_timestamp_ns, 123); + }) + } +} diff --git a/xmtp_mls/src/xmtp_openmls_provider.rs b/xmtp_mls/src/xmtp_openmls_provider.rs index d6c1191ac..1f5599c3d 100644 --- a/xmtp_mls/src/xmtp_openmls_provider.rs +++ b/xmtp_mls/src/xmtp_openmls_provider.rs @@ -35,7 +35,10 @@ impl<'a> XmtpOpenMlsProvider<'a> { /// XmtpOpenMlsProvider::transaction(conn, |provider| { /// // do some operations requiring provider /// // access the connection with .conn() - /// provider.conn().borrow_mut() + /// // wrap in a block so that the borrow is ended + /// { + /// provider.conn().borrow_mut() + /// } /// }) /// ``` pub fn transaction(connection: &mut DbConnection, fun: F) -> Result