Skip to content

Commit

Permalink
Fix more merge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
richardhuaaa committed Nov 15, 2023
1 parent 3831002 commit 336be08
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 64 deletions.
53 changes: 23 additions & 30 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{collections::HashSet, mem::Discriminant};

use diesel::Connection;
use log::debug;
use openmls::{
framing::{MlsMessageIn, MlsMessageInBody},
Expand Down Expand Up @@ -212,9 +211,8 @@ where

pub(crate) async fn pull_from_topic(&self, topic: &str) -> Result<Vec<Envelope>, ClientError> {
let mut conn = self.store.conn()?;
let last_synced_timestamp_ns = self
.store
.get_last_synced_timestamp_for_topic(&mut conn, topic)?;
let last_synced_timestamp_ns =
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, topic)?;

let envelopes = self
.api_client
Expand All @@ -238,14 +236,13 @@ where
process_envelope: ProcessingFn,
) -> Result<ReturnValue, MessageProcessingError>
where
ProcessingFn: FnOnce(&mut DbConnection) -> Result<ReturnValue, MessageProcessingError>,
ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Result<ReturnValue, MessageProcessingError>,
{
// TODO: We can handle errors in the transaction() function to make error handling
// cleaner. Retryable errors can possibly be part of their own enum
XmtpOpenMlsProvider::transaction(&mut self.store.conn()?, |provider| {
let transaction_manager = &mut provider.conn().borrow_mut();
let is_updated = self.store.update_last_synced_timestamp_for_topic(
transaction_manager,
let is_updated = EncryptedMessageStore::update_last_synced_timestamp_for_topic(
&mut provider.conn().borrow_mut(),
topic,
envelope_timestamp_ns as i64,
)?;
Expand All @@ -254,7 +251,7 @@ where
envelope_timestamp_ns,
));
}
process_envelope(transaction_manager)
process_envelope(provider)
})
}

Expand Down Expand Up @@ -302,28 +299,24 @@ where
let groups: Vec<MlsGroup<ApiClient>> = envelopes
.into_iter()
.map(|envelope: Envelope| -> Option<MlsGroup<ApiClient>> {
self.process_for_topic(
&welcome_topic,
envelope.timestamp_ns,
|transaction_manager| {
let welcome = match extract_welcome(&envelope.message) {
Ok(welcome) => welcome,
Err(err) => {
log::error!("failed to extract welcome: {}", err);
return Ok(None);
}
};

// TODO: Abort if error is retryable
match MlsGroup::create_from_welcome(self, &provider, welcome) {
Ok(mls_group) => Ok(Some(mls_group)),
Err(err) => {
log::error!("failed to create group from welcome: {}", err);
Ok(None)
}
self.process_for_topic(&welcome_topic, envelope.timestamp_ns, |provider| {
let welcome = match extract_welcome(&envelope.message) {
Ok(welcome) => welcome,
Err(err) => {
log::error!("failed to extract welcome: {}", err);
return Ok(None);
}
},
)
};

// TODO: Abort if error is retryable
match MlsGroup::create_from_welcome(self, &provider, welcome) {
Ok(mls_group) => Ok(Some(mls_group)),
Err(err) => {
log::error!("failed to create group from welcome: {}", err);
Ok(None)
}
}
})
.ok()
.flatten()
})
Expand Down
35 changes: 17 additions & 18 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ where

fn process_own_message(
&self,
conn: &mut DbConnection,
intent: StoredGroupIntent,
openmls_group: &mut OpenMlsGroup,
provider: &XmtpOpenMlsProvider,
Expand All @@ -225,7 +224,10 @@ where
"no pending commit to merge. Group epoch: {}. Message epoch: {}",
group_epoch, message_epoch
);
EncryptedMessageStore::set_group_intent_to_publish(conn, intent.id)?;
EncryptedMessageStore::set_group_intent_to_publish(
&mut provider.conn().borrow_mut(),
intent.id,
)?;

return Err(MessageProcessingError::NoPendingCommit {
message_epoch,
Expand All @@ -237,7 +239,10 @@ where
Err(MergePendingCommitError::MlsGroupStateError(err)) => {
debug!("error merging commit: {}", err);
openmls_group.clear_pending_commit();
EncryptedMessageStore::set_group_intent_to_publish(conn, intent.id)?;
EncryptedMessageStore::set_group_intent_to_publish(
&mut provider.conn().borrow_mut(),
intent.id,
)?;
}
_ => (),
};
Expand All @@ -257,18 +262,20 @@ where
sender_installation_id: self.client.installation_public_key(),
sender_wallet_address: self.client.account_address(),
}
.store(conn)?;
.store(&mut provider.conn().borrow_mut())?;
}
};

EncryptedMessageStore::set_group_intent_committed(conn, intent.id)?;
EncryptedMessageStore::set_group_intent_committed(
&mut provider.conn().borrow_mut(),
intent.id,
)?;

Ok(())
}

