Skip to content

Commit

Permalink
Publish intents flows (#315)
Browse files Browse the repository at this point in the history
* Implement publish flow of state machine

* Add a few tests

* Refactor rand_string and rand_vec

* Remove unused imports

* Run clippy fix

* Fix issues with updating existing entries

* Clippy fix

* Remove dependency on tempfile

* Remove all usage of local tmp paths

* Require explicit EncryptedMessageStore path

* Run clippy --fix

* Format

* Fmt

* Use generic error

* Add one more small test

* Merge changes

* Fix import

* Un-prefix import

* Remove unused import

* Add missing parameter error

* Remove unused field

* Remove test field

* Add tests for intent data

* Address review comments

* Use ok_or

* Clean up error handling

* Format using nightly toolchain
  • Loading branch information
neekolas authored Nov 7, 2023
1 parent e4a4fba commit f33fd97
Show file tree
Hide file tree
Showing 28 changed files with 1,869 additions and 273 deletions.
4 changes: 0 additions & 4 deletions mls_validation_service/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,11 @@ impl ValidationApi for ValidationService {
match validate_group_message(message.group_message_bytes_tls_serialized) {
Ok(res) => ValidateGroupMessageValidationResponse {
group_id: res.group_id,
epoch: res.epoch,
error_message: "".to_string(),
is_ok: true,
},
Err(e) => ValidateGroupMessageValidationResponse {
group_id: "".to_string(),
epoch: 0,
error_message: e,
is_ok: false,
},
Expand All @@ -86,7 +84,6 @@ impl ValidationApi for ValidationService {

struct ValidateGroupMessageResult {
group_id: String,
epoch: u64,
}

fn validate_group_message(message: Vec<u8>) -> Result<ValidateGroupMessageResult, String> {
Expand All @@ -99,7 +96,6 @@ fn validate_group_message(message: Vec<u8>) -> Result<ValidateGroupMessageResult
// TODO: I wonder if we really want to be base64 encoding this or if we can treat it as a
// slice
group_id: hex_encode(private_message.group_id().as_slice()),
epoch: private_message.epoch().as_u64(),
})
}

Expand Down
26 changes: 19 additions & 7 deletions xmtp_mls/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ where
parameter: "api_client",
})?;
let network = self.network;
let store = self.store.take().unwrap_or_default();
let store = self
.store
.take()
.ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?;
let provider = XmtpOpenMlsProvider::new(&store);
let identity = self
.identity_strategy
Expand All @@ -150,13 +153,13 @@ mod tests {

use ethers::signers::{LocalWallet, Signer, Wallet};
use ethers_core::k256::ecdsa::SigningKey;
use tempfile::TempPath;
use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient;
use xmtp_cryptography::utils::generate_local_wallet;

use super::{ClientBuilder, IdentityStrategy};
use crate::{
storage::{EncryptedMessageStore, StorageOption},
utils::test::tmp_path,
Client,
};

Expand All @@ -171,10 +174,22 @@ mod tests {
self.api_client(get_local_grpc_client().await)
}

fn temp_store(self) -> Self {
let tmpdb = tmp_path();
self.store(
EncryptedMessageStore::new_unencrypted(StorageOption::Persistent(tmpdb)).unwrap(),
)
}

pub async fn new_test_client(
strat: IdentityStrategy<Wallet<SigningKey>>,
) -> Client<GrpcClient> {
Self::new(strat).local_grpc().await.build().unwrap()
Self::new(strat)
.temp_store()
.local_grpc()
.await
.build()
.unwrap()
}
}

