Skip to content

Commit

Permalink
Refactor into helper method
Browse files Browse the repository at this point in the history
  • Loading branch information
richardhuaaa committed Nov 14, 2023
1 parent 74e7169 commit 36e7c53
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 94 deletions.
55 changes: 52 additions & 3 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;
use std::{collections::HashSet, mem::Discriminant};

use diesel::Connection;
use log::debug;
use openmls::{
framing::{MlsMessageIn, MlsMessageInBody},
Expand All @@ -14,7 +15,7 @@ use crate::{
api_client_wrapper::{ApiClientWrapper, IdentityUpdate},
groups::MlsGroup,
identity::Identity,
storage::{group::GroupMembershipState, EncryptedMessageStore, StorageError},
storage::{group::GroupMembershipState, DbConnection, EncryptedMessageStore, StorageError},
types::Address,
utils::topic::get_welcome_topic,
verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage},
Expand Down Expand Up @@ -46,11 +47,32 @@ pub enum ClientError {
#[error("key package verification: {0}")]
KeyPackageVerification(#[from] KeyPackageVerificationError),
#[error("message processing: {0}")]
MessageProcessing(#[from] crate::groups::MessageProcessingError),
MessageProcessing(#[from] MessageProcessingError),
#[error("generic:{0}")]
Generic(String),
}

#[derive(Debug, Error)]
pub enum MessageProcessingError {
#[error("[{0}] already processed")]
AlreadyProcessed(u64),
#[error("diesel error: {0}")]
Diesel(#[from] diesel::result::Error),
#[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")]
InvalidSender {
message_time_ns: u64,
credential: Vec<u8>,
},
#[error("openmls process message error: {0}")]
OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError),
#[error("storage error: {0}")]
Storage(#[from] crate::storage::StorageError),
#[error("tls deserialization: {0}")]
TlsDeserialization(#[from] tls_codec::Error),
#[error("unsupported message type: {0:?}")]
UnsupportedMessageType(Discriminant<MlsMessageInBody>),
}

impl From<String> for ClientError {
fn from(value: String) -> Self {
Self::Generic(value)
Expand Down Expand Up @@ -194,6 +216,33 @@ where
Ok(envelopes)
}

pub(crate) fn process_for_topic<F>(
&self,
topic: &str,
envelope_timestamp_ns: u64,
process_envelope: F,
) -> Result<(), MessageProcessingError>
where
F: FnOnce(&mut DbConnection) -> Result<(), MessageProcessingError>,
{
self.store.conn()?.transaction(
|transaction_manager| -> Result<(), MessageProcessingError> {
let is_updated = self.store.update_last_synced_timestamp_for_topic(
transaction_manager,
topic,
envelope_timestamp_ns as i64,
)?;
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(
envelope_timestamp_ns,
));
}
process_envelope(transaction_manager)
},
)?;
Ok(())
}

// Get a flat list of one key package per installation for all the wallet addresses provided.
// Revoked installations will be omitted from the list
#[allow(dead_code)]
Expand Down
147 changes: 56 additions & 91 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mod intents;

use diesel::Connection;
use intents::SendMessageIntentData;
use openmls::{
prelude::{
Expand All @@ -11,15 +10,15 @@ use openmls::{
prelude_test::KeyPackage,
};
use openmls_traits::OpenMlsProvider;
use std::mem::{discriminant, Discriminant};
use std::mem::discriminant;
use thiserror::Error;
use tls_codec::{Deserialize, Serialize};
use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient};

use self::intents::{AddMembersIntentData, IntentError, PostCommitAction, RemoveMembersIntentData};
use crate::{
api_client_wrapper::WelcomeMessage,
client::ClientError,
client::{ClientError, MessageProcessingError},
configuration::CIPHERSUITE,
identity::Identity,
storage::{
Expand Down Expand Up @@ -65,27 +64,6 @@ pub enum GroupError {
Generic(String),
}

#[derive(Debug, Error)]
pub enum MessageProcessingError {
#[error("[{0}] already processed")]
AlreadyProcessed(u64),
#[error("diesel error: {0}")]
Diesel(#[from] diesel::result::Error),
#[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")]
InvalidSender {
message_time_ns: u64,
credential: Vec<u8>,
},
#[error("openmls process message error: {0}")]
OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError),
#[error("storage error: {0}")]
Storage(#[from] crate::storage::StorageError),
#[error("tls deserialization: {0}")]
TlsDeserialization(#[from] tls_codec::Error),
#[error("unsupported message type: {0:?}")]
UnsupportedMessageType(Discriminant<MlsMessageInBody>),
}

pub struct MlsGroup<'c, ApiClient> {
pub group_id: Vec<u8>,
pub created_at_ns: i64,
Expand Down Expand Up @@ -213,66 +191,45 @@ where

fn process_private_message(
&self,
transaction_manager: &mut DbConnection,
openmls_group: &mut OpenMlsGroup,
provider: &XmtpOpenMlsProvider,
message: PrivateMessageIn,
envelope_timestamp_ns: u64,
) -> Result<(), MessageProcessingError> {
self.client.store.conn()?.transaction(
|transaction_manager| -> Result<(), MessageProcessingError> {
let is_updated = self.client.store.update_last_synced_timestamp_for_topic(
transaction_manager,
&self.topic(),
envelope_timestamp_ns as i64,
)?;
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(
envelope_timestamp_ns,
));
}

// TODO include provider in transaction
let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_account_address, sender_installation_id) = self
.validate_message_sender(
openmls_group,
&decrypted_message,
envelope_timestamp_ns,
)?;

match decrypted_message.into_content() {
ProcessedMessageContent::ApplicationMessage(application_message) => {
let message_bytes = application_message.into_bytes();
let message_id =
get_message_id(&message_bytes, &self.group_id, envelope_timestamp_ns);
let message = StoredGroupMessage {
id: message_id,
group_id: self.group_id.clone(),
decrypted_message_bytes: message_bytes,
sent_at_ns: envelope_timestamp_ns as i64,
kind: GroupMessageKind::Application,
sender_installation_id,
sender_wallet_address: sender_account_address,
};
message.store(transaction_manager)?;
}
ProcessedMessageContent::ProposalMessage(_proposal_ptr) => {
// intentionally left blank.
}
ProcessedMessageContent::ExternalJoinProposalMessage(
_external_proposal_ptr,
) => {
// intentionally left blank.
}
ProcessedMessageContent::StagedCommitMessage(_commit_ptr) => {
// intentionally left blank.
}
// TODO include provider in transaction
let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_account_address, sender_installation_id) =
self.validate_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?;

match decrypted_message.into_content() {
ProcessedMessageContent::ApplicationMessage(application_message) => {
let message_bytes = application_message.into_bytes();
let message_id =
get_message_id(&message_bytes, &self.group_id, envelope_timestamp_ns);
let message = StoredGroupMessage {
id: message_id,
group_id: self.group_id.clone(),
decrypted_message_bytes: message_bytes,
sent_at_ns: envelope_timestamp_ns as i64,
kind: GroupMessageKind::Application,
sender_installation_id,
sender_wallet_address: sender_account_address,
};
message.store(transaction_manager)?;
}
ProcessedMessageContent::ProposalMessage(_proposal_ptr) => {
// intentionally left blank.
}
ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => {
// intentionally left blank.
}
ProcessedMessageContent::StagedCommitMessage(_commit_ptr) => {
// intentionally left blank.
}
};

openmls_group.save(provider.key_store())?; // TODO include provider in transaction
Ok(())
},
)?;
openmls_group.save(provider.key_store())?; // TODO include provider in transaction
Ok(())
}

Expand All @@ -282,23 +239,31 @@ where
let receive_errors: Vec<MessageProcessingError> = envelopes
.into_iter()
.map(|envelope| -> Result<(), MessageProcessingError> {
let mls_message_in = MlsMessageIn::tls_deserialize_exact(&envelope.message)?;

match mls_message_in.extract() {
MlsMessageInBody::PrivateMessage(message) => self.process_private_message(
&mut openmls_group,
&provider,
message,
envelope.timestamp_ns,
),
other => Err(MessageProcessingError::UnsupportedMessageType(
discriminant(&other),
)),
}
self.client.process_for_topic(
&self.topic(),
envelope.timestamp_ns,
|transaction_manager| -> Result<(), MessageProcessingError> {
let mls_message_in =
MlsMessageIn::tls_deserialize_exact(&envelope.message)?;

match mls_message_in.extract() {
MlsMessageInBody::PrivateMessage(message) => self
.process_private_message(
transaction_manager,
&mut openmls_group,
&provider,
message,
envelope.timestamp_ns,
),
other => Err(MessageProcessingError::UnsupportedMessageType(
discriminant(&other),
)),
}
},
)
})
.filter_map(|result| result.err())
.collect();
openmls_group.save(provider.key_store())?; // TODO handle concurrency

if receive_errors.is_empty() {
Ok(())
Expand Down

0 comments on commit 36e7c53

Please sign in to comment.