fn process_private_message(
&self,
transaction_manager: &mut DbConnection,
openmls_group: &mut OpenMlsGroup,
provider: &XmtpOpenMlsProvider,
message: PrivateMessageIn,
Expand Down Expand Up @@ -297,7 +304,7 @@ where
sender_installation_id,
sender_wallet_address: sender_account_address,
};
message.store(transaction_manager)?;
message.store(&mut provider.conn().borrow_mut())?;
}
ProcessedMessageContent::ProposalMessage(_proposal_ptr) => {
// intentionally left blank.
Expand All @@ -319,7 +326,6 @@ where

fn process_message(
&self,
transaction_manager: &mut DbConnection,
openmls_group: &mut OpenMlsGroup,
provider: &XmtpOpenMlsProvider,
envelope: &Envelope,
Expand All @@ -334,12 +340,11 @@ where
}?;

match EncryptedMessageStore::find_group_intent_by_payload_hash(
transaction_manager,
&mut provider.conn().borrow_mut(),
sha256(envelope.message.as_slice()),
) {
// Intent with the payload hash matches
Ok(Some(intent)) => self.process_own_message(
transaction_manager,
intent,
openmls_group,
provider,
Expand All @@ -349,7 +354,6 @@ where
Err(err) => Err(MessageProcessingError::Storage(err)),
// No matching intent found
Ok(None) => self.process_private_message(
transaction_manager,
openmls_group,
provider,
message,
Expand All @@ -368,13 +372,8 @@ where
self.client.process_for_topic(
&self.topic(),
envelope.timestamp_ns,
|transaction_manager| -> Result<(), MessageProcessingError> {
self.process_message(
transaction_manager,
&mut openmls_group,
&provider,
&envelope,
)?;
|provider| -> Result<(), MessageProcessingError> {
self.process_message(&mut openmls_group, &provider, &envelope)?;
openmls_group.save(provider.key_store())?;
Ok(())
},
Expand Down
30 changes: 14 additions & 16 deletions xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ impl_store!(TopicRefreshState, topic_refresh_state);

impl EncryptedMessageStore {
pub fn get_last_synced_timestamp_for_topic(
&self,
conn: &mut DbConnection,
topic: &str,
) -> Result<i64, StorageError> {
Expand All @@ -35,7 +34,6 @@ impl EncryptedMessageStore {
}

pub fn update_last_synced_timestamp_for_topic(
&self,
conn: &mut DbConnection,
topic: &str,
timestamp_ns: i64,
Expand All @@ -62,12 +60,11 @@ pub(crate) mod tests {

#[test]
fn get_timestamp_with_no_existing_state() {
with_store(|store, mut conn| {
with_store(|mut conn| {
let entry: Option<TopicRefreshState> = conn.fetch(&"topic".to_string()).unwrap();
assert!(entry.is_none());
assert_eq!(
store
.get_last_synced_timestamp_for_topic(&mut conn, "topic")
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, "topic")
.unwrap(),
0
);
Expand All @@ -78,15 +75,14 @@ pub(crate) mod tests {

#[test]
fn get_timestamp_with_existing_state() {
with_store(|store, mut conn| {
with_store(|mut conn| {
let entry = TopicRefreshState {
topic: "topic".to_string(),
last_message_timestamp_ns: 123,
};
entry.store(&mut conn).unwrap();
assert_eq!(
store
.get_last_synced_timestamp_for_topic(&mut conn, "topic")
EncryptedMessageStore::get_last_synced_timestamp_for_topic(&mut conn, "topic")
.unwrap(),
123
);
Expand All @@ -95,16 +91,17 @@ pub(crate) mod tests {

#[test]
fn update_timestamp_when_bigger() {
with_store(|store, mut conn| {
with_store(|mut conn| {
let entry = TopicRefreshState {
topic: "topic".to_string(),
last_message_timestamp_ns: 123,
};
entry.store(&mut conn).unwrap();
assert_eq!(
store
.update_last_synced_timestamp_for_topic(&mut conn, "topic", 124)
.unwrap(),
EncryptedMessageStore::update_last_synced_timestamp_for_topic(
&mut conn, "topic", 124
)
.unwrap(),
true
);
let entry: Option<TopicRefreshState> = conn.fetch(&"topic".to_string()).unwrap();
Expand All @@ -114,16 +111,17 @@ pub(crate) mod tests {

#[test]
fn dont_update_timestamp_when_smaller() {
with_store(|store, mut conn| {
with_store(|mut conn| {
let entry = TopicRefreshState {
topic: "topic".to_string(),
last_message_timestamp_ns: 123,
};
entry.store(&mut conn).unwrap();
assert_eq!(
store
.update_last_synced_timestamp_for_topic(&mut conn, "topic", 122)
.unwrap(),
EncryptedMessageStore::update_last_synced_timestamp_for_topic(
&mut conn, "topic", 122
)
.unwrap(),
false
);
let entry: Option<TopicRefreshState> = conn.fetch(&"topic".to_string()).unwrap();
Expand Down

0 comments on commit 336be08

Please sign in to comment.