Expand All @@ -189,10 +204,7 @@ mod tests {

#[tokio::test]
async fn identity_persistence_test() {
let tmpdb = TempPath::from_path("./db.db3")
.to_str()
.unwrap()
.to_string();
let tmpdb = tmp_path();
let wallet = generate_local_wallet();

// Generate a new Wallet + Store
Expand Down
22 changes: 19 additions & 3 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient};
use crate::{
api_client_wrapper::{ApiClientWrapper, IdentityUpdate},
configuration::KEY_PACKAGE_TOP_UP_AMOUNT,
groups::MlsGroup,
identity::Identity,
storage::{EncryptedMessageStore, StorageError},
storage::{group::GroupMembershipState, EncryptedMessageStore, StorageError},
types::Address,
verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage},
xmtp_openmls_provider::XmtpOpenMlsProvider,
Expand All @@ -30,7 +31,7 @@ pub enum ClientError {
#[error("storage error: {0}")]
Storage(#[from] StorageError),
#[error("dieselError: {0}")]
Ddd(#[from] diesel::result::Error),
Diesel(#[from] diesel::result::Error),
#[error("Query failed: {0}")]
QueryError(#[from] xmtp_proto::api_client::Error),
#[error("identity error: {0}")]
Expand Down Expand Up @@ -82,10 +83,17 @@ where
}

// TODO: Remove this and figure out the correct lifetimes to allow long lived provider
fn mls_provider(&self) -> XmtpOpenMlsProvider {
pub fn mls_provider(&self) -> XmtpOpenMlsProvider {
XmtpOpenMlsProvider::new(&self.store)
}

pub fn create_group(&self) -> Result<MlsGroup<ApiClient>, ClientError> {
let group = MlsGroup::create_and_insert(self, GroupMembershipState::Allowed)
.map_err(|e| ClientError::Generic(format!("group create error {}", e)))?;

Ok(group)
}

pub async fn register_identity(&self) -> Result<(), ClientError> {
// TODO: Mark key package as last_resort in creation
let last_resort_kp = self.identity.new_key_package(&self.mls_provider())?;
Expand Down Expand Up @@ -154,6 +162,14 @@ where
.get_all_active_installation_ids(wallet_addresses)
.await?;

self.get_key_packages_for_installation_ids(installation_ids)
.await
}

pub async fn get_key_packages_for_installation_ids(
&self,
installation_ids: Vec<Vec<u8>>,
) -> Result<Vec<VerifiedKeyPackage>, ClientError> {
let key_package_results = self
.api_client
.consume_key_packages(installation_ids)
Expand Down
239 changes: 239 additions & 0 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
use openmls::prelude::MlsMessageOut;
use prost::{DecodeError, Message};
use thiserror::Error;
use tls_codec::Serialize;
use xmtp_proto::xmtp::mls::database::{
add_members_publish_data::{Version as AddMembersVersion, V1 as AddMembersV1},
post_commit_action::{Kind as PostCommitActionKind, SendWelcomes as SendWelcomesProto},
send_message_publish_data::{Version as SendMessageVersion, V1 as SendMessageV1},
AddMembersPublishData, PostCommitAction as PostCommitActionProto, SendMessagePublishData,
};

use crate::{
verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage},
xmtp_openmls_provider::XmtpOpenMlsProvider,
};

#[derive(Debug, Error)]
pub enum IntentError {
#[error("decode error: {0}")]
Decode(#[from] DecodeError),
#[error("key package verification: {0}")]
KeyPackageVerification(#[from] KeyPackageVerificationError),
#[error("tls codec: {0}")]
TlsCodec(#[from] tls_codec::Error),
#[error("generic: {0}")]
Generic(String),
}

#[derive(Debug, Clone)]
pub struct SendMessageIntentData {
pub message: Vec<u8>,
}

impl SendMessageIntentData {
pub fn new(message: Vec<u8>) -> Self {
Self { message }
}

pub(crate) fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
SendMessagePublishData {
version: Some(SendMessageVersion::V1(SendMessageV1 {
payload_bytes: self.message.clone(),
})),
}
.encode(&mut buf)
.unwrap();

buf
}

pub(crate) fn from_bytes(data: &[u8]) -> Result<Self, IntentError> {
let msg = SendMessagePublishData::decode(data)?;
let payload_bytes = match msg.version {
Some(SendMessageVersion::V1(v1)) => v1.payload_bytes,
None => return Err(IntentError::Generic("missing payload".to_string())),
};

Ok(Self::new(payload_bytes))
}
}

impl From<SendMessageIntentData> for Vec<u8> {
fn from(intent: SendMessageIntentData) -> Self {
intent.to_bytes()
}
}

#[derive(Debug, Clone)]
pub struct AddMembersIntentData {
pub key_packages: Vec<VerifiedKeyPackage>,
}

impl AddMembersIntentData {
pub fn new(key_packages: Vec<VerifiedKeyPackage>) -> Self {
Self { key_packages }
}

pub(crate) fn to_bytes(&self) -> Result<Vec<u8>, IntentError> {
let mut buf = Vec::new();
let key_package_bytes_result: Result<Vec<Vec<u8>>, tls_codec::Error> = self
.key_packages
.iter()
.map(|kp| kp.inner.tls_serialize_detached())
.collect();

AddMembersPublishData {
version: Some(AddMembersVersion::V1(AddMembersV1 {
key_packages_bytes_tls_serialized: key_package_bytes_result?,
})),
}
.encode(&mut buf)
.unwrap();

Ok(buf)
}

pub(crate) fn from_bytes(
data: &[u8],
provider: &XmtpOpenMlsProvider,
) -> Result<Self, IntentError> {
let msg = AddMembersPublishData::decode(data)?;
let key_package_bytes = match msg.version {
Some(AddMembersVersion::V1(v1)) => v1.key_packages_bytes_tls_serialized,
None => return Err(IntentError::Generic("missing payload".to_string())),
};
let key_packages: Result<Vec<VerifiedKeyPackage>, KeyPackageVerificationError> =
key_package_bytes
.iter()
// TODO: Serialize VerifiedKeyPackages directly, so that we don't have to re-verify
.map(|kp| VerifiedKeyPackage::from_bytes(provider, kp))
.collect();

Ok(Self::new(key_packages?))
}
}

impl TryFrom<AddMembersIntentData> for Vec<u8> {
type Error = IntentError;

fn try_from(intent: AddMembersIntentData) -> Result<Self, Self::Error> {
intent.to_bytes()
}
}

#[derive(Debug, Clone)]
pub enum PostCommitAction {
SendWelcomes(SendWelcomesAction),
}

#[derive(Debug, Clone)]
pub struct SendWelcomesAction {
pub installation_ids: Vec<Vec<u8>>,
pub welcome_message: Vec<u8>,
}

impl SendWelcomesAction {
pub fn new(installation_ids: Vec<Vec<u8>>, welcome_message: Vec<u8>) -> Self {
Self {
installation_ids,
welcome_message,
}
}

pub(crate) fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
PostCommitActionProto {
kind: Some(PostCommitActionKind::SendWelcomes(SendWelcomesProto {
installation_ids: self.installation_ids.clone(),
welcome_message: self.welcome_message.clone(),
})),
}
.encode(&mut buf)
.unwrap();

buf
}
}

impl PostCommitAction {
pub(crate) fn to_bytes(&self) -> Vec<u8> {
match self {
PostCommitAction::SendWelcomes(action) => action.to_bytes(),
}
}

pub(crate) fn from_bytes(data: &[u8]) -> Result<Self, IntentError> {
let decoded = PostCommitActionProto::decode(data)?;
match decoded.kind {
Some(PostCommitActionKind::SendWelcomes(proto)) => Ok(Self::SendWelcomes(
SendWelcomesAction::new(proto.installation_ids, proto.welcome_message),
)),
None => Err(IntentError::Generic(
"missing post commit action".to_string(),
)),
}
}

pub(crate) fn from_welcome(
welcome: MlsMessageOut,
installation_ids: Vec<Vec<u8>>,
) -> Result<Self, IntentError> {
let welcome_bytes = welcome.tls_serialize_detached()?;

Ok(Self::SendWelcomes(SendWelcomesAction::new(
installation_ids,
welcome_bytes,
)))
}
}

impl From<Vec<u8>> for PostCommitAction {
fn from(data: Vec<u8>) -> Self {
PostCommitAction::from_bytes(data.as_slice()).unwrap()
}
}

#[cfg(test)]
mod tests {
use xmtp_cryptography::utils::generate_local_wallet;

use super::*;
use crate::{builder::ClientBuilder, InboxOwner};

#[test]
fn test_serialize_send_message() {
let message = vec![1, 2, 3];
let intent = SendMessageIntentData::new(message.clone());
let as_bytes: Vec<u8> = intent.into();
let restored_intent = SendMessageIntentData::from_bytes(as_bytes.as_slice()).unwrap();

assert_eq!(restored_intent.message, message);
}

#[tokio::test]
async fn test_serialize_add_members() {
let wallet = generate_local_wallet();
let wallet_address = wallet.get_address();
let client = ClientBuilder::new_test_client(wallet.into()).await;
let key_package = client
.identity
.new_key_package(&client.mls_provider())
.unwrap();
let verified_key_package = VerifiedKeyPackage::new(key_package, wallet_address.clone());

let intent = AddMembersIntentData::new(vec![verified_key_package.clone()]);
let as_bytes: Vec<u8> = intent.clone().try_into().unwrap();
let restored_intent =
AddMembersIntentData::from_bytes(as_bytes.as_slice(), &client.mls_provider()).unwrap();

assert!(intent.key_packages[0]
.inner
.eq(&restored_intent.key_packages[0].inner));
assert_eq!(
intent.key_packages[0].wallet_address,
restored_intent.key_packages[0].wallet_address
);
}
}
Loading

0 comments on commit f33fd97

Please sign in to comment.