Skip to content

Commit

Permalink
Receive flow for application messages (#326)
Browse files Browse the repository at this point in the history
This reads in application messages, decrypts them, and stores them in the database.

There's still lots of missing pieces here:
1. Sync from last processed payload, rather than syncing all payloads - working on adding the `topic_state` table separately
2. Handle concurrency of syncing
3. Handle payloads coming from yourself (left a comment in the PR body)
4. Better unit tests - we can't decrypt payloads coming from ourselves, so we need to support adding members first
5. This only reads payloads for a group and doesn't return anything - we still need a higher-level method to list all messages in a group by reading from the DB, we also need a higher-level method to read payloads for all groups.

There's plenty of space left here also to add in the logic for other message types
  • Loading branch information
richardhuaaa authored Nov 14, 2023
1 parent 56687a5 commit a9dd41d
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 41 deletions.
7 changes: 3 additions & 4 deletions mls_validation_service/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use openmls::{
use openmls_rust_crypto::OpenMlsRustCrypto;
use openmls_traits::OpenMlsProvider;
use tonic::{Request, Response, Status};
use xmtp_mls::utils::id::serialize_group_id;
use xmtp_proto::xmtp::mls_validation::v1::{
validate_group_messages_response::ValidationResponse as ValidateGroupMessageValidationResponse,
validate_key_packages_response::ValidationResponse as ValidateKeyPackageValidationResponse,
validation_api_server::ValidationApi, ValidateGroupMessagesRequest,
ValidateGroupMessagesResponse, ValidateKeyPackagesRequest, ValidateKeyPackagesResponse,
};

use crate::validation_helpers::{hex_encode, identity_to_wallet_address};
use crate::validation_helpers::identity_to_wallet_address;

#[derive(Debug, Default)]
pub struct ValidationService {}
Expand Down Expand Up @@ -93,9 +94,7 @@ fn validate_group_message(message: Vec<u8>) -> Result<ValidateGroupMessageResult
let private_message: ProtocolMessage = msg_result.into();

Ok(ValidateGroupMessageResult {
// TODO: I wonder if we really want to be base64 encoding this or if we can treat it as a
// slice
group_id: hex_encode(private_message.group_id().as_slice()),
group_id: serialize_group_id(private_message.group_id().as_slice()),
})
}

Expand Down
4 changes: 0 additions & 4 deletions mls_validation_service/src/validation_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ use prost::Message;
use xmtp_mls::association::Eip191Association;
use xmtp_proto::xmtp::v3::message_contents::Eip191Association as Eip191AssociationProto;

pub fn hex_encode(key: &[u8]) -> String {
hex::encode(key)
}

pub fn identity_to_wallet_address(identity: &[u8], pub_key: &[u8]) -> Result<String, String> {
let proto_value = Eip191AssociationProto::decode(identity).map_err(|e| format!("{:?}", e))?;
let association = Eip191Association::from_proto_with_expected_address(
Expand Down
5 changes: 2 additions & 3 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient};

use crate::{
api_client_wrapper::{ApiClientWrapper, IdentityUpdate},
configuration::KEY_PACKAGE_TOP_UP_AMOUNT,
groups::MlsGroup,
identity::Identity,
storage::{group::GroupMembershipState, EncryptedMessageStore, StorageError},
Expand Down Expand Up @@ -58,10 +57,10 @@ impl From<&str> for ClientError {

#[derive(Debug)]
pub struct Client<ApiClient> {
pub api_client: ApiClientWrapper<ApiClient>,
pub(crate) api_client: ApiClientWrapper<ApiClient>,
pub(crate) _network: Network,
pub(crate) identity: Identity,
pub store: EncryptedMessageStore, // Temporarily exposed outside crate for CLI client
pub(crate) store: EncryptedMessageStore,
}

impl<ApiClient> Client<ApiClient>
Expand Down
2 changes: 0 additions & 2 deletions xmtp_mls/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,3 @@ pub const CIPHERSUITE: Ciphersuite =
Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519;

pub const MLS_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::Mls10;

pub const KEY_PACKAGE_TOP_UP_AMOUNT: u16 = 100;
171 changes: 165 additions & 6 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
mod intents;

#[cfg(test)]
use std::println as debug;

use intents::SendMessageIntentData;
#[cfg(not(test))]
use log::debug;
use openmls::{
prelude::{
CredentialWithKey, CryptoConfig, GroupId, LeafNodeIndex, MlsGroup as OpenMlsGroup,
MlsGroupConfig, WireFormatPolicy,
MlsGroupConfig, MlsMessageIn, MlsMessageInBody, PrivateMessageIn, ProcessedMessage,
ProcessedMessageContent, Sender, WireFormatPolicy,
},
prelude_test::KeyPackage,
};
use openmls_traits::OpenMlsProvider;
use std::mem::{discriminant, Discriminant};
use thiserror::Error;
use tls_codec::Serialize;
use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient};
use tls_codec::{Deserialize, Serialize};
use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient};

use self::intents::{AddMembersIntentData, IntentError, PostCommitAction, RemoveMembersIntentData};
use crate::{
api_client_wrapper::WelcomeMessage,
client::ClientError,
configuration::CIPHERSUITE,
identity::Identity,
storage::{
group::{GroupMembershipState, StoredGroup},
group_intent::{IntentKind, IntentState, NewGroupIntent, StoredGroupIntent},
group_message::{GroupMessageKind, StoredGroupMessage},
DbConnection, StorageError,
},
utils::{hash::sha256, time::now_ns, topic::get_group_topic},
utils::{hash::sha256, id::get_message_id, time::now_ns, topic::get_group_topic},
xmtp_openmls_provider::XmtpOpenMlsProvider,
Client, Delete, Store,
};
Expand Down Expand Up @@ -53,10 +61,29 @@ pub enum GroupError {
SelfUpdate(#[from] openmls::group::SelfUpdateError<StorageError>),
#[error("client: {0}")]
Client(#[from] ClientError),
#[error("receive errors: {0:?}")]
ReceiveError(Vec<MessageProcessingError>),
#[error("generic: {0}")]
Generic(String),
}

#[derive(Debug, Error)]
pub enum MessageProcessingError {
#[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")]
InvalidSender {
message_time_ns: u64,
credential: Vec<u8>,
},
#[error("openmls process message error: {0}")]
OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError),
#[error("storage error: {0}")]
Storage(#[from] crate::storage::StorageError),
#[error("tls deserialization: {0}")]
TlsDeserialization(#[from] tls_codec::Error),
#[error("unsupported message type: {0:?}")]
UnsupportedMessageType(Discriminant<MlsMessageInBody>),
}

pub struct MlsGroup<'c, ApiClient> {
pub group_id: Vec<u8>,
pub created_at_ns: i64,
Expand Down Expand Up @@ -97,7 +124,6 @@ where
&provider,
&client.identity.installation_keys,
&build_group_config(),
// TODO: Confirm I should be using the installation keys here
CredentialWithKey {
credential: client.identity.credential.clone(),
signature_key: client.identity.installation_keys.to_public_vec().into(),
Expand Down Expand Up @@ -132,6 +158,124 @@ where
Ok(messages)
}

fn validate_message_sender(
&self,
openmls_group: &mut OpenMlsGroup,
decrypted_message: &ProcessedMessage,
envelope_timestamp_ns: u64,
) -> Result<(String, Vec<u8>), MessageProcessingError> {
let mut sender_account_address = None;
let mut sender_installation_id = None;
if let Sender::Member(leaf_node_index) = decrypted_message.sender() {
if let Some(member) = openmls_group.member_at(*leaf_node_index) {
if member.credential.eq(decrypted_message.credential()) {
sender_account_address = Identity::get_validated_account_address(
member.credential.identity(),
&member.signature_key,
)
.ok();
sender_installation_id = Some(member.signature_key);
}
}
}

if sender_account_address.is_none() {
return Err(MessageProcessingError::InvalidSender {
message_time_ns: envelope_timestamp_ns,
credential: decrypted_message.credential().identity().to_vec(),
});
}
Ok((
sender_account_address.unwrap(),
sender_installation_id.unwrap(),
))
}

fn process_private_message(
&self,
openmls_group: &mut OpenMlsGroup,
provider: &XmtpOpenMlsProvider,
message: PrivateMessageIn,
envelope_timestamp_ns: u64,
) -> Result<(), MessageProcessingError> {
let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_account_address, sender_installation_id) =
self.validate_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?;

match decrypted_message.into_content() {
ProcessedMessageContent::ApplicationMessage(application_message) => {
let message_bytes = application_message.into_bytes();
let message_id =
get_message_id(&message_bytes, &self.group_id, envelope_timestamp_ns);
let message = StoredGroupMessage {
id: message_id,
group_id: self.group_id.clone(),
decrypted_message_bytes: message_bytes,
sent_at_ns: envelope_timestamp_ns as i64,
kind: GroupMessageKind::Application,
sender_installation_id,
sender_wallet_address: sender_account_address,
};
message.store(&mut self.client.store.conn()?)?;
}
ProcessedMessageContent::ProposalMessage(_proposal_ptr) => {
// intentionally left blank.
}
ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => {
// intentionally left blank.
}
ProcessedMessageContent::StagedCommitMessage(_commit_ptr) => {
// intentionally left blank.
}
}
Ok(())
}

pub fn process_messages(&self, envelopes: Vec<Envelope>) -> Result<(), GroupError> {
let provider = self.client.mls_provider();
let mut openmls_group = self.load_mls_group(&provider)?;
let receive_errors: Vec<MessageProcessingError> = envelopes
.into_iter()
.map(|envelope| -> Result<(), MessageProcessingError> {
let mls_message_in = MlsMessageIn::tls_deserialize_exact(&envelope.message)?;

match mls_message_in.extract() {
MlsMessageInBody::PrivateMessage(message) => self.process_private_message(
&mut openmls_group,
&provider,
message,
envelope.timestamp_ns,
),
other => Err(MessageProcessingError::UnsupportedMessageType(
discriminant(&other),
)),
}
})
.filter(|result| result.is_err())
.map(|result| result.unwrap_err())
.collect();
openmls_group.save(provider.key_store())?; // TODO handle concurrency

if receive_errors.is_empty() {
Ok(())
} else {
Err(GroupError::ReceiveError(receive_errors))
}
}

pub async fn receive(&self) -> Result<(), GroupError> {
let topic = get_group_topic(&self.group_id);
let envelopes = self
.client
.api_client
.read_topic(
&topic, 0, // TODO: query from last query point
)
.await?;
debug!("Received {} envelopes", envelopes.len());
self.process_messages(envelopes)
}

pub async fn send_message(&self, message: &[u8]) -> Result<(), GroupError> {
let mut conn = self.client.store.conn()?;
let intent_data: Vec<u8> = SendMessageIntentData::new(message.to_vec()).into();
Expand Down Expand Up @@ -377,7 +521,7 @@ mod tests {
use openmls_traits::OpenMlsProvider;
use xmtp_cryptography::utils::generate_local_wallet;

use crate::{builder::ClientBuilder, utils::topic::get_welcome_topic};
use crate::{builder::ClientBuilder, groups::GroupError, utils::topic::get_welcome_topic};

#[tokio::test]
async fn test_send_message() {
Expand All @@ -397,6 +541,21 @@ mod tests {
assert_eq!(messages.len(), 1)
}

#[tokio::test]
async fn test_receive_self_message() {
let wallet = generate_local_wallet();
let client = ClientBuilder::new_test_client(wallet.into()).await;
let group = client.create_group().expect("create group");
group.send_message(b"hello").await.expect("send message");

let result = group.receive().await;
if let GroupError::ReceiveError(errors) = result.err().unwrap() {
assert_eq!(errors.len(), 1);
} else {
panic!("expected GroupError::ReceiveError")
}
}

#[tokio::test]
async fn test_add_members() {
let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
Expand Down
21 changes: 18 additions & 3 deletions xmtp_mls/src/identity.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::default;

use openmls::{
extensions::LastResortExtension,
prelude::{
Expand Down Expand Up @@ -36,6 +34,8 @@ pub enum IdentityError {
StorageError(#[from] StorageError),
#[error("generating key package")]
KeyPackageGenerationError(#[from] KeyPackageNewError<StorageError>),
#[error("deserialization")]
Deserialization(#[from] prost::DecodeError),
}

#[derive(Debug)]
Expand All @@ -61,7 +61,7 @@ impl Identity {
installation_keys: signature_keys,
credential,
};
identity.new_key_package(&provider)?;
identity.new_key_package(provider)?;
StoredIdentity::from(&identity).store(&mut store.conn()?)?;

// TODO: upload credential_with_key and last_resort_key_package
Expand Down Expand Up @@ -122,6 +122,21 @@ impl Identity {
// Serialize into credential
Ok(Credential::new(association_proto.encode_to_vec(), CredentialType::Basic).unwrap())
}

pub(crate) fn get_validated_account_address(
credential: &[u8],
installation_public_key: &[u8],
) -> Result<String, IdentityError> {
let proto = Eip191AssociationProto::decode(credential)?;
let expected_wallet_address = proto.wallet_address.clone();
let association = Eip191Association::from_proto_with_expected_address(
installation_public_key,
proto,
expected_wallet_address,
)?;

Ok(association.address())
}
}

#[cfg(test)]
Expand Down
19 changes: 19 additions & 0 deletions xmtp_mls/src/utils/id.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use super::hash::sha256;

pub fn serialize_group_id(group_id: &[u8]) -> String {
// TODO: I wonder if we really want to be base64 encoding this or if we can treat it as a
// slice
hex::encode(group_id)
}

pub fn get_message_id(
decrypted_message_bytes: &[u8],
group_id: &[u8],
envelope_timestamp_ns: u64,
) -> Vec<u8> {
let mut id_vec = Vec::new();
id_vec.extend_from_slice(group_id);
id_vec.extend_from_slice(&envelope_timestamp_ns.to_be_bytes());
id_vec.extend_from_slice(decrypted_message_bytes);
sha256(&id_vec)
}
1 change: 1 addition & 0 deletions xmtp_mls/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod hash;
pub mod id;
#[cfg(test)]
pub mod test;
pub mod time;
Expand Down
6 changes: 4 additions & 2 deletions xmtp_mls/src/utils/topic.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub fn get_group_topic(group_id: &Vec<u8>) -> String {
format!("/xmtp/3/g-{}/proto", hex::encode(group_id))
use crate::utils::id::serialize_group_id;

pub fn get_group_topic(group_id: &[u8]) -> String {
format!("/xmtp/3/g-{}/proto", serialize_group_id(group_id))
}

pub fn get_welcome_topic(installation_id: &Vec<u8>) -> String {
Expand Down
Loading

0 comments on commit a9dd41d

Please sign in to comment.